import streamlit as st
from streamlit_oauth import OAuth2Component
import boto3
import pandas as pd
import jwt
from botocore.exceptions import ClientError
import time
import pydeck as pdk
from datetime import datetime, timezone
import logging


# TIP Token exchange configuration
AWS_REGION = "<your-region-we-used: us-east-1>"
TOKEN_EXCHANGE_APP_ARN = "<ARN of customer managed application RedshiftStreamlitDemo you created in IAM Identity center>"
TOKEN_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
TEMP_ROLE_ARN = "arn:aws:iam::<yourawsaccountid>:role/IDCBridgeRole"
ENHANCED_ROLE_ARN = "arn:aws:iam::<yourawsaccountid>:role/RedshiftDataAPIClientRole"
IDENHANCED_ROLE_SESSION_NAME = "<name-your-session>"
ROLE_DURATION_SECS = 3600  # 1 hour

# # Hard-coded Okta OAuth configuration

OKTA_DOMAIN = "<your-domain>.okta.com"
AUTHORIZE_URL = f"https://<your-domain>.okta.com/oauth2/v1/authorize"
TOKEN_URL = f"https://<your-domain>.okta.com/oauth2/v1/token"
REFRESH_TOKEN_URL = f"https://<your-domain>.okta.com/oauth2/v1/token"
REVOKE_TOKEN_URL = f"https://<your-domain>.okta.com/oauth2/v1/revoke"
LOGOUT_URL = f"https://<your-domain>.okta.com/oauth2/v1/logout"

CLIENT_ID = "<aud claim from okta>"
CLIENT_SECRET = "<client secret value corresponding aud from okta>"
REDIRECT_URI = "http://localhost:8501"
SCOPE = "openid profile email"

WORKGROUP_NAME = "<your-redshift-workgroup-we-used:redshift-tip-enabled>"
DATABASE = "dev"


def execute_statement(sql, redshift_client):
    """
    Executes a SQL statement on Amazon Redshift using the provided Redshift Data API client.

    Args:
        sql (str): The SQL query to execute.
        redshift_client (boto3.client): The Redshift Data API client.

    Returns:
        str: The execution ID of the statement.

    Raises:
        ClientError: If an error occurs during execution.
    """
    try:
        response = redshift_client.execute_statement(
            WorkgroupName=WORKGROUP_NAME,
            Database=DATABASE,
            Sql=sql
        )
        return response["Id"]
    
    except ClientError as e:
        error_code = e.response.get('Error', {}).get('Code', '')
        
        if error_code == 'ExpiredTokenException':
            logging.error("Session expired. Logging out...")
            logout()
        else:
            logging.error(f"Error executing statement: {e}")
            raise

def wait_for_query(statement_id, redshift_client, poll_interval=2):
    """
    Waits for the Redshift Data API query to complete.

    Args:
        statement_id (str): The execution ID of the statement.
        redshift_client (boto3.client): The Redshift Data API client.
        poll_interval (int, optional): Time interval (seconds) between status checks. Default is 2.

    Returns:
        dict: The response containing the final query status.
    """
    while True:
        try:
            response = redshift_client.describe_statement(Id=statement_id)
            status = response.get("Status")

            if status in ["FINISHED", "FAILED", "ABORTED"]:
                return response

            time.sleep(poll_interval)
        
        except ClientError as e:
            logging.error(f"Error checking query status: {e}")
            raise

def fetch_results(statement_id, redshift_client):
    """
    Fetches query results from the Redshift Data API.

    Args:
        statement_id (str): The execution ID of the statement.
        redshift_client (boto3.client): The Redshift Data API client.

    Returns:
        list: A list of records from the query result.
    """
    try:
        response = redshift_client.get_statement_result(Id=statement_id)
        return response.get("Records", [])
    
    except ClientError as e:
        logging.error(f"Error fetching query results: {e}")
        raise


def query_sales_data(redshift_client):
    """
    Queries sales data from Redshift and returns it as a pandas DataFrame.

    Args:
        redshift_client (boto3.client): The Redshift Data API client.

    Returns:
        pd.DataFrame: A DataFrame containing the queried sales data.
    """
    sql = """
        SELECT City, Country, Latitude, Longitude, SUM(Sales_Price * Quantity) AS Total_Sales
        FROM public.sales_data
        GROUP BY City, Country, Latitude, Longitude
    """
    
    with st.spinner("Loading sales data..."):
        try:
            # Execute the SQL statement
            statement_id = execute_statement(sql, redshift_client)
            
            # Wait for the query to finish
            wait_for_query(statement_id, redshift_client)
            
            # Fetch the query results
            results = fetch_results(statement_id, redshift_client)
        
        except ClientError as e:
            # Handle token expiration or other errors
            error_code = e.response.get('Error', {}).get('Code', '')
            if error_code == 'ExpiredTokenException':
                st.error("Session expired. Logging out...")
                logout()
            else:
                st.error(f"Error fetching data: {e}")
                logging.error(f"Error fetching data: {e}")
            raise

    # Process and format the results
    data = []
    for row in results:
        try:
            city = row[0]["stringValue"]
            country = row[1]["stringValue"]
            latitude = float(row[2]["stringValue"])
            longitude = float(row[3]["stringValue"])
            total_sales = float(row[4]["stringValue"])
            data.append({
                "City": city, 
                "Country": country, 
                "Latitude": latitude, 
                "Longitude": longitude, 
                "Total_Sales": total_sales
            })
        except (KeyError, ValueError) as e:
            logging.error(f"Error processing row: {row}, Error: {e}")
            continue  # Skip any rows with errors

    # Return the results as a DataFrame
    return pd.DataFrame(data)


def assume_role_with_web_identity(jwt_token):
    """
    Assumes an IAM role using a web identity token and returns the temporary credentials.

    Args:
        jwt_token (str): The JWT token for authentication, typically issued by an external identity provider.

    Returns:
        dict: Temporary IAM credentials (Access Key, Secret Key, Session Token) or None if an error occurs.
    """
    try:
        # Initialize the STS client
        sts_client = boto3.client('sts', region_name=AWS_REGION)
        
        # Decode the JWT token without verifying signature (for debugging purposes)
        decoded_jwt = jwt.decode(jwt_token, options={"verify_signature": False})
        logging.debug(f"Decoded JWT Token: {decoded_jwt}")

        # Prepare the request for AssumeRoleWithWebIdentity
        assume_role_request = {
            'RoleArn': TEMP_ROLE_ARN,
            'RoleSessionName': 'WebIdentitySession',
            'WebIdentityToken': jwt_token,
            'DurationSeconds': ROLE_DURATION_SECS  # 1 hour
        }

        # Call the AssumeRoleWithWebIdentity API
        assume_role_response = sts_client.assume_role_with_web_identity(**assume_role_request)
        
        # Extract the temporary credentials from the response
        temp_credentials = assume_role_response['Credentials']
        logging.info("Temporary credentials successfully obtained.")
        
        # Return the temporary credentials
        return temp_credentials

    except ClientError as e:
        logging.error(f"Error calling AssumeRoleWithWebIdentity: {e}")
        return None
    except jwt.ExpiredSignatureError:
        logging.error("JWT token has expired.")
        return None
    except jwt.DecodeError:
        logging.error("Error decoding JWT token.")
        return None
    except Exception as e:
        logging.error(f"Unexpected error: {e}")
        return None


def create_token_with_iam(jwt_token, temp_credentials):
    """
    Creates an IAM token using the provided JWT token and temporary credentials.

    Args:
        jwt_token (str): The JWT token to exchange for an IAM token.
        temp_credentials (dict): Temporary AWS credentials for assuming the role.
    
    Returns:
        str or None: The IAM token if successful, otherwise None.
    """
    logging.info("Starting token creation process with IAM.")
    
    # Initialize the SSO OIDC client with temporary credentials
    try:
        sso_oidc_client = boto3.client(
            'sso-oidc', 
            region_name=AWS_REGION, 
            aws_access_key_id=temp_credentials['AccessKeyId'],
            aws_secret_access_key=temp_credentials['SecretAccessKey'],
            aws_session_token=temp_credentials['SessionToken']
        )
    except Exception as e:
        logging.error(f"Error initializing SSO OIDC client: {e}")
        return None

    # Prepare the request for CreateTokenWithIAM
    token_request = {
        'clientId': TOKEN_EXCHANGE_APP_ARN,
        'grantType': TOKEN_GRANT_TYPE,
        'assertion': jwt_token
    }

    # Call the CreateTokenWithIAM API
    try:
        token_result = sso_oidc_client.create_token_with_iam(**token_request)
        id_token = token_result['idToken']
        logging.info(f"Successfully obtained ID Token: {id_token}")
        return id_token
    except ClientError as e:
        logging.error(f"Error calling CreateTokenWithIAM API: {e}")
        return None
    except KeyError as e:
        logging.error(f"Missing expected field in response: {e}")
        return None


def extract_identity_context_from_id_token(id_token):
    """
    Extracts the identity context from a decoded JWT token.

    Args:
        id_token (str): The JWT token containing identity context.

    Returns:
        dict or None: The extracted identity context if available, otherwise None.
    """
    logging.info("Decoding ID token to extract identity context.")

    try:
        # Decode the JWT token (without signature verification)
        decoded_jwt = jwt.decode(id_token, options={"verify_signature": False})

        logging.debug(f"Decoded JWT Claims: {decoded_jwt}")

        # Extract the identity context from the token
        for key in ('sts:identity_context', 'sts:audit_context'):
            if key in decoded_jwt:
                return decoded_jwt[key]

        logging.warning("No valid identity context found in the token.")
        return None

    except Exception as e:
        logging.error(f"Error decoding JWT: {e}")
        return None


def assume_enhanced_role_session(id_token, temp_credentials):
    """
    Assumes an identity-enhanced IAM role session using the provided ID token and temporary credentials.

    Args:
        id_token (str): The ID token containing the identity context.
        temp_credentials (dict): Temporary AWS credentials for assuming the role.

    Returns:
        dict or None: The credentials for the identity-enhanced IAM role session, or None on failure.
    """
    logging.info("Extracting identity context from ID token.")
    identity_context = extract_identity_context_from_id_token(id_token)

    if not identity_context:
        logging.error("Failed to extract identity context from ID token.")
        return None

    try:
        # Initialize STS client with temporary credentials
        sts_client = boto3.client(
            'sts',
            region_name=AWS_REGION,
            aws_access_key_id=temp_credentials['AccessKeyId'],
            aws_secret_access_key=temp_credentials['SecretAccessKey'],
            aws_session_token=temp_credentials['SessionToken']
        )

        # Prepare AssumeRole request with identity context
        assume_role_request = {
            'RoleArn': ENHANCED_ROLE_ARN,
            'RoleSessionName': IDENHANCED_ROLE_SESSION_NAME,
            'DurationSeconds': ROLE_DURATION_SECS,
            'ProvidedContexts': [{
                'ContextAssertion': identity_context,
                'ProviderArn': "arn:aws:iam::aws:contextProvider/IdentityCenter"
            }]
        }

        # Call the AssumeRole API
        logging.info("Calling STS AssumeRole for identity-enhanced session.")
        assume_role_response = sts_client.assume_role(**assume_role_request)

        enhanced_role_credentials = assume_role_response['Credentials']
        logging.info("Successfully assumed enhanced role session.")
        
        return enhanced_role_credentials

    except ClientError as e:
        logging.error(f"Error calling AssumeRole: {e}")
        return None


def get_id_enhanced_session(jwt_token):
    """
    Obtains an identity-enhanced session by assuming a temporary IAM role,
    creating a token with IAM, and assuming an enhanced role session.
    
    Args:
        jwt_token (str): The JWT id token from the identity provider.
    
    Returns:
        dict or None: The enhanced session credentials if successful, otherwise None.
    """
    logging.info("Starting identity-enhanced session process.")

    # Step 1: Assume a temporary IAM role with the provided JWT token
    temp_credentials = assume_role_with_web_identity(jwt_token)
    if not temp_credentials:
        logging.error("Failed to assume role with web identity.")
        return None

    # Step 2: Use the temporary credentials to create a token with IAM
    id_token = create_token_with_iam(jwt_token, temp_credentials)
    if not id_token:
        logging.error("Failed to create ID token with IAM.")
        return None

    # Step 3: Use the ID token to assume an enhanced role session
    enhanced_creds = assume_enhanced_role_session(id_token, temp_credentials)
    if not enhanced_creds:
        logging.error("Failed to assume enhanced role session.")
        return None

    logging.info("Successfully obtained identity-enhanced session credentials.")
    return enhanced_creds
        
def get_user_email_from_token(id_token):
    try:
        decoded_token = jwt.decode(id_token, options={"verify_signature": False})
        return decoded_token.get("email", "Unknown User")
    except Exception as e:
        return "Unknown User"


def logout():
    id_token = st.session_state.token.get("id_token", None)
    
    if id_token:
        logout_url = f"{LOGOUT_URL}?id_token_hint={id_token}&post_logout_redirect_uri={REDIRECT_URI}"

        # Clear Streamlit session
        for key in list(st.session_state.keys()):
            del st.session_state[key]

        # Redirect to Okta logout
        st.write(f'<meta http-equiv="refresh" content="0; URL={logout_url}">', unsafe_allow_html=True)
    else:
        st.error("No valid ID token found. Please log in again.")


def is_token_expired():
    if "aws_creds" in st.session_state:
        expiration = st.session_state["aws_creds"].get("Expiration")
        if expiration and isinstance(expiration, datetime):
            return expiration <= datetime.now(timezone.utc)
    return True

def main():
    # Create OAuth2Component instance
    oauth2 = OAuth2Component(
        CLIENT_ID, 
        CLIENT_SECRET, 
        AUTHORIZE_URL, 
        TOKEN_URL, 
        REFRESH_TOKEN_URL, 
        REVOKE_TOKEN_URL)

    st.set_page_config(page_title="Sales Dashboard", layout="wide")

    # Initialize session state variables for user management
    if "is_authenticated" not in st.session_state:
        st.session_state.is_authenticated = False
    if "token" not in st.session_state:
        st.session_state.token = None
    if "query_result" not in st.session_state:
        st.session_state.query_result = None


    # Handle authentication
    if not st.session_state.is_authenticated or is_token_expired():
        # Show the login button if not authenticated
        st.title("Login")
        result = oauth2.authorize_button("Login with SSO", REDIRECT_URI, SCOPE)
        if result and "token" in result:
            # Save the token in session state and mark the user as authenticated
            st.session_state.token = result.get("token")
            st.session_state.user_email = get_user_email_from_token(st.session_state.token.get("id_token"))
            st.session_state.aws_creds = get_id_enhanced_session(st.session_state.token.get("id_token"))
            st.session_state.is_authenticated = True
            st.rerun()
    else:
        
        st.text('Identity-enhanced role session credentials:')
        st.json(st.session_state.aws_creds)
        st.title("Total Sales by City")

        if not is_token_expired():
            try:
                # Use the enhanced credentials to create the Redshift client
                redshift_client = boto3.client("redshift-data", region_name=AWS_REGION,
                                            aws_access_key_id=st.session_state.aws_creds['AccessKeyId'],
                                            aws_secret_access_key=st.session_state.aws_creds['SecretAccessKey'],
                                            aws_session_token=st.session_state.aws_creds['SessionToken'])
                # Fetch data once at the start
                df = query_sales_data(redshift_client)
            
            except ClientError as e:
                if e.response['Error']['Code'] == 'ExpiredTokenException':
                    st.error("Session expired. Logging out...")
                    logout()
                else:
                    st.error(f"Unexpected error: {e}")
                    raise


            # Filter controls in the sidebar
            with st.sidebar:
                st.header("Filters")
                sales_range = st.slider("Sales Range", 0, 50000, (0, 50000), step=200)
                city_filter = st.text_input("City", "")
                country_filter = st.text_input("Country", "")
                
                st.divider()
                st.markdown(f"**Logged in as:** {st.session_state.user_email}") 
                # Logout button
                if st.button("Logout"):
                    # Clear session state to log out the user
                    logout()
                

            # Apply the filter logic on the frontend (no backend call)
            filtered_df = df

            # Filter by sales range
            min_sales, max_sales = sales_range
            filtered_df = filtered_df[(filtered_df["Total_Sales"] >= min_sales) & (filtered_df["Total_Sales"] <= max_sales)]

            # Filter by city
            if city_filter:
                filtered_df = filtered_df[filtered_df["City"].str.contains(city_filter, case=False)]

            # Filter by country
            if country_filter:
                filtered_df = filtered_df[filtered_df["Country"].str.contains(country_filter, case=False)]
            
            # Check if there is data after filtering
            if filtered_df.empty:
                st.warning("No sales data available for the applied filters.")
                return
            
            # Normalize sales for visualization (use .loc to avoid the warning)
            max_sales_value = filtered_df["Total_Sales"].max()

            filtered_df = filtered_df.copy()
            filtered_df.loc[:, "Scaled_Sales"] = (filtered_df["Total_Sales"] / max_sales_value) * 150000  # Normalize bubble size
            
            col1, col2 = st.columns([1, 2])  # Adjust layout
            
            with col1:
                st.subheader("Sales Data Table")
                st.dataframe(filtered_df, height=600)  # Set a fixed height for the table
            
            with col2:
                st.subheader("Sales Distribution Map")
                layer = pdk.Layer(
                    "ScatterplotLayer",
                    data=filtered_df,
                    get_position="[Longitude, Latitude]",
                    get_radius="Scaled_Sales",
                    get_fill_color="[200, 30, 0, 160]",
                    pickable=True,
                    auto_highlight=True,
                )
                
                view_state = pdk.ViewState(
                    latitude=filtered_df["Latitude"].mean(),
                    longitude=filtered_df["Longitude"].mean(),
                    zoom=3,
                    pitch=0,
                )
                
                st.pydeck_chart(pdk.Deck(
                    layers=[layer],
                    initial_view_state=view_state,
                    tooltip={"html": "<b>{City}</b>: ${Total_Sales}"},
                    map_style="light"
                ))

        else:
            st.error("Session expired. Please re-authenticate.")
            logout()

if __name__ == "__main__":
    main()