import sys
import json
import os
import pandas as pd
import mysql.connector
import numpy as np

# uses mysql-connector-python, to be installed with pip install mysql-connector-python
# uses pandas, to be installed with pip install pandas

_dbHost = "";
_dbUser = "";
_dbPassword = "";
_dbName = "";


def exitWithError(error_code: int, error_message: str, additional_data: dict = None):
    try:
        print(f"Error {error_code}: {error_message}")

        # JSON object to write
        error_data = {
            "errorCode": error_code,
            "errorMessage": error_message
        }
        
        # Add additional data if provided
        if additional_data:
            error_data.update(additional_data)

        # Write the error data to the specified file in JSON format
        with open("output.json", 'w') as outfile:
            json.dump(error_data, outfile)

        # Exit the program with the provided error code
        sys.exit(error_code)
    
    except Exception as e:
        print(f"Failed to log error: {e}")
        sys.exit(error_code)



def predict_initialize():
    print("--Initializing prediction module--")
    # Check for required argument Temp Directory Path
    if len(sys.argv) != 2:
        print("Error: Missing argument: Path to temp directory.")
        sys.exit(1)

    input_directory = sys.argv[1]

    try:
        os.chdir(input_directory)
    except FileNotFoundError:
        print(f"Error: Directory {input_directory} does not exist.")
        sys.exit(2)
    except PermissionError:
        print(f"Error: Permission denied to access directory {input_directory}.")
        sys.exit(3)
    except Exception as e:
        print(f"An unexpected error occurred: {str(e)} while accessing the directory {input_directory}.")
        sys.exit(4)

    config = predict_loadConfig()

    if not "resultId" in config:
        exitWithError(5, "Error: 'resultId' parameter is missing.")

    resultId = config["resultId"]
    config.pop("resultId")

    output_data_file = "output.json"

    regressors = predict_getRegressors(config)

    histData, futureData = predict_loadData(config, regressors)

    if 'onlyModelAnalysis' in config:
        onlyModelAnalysis = config["onlyModelAnalysis"]
        config.pop("onlyModelAnalysis", None)
    else:
        onlyModelAnalysis = False

    if 'cutoff_minimum' in config:
        cutoff_minimum = config["cutoff_minimum"]
    else:
        cutoff_minimum = None
    config.pop("cutoff_minimum", None)


    if 'log_transform' in config:
        log_transform = config["log_transform"]
    else:
        log_transform = False
    config.pop("log_transform", None)


    if not 'anomaly_threshold_n_stddev' in config:
        anomaly_threshold_n_stddev = 3.0 # default
    else:
        # ToDo: check for positive float > 0
        anomaly_threshold_n_stddev = config["anomaly_threshold_n_stddev"]
    config.pop("anomaly_threshold_n_stddev", None)


    if log_transform:
        histData['y'] = np.log(histData['y'] + 1)  # Adding 1 to avoid log(0)


    return config, resultId, output_data_file, regressors, histData, futureData, cutoff_minimum, log_transform, anomaly_threshold_n_stddev, onlyModelAnalysis



def predict_loadConfig():
    print("--Loading config--")
    input_config_file = "config.json"

    # Read input config file
    try:
        with open(input_config_file) as fs:
            input_config_json_string = fs.read()
    except FileNotFoundError:
        exitWithError(10, f"File {input_config_file} does not exist.")
    except PermissionError:
        exitWithError(11, f"Permission denied to open config file {input_config_file}.")
    except Exception as e:
        exitWithError(12, f"An unexpected error occurred: {str(e)} when opening and loading the config file {input_config_file}.")

    try:
        config = json.loads(input_config_json_string)
    except ValueError:
        exitWithError(13, f"Could not decode JSON from the specified config file.")
    except Exception as e:
        exitWithError(14, f"An unexpected error occurred: {str(e)} while decoding the specified config file.")
    
    return config


# regressors: list of regressors to be used in the prediction
# columns: 'name', 'numDaysSince', 'numDaysBefore'

def predict_getRegressors(config):
    print("--getRegressors--")

    if not "regressors" in config:
        exitWithError(20, "Error: 'regressors' key not found in config.")

    try:
        regressors = config["regressors"]
        config.pop("regressors")

        print("Regressors loaded:")
        for regressor in regressors:
            print(f"  {regressor}")

        return regressors
    
    except KeyError:
        exitWithError(21, "Error: 'regressors' key not found in config.")
    except Exception as e:
        exitWithError(22, f"Runtime error: {e}")


def predict_generateRegressorsDaysSince(df, regressors):

    new_regressors = []

    for regressor in regressors:

        numDaysSince = regressor["numDaysSince"]
        if numDaysSince is None or numDaysSince < 1:
            continue

        regressorName = regressor["name"]

        print(f"Generating 'days since' column for regressor: {regressorName}, with a maximum of {numDaysSince} days.")

        # Create an array to hold the result
        records_since_1 = np.zeros(len(df), dtype=int)

        # Identify the indices where 'original' is 1
        ones_indices = df.index[df[regressorName] == 1].tolist()

        # Loop through the indices where 'original' is 1 and update records_since_1 accordingly
        for i in range(len(ones_indices)):
            start_idx = ones_indices[i] + 1
            end_idx = ones_indices[i + 1] if i + 1 < len(ones_indices) else len(df)
            length = min(numDaysSince, end_idx - start_idx)
            records_since_1[start_idx:start_idx + length] = range(1, length + 1)

        # Assign the result to the dataframe
        df[f"ds_{regressorName}"] = records_since_1

        # Create the new regressor record for the 'days since' column
        new_regressor = {"name": f"ds_{regressor['name']}", "numDaysSince": 0, "numDaysBefore": 0}
        new_regressors.append(new_regressor)

    regressors.extend(new_regressors)

    print("New Regressors added for 'days since':")
    print(new_regressors)
    print("First 10 rows of the dataframe after generating 'days since' columns:")
    print(df.head(10))  # Print the first 10 records



def predict_generateRegressorsDaysBefore (df, regressors):

    new_regressors = []

    for regressor in regressors:

        numDaysBefore = regressor["numDaysBefore"]
        if numDaysBefore is None or numDaysBefore < 1:
            continue

        regressorName = regressor["name"]

        print(f"Generating 'days before' column for regressor: {regressorName}, with a maximum of {numDaysBefore} days.")

        # Create an array to hold the result
        records_before_1 = np.zeros(len(df), dtype=int)

        # Identify the indices where 'original' is 1
        ones_indices = df.index[df[regressorName] == 1].tolist()

        # Loop through the indices where 'original' is 1 and update records_before_1 accordingly
        for i in range(len(ones_indices)):
            end_idx = ones_indices[i]  # Current position of 1
            start_idx = max(0, end_idx - numDaysBefore)  # Calculate start index, but no earlier than the beginning of the DataFrame
            length = end_idx - start_idx  # Determine how many records we can count before the current 1
            records_before_1[start_idx:end_idx] = range(length, 0, -1)  # Set decreasing count from length to 1

        # Assign the new column to the DataFrame
        df[f"db_{regressorName}"] = records_before_1
   
        # Create the new regressor record for the 'days since' column
        new_regressor = {"name": f"db_{regressor['name']}", "numDaysSince": 0, "numDaysBefore": 0}
        new_regressors.append(new_regressor)

    regressors.extend(new_regressors)

    print("New Regressors added for 'days before':")
    print(new_regressors)
    print("First 10 rows of the dataframe after generating 'days before' columns:")
    print(df.head(10))  # Print the first 10 records



def predict_dropRegressorsWithNaNValues(df, df2, regressors):
    try:

        # List to keep track of columns to be dropped and corresponding regressors to be removed
        columns_to_drop = []
        regressors_to_remove = []
        columns_to_check = df.columns[1:]  # Skip the first column (likely 'ds')

        # Identify columns to drop if they contain any NaN values
        for column in columns_to_check:
            if df[column].isna().sum() > 0:  # Check if there are any NaN values
                columns_to_drop.append(column)
                # Find the corresponding regressor
                for regressor in regressors:
                    if regressor["name"] == column:
                        regressors_to_remove.append(regressor)
                        break

        if columns_to_drop:
            df.drop(columns=columns_to_drop, inplace=True)
            df2.drop(columns=columns_to_drop, inplace=True)

        # Optionally, you can remove the identified regressors from the regressors list
        regressors[:] = [r for r in regressors if r not in regressors_to_remove]

        # Print the dropped columns and regressors for verification
        print("Columns dropped from both data frames:", columns_to_drop)
        print("Regressors dropped:", regressors_to_remove)
    
    except KeyError as ke:
        exitWithError(30, f"KeyError: {ke} - The specified column might not exist.")
    except Exception as e:
        exitWithError(31, f"Runtime error: {e}")



def predict_prepareDataframe (df):
    print("Initial dataframe shape:", df.shape)
    print("First 10 rows of the dataframe:")
    print(df.head(10))  # Print the first 10 records

    # Drop all rows with NaN values in 'ds' or 'y' columns, in place if both columns exist
    if 'y' in df.columns:
        df.dropna(subset=['ds', 'y'], inplace=True)
        print(f"Rows after dropping missing values in 'ds' or 'y': {df.shape[0]}")
    else:
        df.dropna(subset=['ds'], inplace=True)
        print(f"Rows after dropping missing values in 'ds': {df.shape[0]}")

    # Convert the 'ds' column to datetime, coercing errors to NaT
    print("Converting 'ds' column to datetime format.")
    df['ds'] = pd.to_datetime(df['ds'], format='%Y-%m-%d %H:%M:%S', errors='coerce')

    # Drop rows with 'NaT' values in 'ds' column after conversion
    nat_count = df['ds'].isna().sum()
    if nat_count > 0:
        print(f"Warning: {nat_count} values could not be converted to datetime and were set to NaT. These rows will be removed.")
        df.dropna(subset=['ds'], inplace=True)

    # Ensure the 'y' column is numeric, coercing errors to NaN, and drop NaNs in place if 'y' exists
    if 'y' in df.columns:
        df['y'] = pd.to_numeric(df['y'], errors='coerce')
        numeric_y_nan_count = df['y'].isna().sum()
        if numeric_y_nan_count > 0:
            print(f"Warning: {numeric_y_nan_count} non-numeric values in 'y' column. These rows will be removed.")
            df.dropna(subset=['y'], inplace=True)

    # Print final summary
    print("Final dataframe shape after all preprocessing:", df.shape)
    print(f"Total missing values in the cleaned dataframe: {df.isna().sum().sum()}")     



def predict_loadData(config, regressors) -> tuple[pd.DataFrame, pd.DataFrame]:
    global _dbHost, _dbUser, _dbPassword, _dbName

    if not 'onlyModelAnalysis' in config:
        onlyModelAnalysis = False
    else:
        onlyModelAnalysis = config['onlyModelAnalysis']

    if not 'db_host' in config:
        exitWithError(50, "Error: 'db_host' parameter is missing.")
    _dbHost = config['db_host']
    config.pop('db_host')

    if not 'db_user' in config:
        exitWithError(51, "Error: 'db_user' parameter is missing.")
    _dbUser = config['db_user']
    config.pop('db_user')

    if not 'db_password' in config:
        exitWithError(52, "Error: 'db_password' parameter is missing.")
    _dbPassword = config['db_password']
    config.pop('db_password')

    if not 'db_name' in config:
        exitWithError(53, "Error: 'db_name' parameter is missing.")
    _dbName = config['db_name']
    config.pop('db_name')

    if not 'histViewName' in config:
        exitWithError(54, "Error: 'histViewName' parameter is missing.")
    histViewName = config['histViewName']
    config.pop('histViewName')

    if not onlyModelAnalysis:
        if not 'futureViewName' in config:
            exitWithError(55, "Error: 'futureViewName' parameter is missing.")
        futureViewName = config['futureViewName']
        config.pop('futureViewName')

    try:
        conn = mysql.connector.connect(
            host=_dbHost,
            user=_dbUser,
            password=_dbPassword,
            database=_dbName
        )
    except mysql.connector.Error as e:
        exitWithError(56, f"Error: Could not connect to the database: {str(e)}")
    except Exception as e:
        exitWithError(57, f"An unexpected error occurred: {str(e)} while connecting to the database.")


    try:
        print("--load histData--")
        SQL = f"SELECT * FROM {histViewName}"
        histData = pd.read_sql(SQL, conn)

        if not onlyModelAnalysis:
            print("--load futureData--")
            SQL = f"SELECT * FROM {futureViewName}"
            futureData = pd.read_sql(SQL, conn)
        else:
            futureData = pd.DataFrame()

        print("--prepare histData--")
        predict_prepareDataframe(histData)
        predict_dropRegressorsWithNaNValues(histData, futureData, regressors)
        predict_generateRegressorsDaysSince(histData, regressors)
        predict_generateRegressorsDaysBefore(histData, regressors)
        if histData.empty:
            exitWithError(58, "Error: The historical data is empty.")
        
        histData.to_csv(
            "/var/www/prophettmp/histData.csv",
            index=True,
            sep=";",
            header=True,
            na_rep="NA",
            encoding="utf-8"
        )

        if not onlyModelAnalysis:
            print("--prepare futureData--")
            predict_prepareDataframe(futureData)
            predict_dropRegressorsWithNaNValues(futureData, histData, regressors)
            predict_generateRegressorsDaysSince(futureData, regressors)
            predict_generateRegressorsDaysBefore(futureData, regressors)
            if futureData.empty:
                exitWithError(59, "Error: The future data is empty.")

            futureData.to_csv(
                "/var/www/prophettmp/futureData.csv",
                index=True,
                sep=";",
                header=True,
                na_rep="NA",
                encoding="utf-8"
            )


    except KeyError as ke:
        exitWithError(60, f"Error: a key was not found - {ke}")
    except ValueError as ve:
        exitWithError(61, f"Error: Value error occurred - {ve}")
    except IndexError as ie:
        exitWithError(62, f"Error: Index error occurred - {ie}")
    except mysql.connector.Error as e:
        exitWithError(63, f"Error: Could not read data from the database: {str(e)}")
    except Exception as e:
        exitWithError(64, f"An unexpected error occurred: {e} while reading data from the database.")

    return histData, futureData



def predict_saveResults(resultId, result_type, df, columns_to_insert, wMAPE):
    print(f"--Saving {result_type} results--")

    column_names = [item[0] for item in columns_to_insert]

    print("Columns to be saved:")
    for col in column_names:
        print(col)
    print("")

    print("Columns not saved:")
    for col in df.columns:
        if col not in column_names:
            print(col)
    print("")

    try:
        conn = mysql.connector.connect(
            host=_dbHost,
            user=_dbUser,
            password=_dbPassword,
            database=_dbName
        )
    except mysql.connector.Error as e:
        exitWithError(70, f"Database connection error: {e}")
    except Exception as e:
        exitWithError(72, f"Unexpected error during database connection: {e}")

    cursor = conn.cursor() 

    # Iterate through each row in the dataframe
    for _, row in df.iterrows():
        ds_value = row['ds']

        # Prepare bulk values for the insert statement for each row
        values_to_insert = []
        for col in column_names:
            value = row[col]
            if pd.isna(value):
                value = 0
            target_id = next((item[1] for item in columns_to_insert if item[0] == col), None)
            values_to_insert.append((resultId, result_type, ds_value, target_id, value))

        # Build a single insert query for the current row with multiple value tuples
        insert_query = "INSERT INTO prediction_results (result_id, result_type, ds, k, v) VALUES "
        insert_query += ", ".join(["(%s, %s, %s, %s, %s)"] * len(values_to_insert))

        #print(f"Insert query: {insert_query}")

        # Flatten the list of tuples for passing into the execute statement
        flattened_values = [item for sublist in values_to_insert for item in sublist]

        #print(f"Values to insert: {flattened_values}")

        # Execute the insert query for the current row
        try:
                cursor.execute(insert_query, flattened_values)
        except mysql.connector.Error as e:
            exitWithError(73, f"Database query execution error: {e}")
        except Exception as e:
            exitWithError(74, f"Unexpected error during query execution: {e}")

    # Commit the transaction to the database
    conn.commit()




def predict_processAfterForecast(log_transform, cutoff_minimum, anomaly_threshold_n_stddev,
                                 forecast, forecast_columns, forecast_past, forecast_past_columns,
                                 histData, histDataColumns):
    if log_transform:
        predict_undoLogTransformation(forecast, forecast_columns, forecast_past, forecast_past_columns,
                                      histData, histDataColumns)

    # Post-process the forecast to set negative values to cutoff_minimum
    if isinstance(cutoff_minimum, (int, float)) and cutoff_minimum is not None:
        if forecast is not None:
            for col in forecast_columns:
                forecast[col] = forecast[col].apply(lambda x: max(cutoff_minimum, x))
        if forecast_past is not None:
            for col in forecast_past_columns:
                forecast_past[col] = forecast_past[col].apply(lambda x: max(cutoff_minimum, x))

    wMAPE = None

    
    if forecast is not None:

        # convention: the first name in forecast_columns is always the primary forecast column (e.g., "yhat" in Prophet)
        pred_col = forecast_columns[0]  

        # 1. Robust Z-score using median and MAD (Median Absolute Deviation)
        median_yhat = forecast[pred_col].median()
        mad_yhat = forecast[pred_col].apply(lambda x: abs(x - median_yhat)).median()
        # Avoid division by zero in case mad_yhat is zero:
        scale = mad_yhat * 1.4826 if mad_yhat != 0 else 1
        forecast['robust_z_score'] = (forecast[pred_col] - median_yhat) / scale
        threshold_z = 3  # Adjust if needed
        forecast['anomaly1'] = (forecast['robust_z_score'].abs() > threshold_z).astype(int)

        # 2. IQR-based anomaly detection (global)
        Q1 = forecast[pred_col].quantile(0.25)
        Q3 = forecast[pred_col].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        forecast['anomaly2'] = ((forecast[pred_col] < lower_bound) | (forecast[pred_col] > upper_bound)).astype(int)

        # 3. Rolling window anomaly detection using robust statistics
        window_size = 10  # Adjust as needed
        forecast['rolling_median'] = forecast[pred_col].rolling(window=window_size, min_periods=1).median()
        # Compute rolling MAD using a lambda function:
        forecast['rolling_mad'] = forecast[pred_col].rolling(window=window_size, min_periods=1)\
            .apply(lambda x: np.median(np.abs(x - np.median(x))), raw=True)
        # Scale rolling MAD to be comparable to standard deviation:
        forecast['rolling_mad_scaled'] = forecast['rolling_mad'] * 1.4826
        # Use a threshold of 2 robust deviations (tweak as necessary)
        forecast['anomaly3'] = (
            (forecast[pred_col] > forecast['rolling_median'] + 2 * forecast['rolling_mad_scaled']) |
            (forecast[pred_col] < forecast['rolling_median'] - 2 * forecast['rolling_mad_scaled'])
        ).astype(int)

        # Combine the anomaly flags using a majority vote (flag anomaly if at least 2 methods agree)
        forecast['anomaly'] = ((forecast['anomaly1'] + forecast['anomaly2'] + forecast['anomaly3']) >= 2).astype(int)


    if (histData is not None) and (forecast_past is not None) and (not histData.empty) and (not forecast_past.empty):

        # convention: the first name in forecast_past_columns is always the primary forecast column (e.g., "yhat" in Prophet)
        pred_col = forecast_past_columns[0]  

        if anomaly_threshold_n_stddev <= 0: 
            exitWithError(99, "Error: anomaly_threshold_n_stddev must be a positive value.")

        # Ensure both 'ds' columns are datetime
        histData['ds'] = pd.to_datetime(histData['ds'])
        forecast_past['ds'] = pd.to_datetime(forecast_past['ds'])

        # Merge on the 'ds' column to retain only matching intervals
        common = pd.merge(histData[['ds', 'y']], forecast_past[['ds', pred_col]], on='ds', how='inner')

        # Remove rows where 'y' or the forecast column is missing (None or NaN)
        common = common.dropna(subset=['y', pred_col])

        # Error if there are no common rows    
        if common.empty:
            exitWithError(99, "Error: No matching 'ds' values found between historical data and forecasted data.")

        # Calculate residuals
        common['residual'] = common['y'] - common[pred_col]

        # Protect against division by zero by checking the sum of 'y'
        total_y = np.sum(common['y'])
        if total_y != 0:

            # Calculate wMAPE using the common rows only
            wMAPE = np.sum(np.abs(common['residual'])) / total_y
            print("wMAPE: ", wMAPE)

        # Calculate the standard deviation of the residuals for anomaly detection
        stddev = common['residual'].std()

        # If stddev is NaN or zero, set the threshold to 0.
        if pd.isna(stddev) or stddev == 0:
            threshold = 0
        else:
            threshold = anomaly_threshold_n_stddev * stddev

        # calculate Anomalies
        common['anomaly'] = (np.abs(common['residual']) > threshold).astype(int)

        # Compute MAPE for valid (non-zero) actual values; set MAPE to NaN if 'y' is zero.
        common['mape'] = np.where(common['y'] != 0, np.abs(common['residual']) / common['y'] * 100, np.nan)

        # add a date column to the common dataframe
        common['date'] = common['ds'].dt.date  # Extract date part (removes time)

        # Compute Daily MAPE: For each date, compute a weighted MAPE only if the total actual is nonzero.
        daily_mape_series = common.groupby('date').apply(
            lambda x: (np.sum(np.abs(x['residual'])) / np.sum(x['y']) * 100)
            if np.sum(x['y']) != 0 else np.nan
        )

        # Map the computed daily_mape back to each interval in the common dataframe.
        common['daily_mape'] = common['date'].map(daily_mape_series)

        print("First 10 rows of the common dataframe:")
        print(common.head(10))

        # Map data back to forecast_past
        forecast_past['anomaly'] = forecast_past['ds'].map(common.set_index('ds')['anomaly']).fillna(0).astype(int)
        forecast_past['mape'] = forecast_past['ds'].map(common.set_index('ds')['mape'])
        #forecast_past['interval_scaled_daily_mape'] = forecast_past['ds'].map(common.set_index('ds')['daily_mape'])
        #forecast_past['interval_scaled_daily_mape'] = forecast_past['ds'].map(common.set_index('ds')['daily_mape']) * forecast_past['ds'].map(common.set_index('ds')['y'])

        # Merge forecast_past with the daily_mape and actual y values from common on 'ds'
        common_subset = common[['ds', 'daily_mape', 'y']].rename(columns={'y': 'actual_y'})
        print("First 10 rows of the common_subset dataframe:")
        print(common_subset.head(10))

        temp = forecast_past.merge(common_subset, on='ds', how='left')
        forecast_past['interval_scaled_daily_mape'] = temp['daily_mape'] * temp['actual_y']

        #forecast_past = forecast_past.merge(common_subset, on='ds', how='left')

        # Compute interval_scaled_daily_mape only where there is a matching ds (if no match, the result will be NaN)
        #forecast_past['interval_scaled_daily_mape'] = forecast_past['daily_mape'] * forecast_past['actual_y']


        # Calculate the daily total actual volume for weighting the average daily MAPE.
        daily_actuals = common.groupby('date')['y'].sum()

        # Compute the volume-weighted average daily MAPE:
        # (sum over days of (daily_mape * daily_actual)) / (sum over days of daily_actual)
        if daily_actuals.sum() != 0:
            average_daily_mape = (daily_mape_series * daily_actuals).sum() / daily_actuals.sum()
            print("Average Daily MAPE: ", average_daily_mape)

    if forecast is not None:
        forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')

    if forecast_past is not None:
        forecast_past['ds'] = forecast_past['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')

    return wMAPE



def predict_undoLogTransformation (forecast, forecast_columns, forecast_past, forecast_past_columns, histData, histDataColumns):
    
    # Apply inverse log transformation to the forecasted values
    if forecast is not None:
        for col in forecast_columns:
            forecast[col] = np.exp(forecast[col]) - 1

    if forecast_past is not None:
        for col in forecast_past_columns:
            forecast_past[col] = np.exp(forecast_past[col]) - 1

    if histData is not None:
        for col in histDataColumns:
            histData[col] = np.exp(histData[col]) - 1
