import pandas as pd
import numpy as np
from prophet import Prophet
from os import path
import sys
import json
import logging

print("--start prophet--")

logger = logging.getLogger("cmdstanpy")
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.WARNING)

logging.getLogger("prophet").setLevel(logging.WARNING)
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)

print("--start processing--")

output_data_file = ""



def exitWithError(error_code, error_message):
    try:
        print(f"Error {error_code}: {error_message}")

        # JSON object to write
        error_data = {
            "errorCode": error_code,
            "errorMessage": error_message
        }
        
        # Write the error data to the specified file in JSON format
        with open(output_data_file, 'w') as outfile:
            json.dump(error_data, outfile)

        # Exit the program with the provided error code
        sys.exit(error_code)
    
    except Exception as e:
        print(f"Failed to log error: {e}")
        sys.exit(error_code)



def loadConfig(input_directory):
    input_config_file = path.join(input_directory, "config.json")

    # Read input config file
    try:
        with open(input_config_file) as fs:
            input_config_json_string = fs.read()
    except FileNotFoundError:
        exitWithError(101, f"File {input_config_file} does not exist.")
    except PermissionError:
        exitWithError(102, f"Permission denied to open config file {input_config_file}.")
    except Exception as e:
        exitWithError(103, f"An unexpected error occurred: {str(e)} when opening and loading the config file {input_config_file}.")

    try:
        config = json.loads(input_config_json_string)
    except ValueError:
        exitWithError(104, f"Could not decode JSON from the specified config file.")
    except Exception as e:
        exitWithError(105, f"An unexpected error occurred: {str(e)} while decoding the specified config file.")
        sys.exit(105)
    
    return config



def getInputData(config):
    try:

        print("getInputData")

        # load input data (handed over by .php)
        input_data = pd.DataFrame(config["input_data"])
        config.pop("input_data", None)

        print("input_data dataframe (historical data):")
        print(input_data.head(10))  # Print the first 10 records

        print("Missing values in input data:")
        print(input_data.isna().sum().sum())

        # Rename the first two columns only if they aren't already named 'ds' and 'y'
        columns_to_rename = {}
        
        if input_data.columns[0] != "ds":
            columns_to_rename[input_data.columns[0]] = "ds"
        if input_data.columns[1] != "y":
            columns_to_rename[input_data.columns[1]] = "y"
        
        if columns_to_rename:
            input_data.rename(columns=columns_to_rename, inplace=True)

        # Convert the 'ds' column to datetime, coercing errors to NaT
        print("Converting 'ds' column to datetime format.")
        input_data['ds'] = pd.to_datetime(input_data['ds'], format='%Y-%m-%d %H:%M:%S', errors='coerce')

        # Check if there were any invalid datetime conversions
        if input_data['ds'].isna().sum() > 0:
            nat_count = input_data['ds'].isna().sum()
            print(f"Warning: {nat_count} values could not be converted to datetime and were set to NaT.")

        return input_data
    
    except KeyError:
        exitWithError(201, "Error: 'input_data' key not found in config.")
    except ValueError as ve:
        exitWithError(202, f"Error: Value error occurred - {ve}")
    except IndexError as ie:
        exitWithError(203, f"Error: Index error occurred - {ie}. Input data might have fewer than two columns.")
    except Exception as e:
        exitWithError(204, f"Runtime error: {e}")



def getFutureData(config):
    try:

        print("getFutureData")

        # load future data (handed over by .php)
        future_data = pd.DataFrame(config["future_data"])
        config.pop("future_data", None)

        print("future_data dataframe (future data):")
        print(future_data.head(10))  # Print the first 10 records

        print("Missing values in future data:")
        print(future_data.isna().sum().sum())

        # Rename the first two columns only if they aren't already named 'ds' and 'y'
        columns_to_rename = {}
        
        if future_data.columns[0] != "ds":
            columns_to_rename[future_data.columns[0]] = "ds"
        
        if columns_to_rename:
            future_data.rename(columns=columns_to_rename, inplace=True)

        # Convert the 'ds' column to datetime, coercing errors to NaT
        print("Converting 'ds' column to datetime format.")
        future_data['ds'] = pd.to_datetime(future_data['ds'], format='%Y-%m-%d %H:%M:%S', errors='coerce')

        # Check if there were any invalid datetime conversions
        if future_data['ds'].isna().sum() > 0:
            nat_count = future_data['ds'].isna().sum()
            print(f"Warning: {nat_count} values could not be converted to datetime and were set to NaT.")

        return future_data
    
    except KeyError:
        exitWithError(301, "Error: 'future_data' key not found in config.")
    except ValueError as ve:
        exitWithError(302, f"Error: Value error occurred - {ve}")
    except IndexError as ie:
        exitWithError(303, f"Error: Index error occurred - {ie}. Input data might have fewer than two columns.")
    except Exception as e:
        exitWithError(304, f"Runtime error: {e}")



def getRegressors(config):

    print("getRegressors")

    try:
        regressors = config["regressors"]
        config.pop("regressors", None)

        print("Regressors loaded:")
        for regressor in regressors:
            print(regressor)

        return regressors
    
    except KeyError:
        exitWithError(401, "Error: 'regressors' key not found in config.")
    except Exception as e:
        exitWithError(402, f"Runtime error: {e}")



def renameInputDataColumns(regressors, input_data, future_data):
    try:

        # Rename input_data columns to match the regressor names
        for regressor in regressors:
            if "colNum" not in regressor or "name" not in regressor:
                exitWithError(501, "Error: Each regressor must have 'colNum' and 'name' keys.")
            
            col_index = regressor["colNum"]
            col_name = regressor["name"]

            # Check if col_index is valid for input_data
            if not isinstance(col_index, int) or col_index < 0 or col_index >= len(input_data.columns):
                exitWithError(502, f"Error: Invalid column index {col_index} for input_data.")

            # Check if col_index-1 is valid for future_data
            if col_index - 1 < 0 or col_index - 1 >= len(future_data.columns):
                exitWithError(503, f"Error: Invalid column index {col_index-1} for future_data.")

            input_data.rename(columns={input_data.columns[col_index]: col_name}, inplace=True)
            future_data.rename(columns={future_data.columns[col_index-1]: col_name}, inplace=True)
            
        print("input_data dataframe renamed:")
        print(input_data.head(10))  # Print the first 10 records

        print("future_data dataframe renamed:")
        print(future_data.head(10))  # Print the first 10 records
    
    except KeyError as ke:
        exitWithError(504, f"Error: KeyError occurred - {ke}")
    except IndexError as ie:
        exitWithError(505, f"Error: IndexError occurred - {ie}")
    except TypeError as te:
        exitWithError(506, f"Error: TypeError occurred - {te}")
    except Exception as e:
        exitWithError(507, f"Runtime error: {e}")



def dropRegressorsWithTooFewInputDataValues(input_data, regressors):
    try:

        # Ensure regressors is a list of dictionaries with 'name' keys
        if not isinstance(regressors, list) or not all(isinstance(r, dict) and "name" in r for r in regressors):
            exitWithError(601, "Error: regressors must be a list of dictionaries with a 'name' key.")

        # List to keep track of columns to be dropped and corresponding regressors to be removed
        columns_to_drop = []
        regressors_to_remove = []

        # Ensure there are at least 3 columns (since we're checking from the third column onward)
        if len(input_data.columns) < 3:
            print("Warning: input_data does not have enough columns to check for regressors.")
            return input_data, regressors

        columns_to_check = input_data.columns[2:]

        # Identify columns to drop if they have <= 2 non-NaN values
        for column in columns_to_check:
            if input_data[column].dropna().shape[0] <= 2:
                columns_to_drop.append(column)
                # Find the corresponding regressor
                for regressor in regressors:
                    if regressor["name"] == column:
                        regressors_to_remove.append(regressor)
                        break

        # Drop the identified columns from input_data
        if columns_to_drop:
            input_data = input_data.drop(columns=columns_to_drop)

        # remove the identified regressors from the regressors list
        regressors = [r for r in regressors if r not in regressors_to_remove]

        # Print the dropped columns and regressors for verification
        print("Columns dropped from input_data:", columns_to_drop)
        print("Regressors dropped:", regressors_to_remove)

        return input_data, regressors
    
    except KeyError as ke:
        exitWithError(602, f"KeyError: {ke} - The specified column might not exist.")
    except Exception as e:
        exitWithError(603, f"Runtime error: {e}")



def dropRegressorsWithNaNValues(future_data, regressors):
    try:

        # List to keep track of columns to be dropped and corresponding regressors to be removed
        future_columns_to_drop = []
        future_regressors_to_remove = []
        future_columns_to_check = future_data.columns[1:]  # Skip the first column (likely 'ds')

        # Identify columns to drop if they contain any NaN values
        for column in future_columns_to_check:
            if future_data[column].isna().sum() > 0:  # Check if there are any NaN values
                future_columns_to_drop.append(column)
                # Find the corresponding regressor
                for regressor in regressors:
                    if regressor["name"] == column:
                        future_regressors_to_remove.append(regressor)
                        break

        if future_columns_to_drop:
            future_data = future_data.drop(columns=future_columns_to_drop)

        # Optionally, you can remove the identified regressors from the regressors list
        regressors = [r for r in regressors if r not in future_regressors_to_remove]

        # Print the dropped columns and regressors for verification
        print("Columns dropped from future_data:", future_columns_to_drop)
        print("Regressors dropped:", future_regressors_to_remove)

        return future_data, regressors
    
    except KeyError as ke:
        exitWithError(701, f"KeyError: {ke} - The specified column might not exist.")
    except Exception as e:
        exitWithError(702, f"Runtime error: {e}")



# Check for required argument Temp Directory Path
if len(sys.argv) != 2:
    exitWithError(1, "Error: Missing argument: Path to temp directory.")
    sys.exit(1)

input_directory = sys.argv[1]
output_data_file = path.join(input_directory, "output.json")

config = loadConfig(input_directory)
input_data_loaded = getInputData(config)
future_data_loaded = getFutureData(config)
regressors = getRegressors(config)
renameInputDataColumns(regressors, input_data_loaded, future_data_loaded)
input_data_loaded, regressors = dropRegressorsWithTooFewInputDataValues(input_data_loaded, regressors)
input_data = input_data_loaded
future_data_loaded, regressors = dropRegressorsWithNaNValues(future_data_loaded, regressors)
future_data = future_data_loaded

print("Final Regressors:")
for regressor in regressors:
    print(regressor)

print("Final input_data:")
print(input_data.head(10))  # Print the first 10 records
print("number of records: ", len(input_data))

print("final future_data:")
print(future_data.head(10))  # Print the first 10 records
print("number of records: ", len(future_data))



# growth: type of trend growth used in the model
if 'growth' in config:
    if not config['growth'] in ['linear', 'logistic', 'flat'] :
        exitWithError(6, "Error: growth parameter should be either 'linear', 'logistic' or 'flat.")
else:
    config['growth'] = 'linear' # default

# make sure that we have a 'cap' column if the growth is logistic
# in this case, optionally also add a 'floor' column
if config['growth'] == 'logistic':
    if 'logistic_growth_cap' not in config:
        exitWithError(7, "Error: logistic growth requires 'logistic_growth_cap' parameter.")
    else:
        input_data['cap'] = config["logistic_growth_cap"]
        future_data['cap'] = config["logistic_growth_cap"]
    if 'logistic_growth_floor' in config:
        if config["logistic_growth_floor"] != None:
            input_data['floor'] = config["logistic_growth_floor"]
            future_data['floor'] = config["logistic_growth_floor"]
config.pop("logistic_growth_cap", None)
config.pop("logistic_growth_floor", None)

# mcmc_samples: To control the number of MCMC samples for the Bayesian estimation of the model parameters
# critical setting for controlling how Prophet estimates model parameters, balancing between computational efficiency and the thoroughness of uncertainty estimation.
# ToDo: check that it's an integer >= 0
if not 'mcmc_samples' in config:
    config['mcmc_samples'] = 0 # default

# stan_backend: backend to use for fitting the model with Stan, the probabilistic programming language used for Bayesian inference. This parameter allows users to choose the computational engine that runs the Stan code for model fitting
if not 'stan_backend' in config:
    config['stan_backend'] = None # default
else:
    if not config['stan_backend'] in ['cmdstanpy', 'pystan', None]:
        exitWithError(8, "Error: stan_backend parameter should be either 'cmdstanpy' or 'pystan'.")

# interval_width: the width of the uncertainty intervals for the forecasts. This parameter determines the range within which the true values are expected to lie with a certain level of confidence.
# The interval_width corresponds to the confidence level for the uncertainty intervals. For instance, an interval_width of 0.90 implies that there is a 90% chance that the true value will fall within the predicted interval
if not 'interval_width' in config:
    config['interval_width'] = 0.8 # default
# ToDo: check that it's a float between 0 and 1

# n_changepoints: controls the number of potential changepoints in the trend of the time series. Changepoints are locations in the data where the trend can change direction or rate.
if not 'n_changepoints' in config:
    config['n_changepoints'] = 25 # default
# ToDo: check integer >= 0

# seasonality_mode: specifies how seasonality components are modeled. This parameter determines whether the seasonal effects (such as daily, weekly, and yearly patterns) are treated as additive or multiplicative.
if not 'seasonality_mode' in config:
    config['seasonality_mode'] = 'additive' # default
else:
    if not config['seasonality_mode'] in ['additive', 'multiplicative']:
        exitWithError(9, "Error: seasonality_mode parameter should be either 'additive' or 'multiplicative'.")

# changepoint_range: specifies the proportion of the historical data within which the model will look for potential changepoints. Changepoints are points in time where the trend changes its direction or rate of growth/decline.
if not "changepoint_range" in config:
    config["changepoint_range"] = 0.8 # default
# ToDo: check float between 0 and 1

# daily_seasonality: specifies whether to include daily seasonality in the model and, if so, how to configure it. Daily seasonality captures patterns that repeat every 24 hours, which can be crucial for data that exhibits intra-day variations, such as website traffic, energy consumption, or sales data.
if not 'daily_seasonality' in config:
    config['daily_seasonality'] = 'auto' # default
else:
    if not config['daily_seasonality'] in ['auto', True, False]:
        # ToDo: there is one correct use case: an integer: Specifies the Fourier order of the seasonal components for the daily seasonality, which determines the flexibility and complexity of the seasonal pattern.
        exitWithError(10, "Error: daily_seasonality parameter should be either 'auto', True or False.")

# weekly_seasonality: specifies whether to include weekly seasonality in the model and, if so, how to configure it. Weekly seasonality captures patterns that repeat every week, which can be essential for data that shows variations depending on the day of the week, such as retail sales, website traffic, or call center volumes.
if not 'weekly_seasonality' in config:
    config['weekly_seasonality'] = 'auto' # default
else:
    if not config['weekly_seasonality'] in ['auto', True, False]:
        # ToDo: there is one correct use case: an integer: Specifies the Fourier order of the seasonal components for the weekly seasonality, which determines the flexibility and complexity of the seasonal pattern.
        exitWithError(11, "Error: weekly_seasonality parameter should be either 'auto', True or False.")

# yearly_seasonality: specifies whether to include yearly seasonality in the model and, if so, how to configure it. Yearly seasonality captures patterns that repeat every year, which can be crucial for data that exhibits annual variations, such as sales influenced by holidays, weather patterns, or agricultural cycles.
if not 'yearly_seasonality' in config:
    config['yearly_seasonality'] = 'auto' # default
else:
    if not config['yearly_seasonality'] in ['auto', True, False]:
        # ToDo: there is one correct use case: an integer: Specifies the Fourier order of the seasonal components for the yearly seasonality, which determines the flexibility and complexity of the seasonal pattern.
        exitWithError(12, "Error: yearly_seasonality parameter should be either 'auto', True or False.")

# uncertainty_samples: To specify the number of Monte Carlo samples to draw for estimating the uncertainty intervals of the forecast
if not "uncertainty_samples" in config:
    config["uncertainty_samples"] = 1000 # default
# ToDo: check for integer >= 0

# changepoint_prior_scale: controls the flexibility of the automatic changepoint detection. This parameter influences how strongly the model allows the trend to change at the identified changepoints
if not 'changepoint_prior_scale' in config:
    config['changepoint_prior_scale'] = 0.05 # default
# ToDo: check for positive float > 0

# seasonality_prior_scale: controls the flexibility of the seasonal components. This parameter influences how much the seasonality can vary, thereby impacting the overall seasonality fit in the model.
if not 'seasonality_prior_scale' in config:
    config['seasonality_prior_scale'] = 10.0 # default
# ToDo: check for positive float > 0

# country_holidays_code: specifies the country code for the country-specific holidays to include in the model. This parameter allows users to add holidays that are specific to a particular country, such as national holidays, public observances, or cultural events.
if not 'holidays_country_code' in config:
    country_holidays_code = None # default
else:
    country_holidays_code = config["holidays_country_code"]
# ToDo: check if string, i. e. "US" or "DE"
config.pop("holidays_country_code", None)

seasonalities = config["seasonalities"] if config.get("seasonalities", False) else []
config.pop("seasonalities", None)

if 'cutoff_minimum' in config:
    cutoff_minimum = config["cutoff_minimum"]
else:
    cutoff_minimum = None
config.pop("cutoff_minimum", None)

if 'log_transform' in config:
    log_transform = config["log_transform"]

    # Apply log transformation to the target variable
    input_data['y'] = np.log(input_data['y'] + 1)  # Adding 1 to avoid log(0)
else:
    log_transform = False
config.pop("log_transform", None)

if not 'anomaly_threshold_n_stddev' in config:
    anomaly_threshold_n_stddev = 3.0 # default
else:
    # ToDo: check for positive float > 0
    anomaly_threshold_n_stddev = config["anomaly_threshold_n_stddev"]
config.pop("anomaly_threshold_n_stddev", None)

# ToDo: implement handing over a constant 'changepoints' list of dates
# ToDo: implement handing over a constant 'holidays' list of dates with 'holiday, 'ds', 'lower_window', 'upper_window' columns
# ToDo: implement holidays_prior_scale


# create an instance of the Prophet model
try:
    model = Prophet(** config)
except Exception as e:
    exitWithError(13, f"Error: An unexpected error occurred: {str(e)} when creating an instance of the Prophet model.")


# add country holidays if the 'holidays_country_code' is set
if country_holidays_code:
    try:
        model.add_country_holidays(country_name = country_holidays_code)
        # ToDo: implement 'lower_window', 'upper_window' and 'holidays' parameters
    except ValueError as e:
        exitWithError(14, f"Error: the country code {country_holidays_code} for adding holidays is not supported.")
    except Exception as e:
        exitWithError(15, f"Error: An unexpected error occurred: {str(e)} when adding country holidays to the Prophet model.")


# Add regressors
try:
    for regressor in regressors:
        # standardize: True, False or 'auto' (default)
        # mode: 'additive' (default) or 'multiplicative'
        # prior_scale: positive float > 0, default 10.0
        if not 'standardize' in regressor:
            regressor["standardize"] = 'auto'
        if not 'mode' in regressor:
            regressor["mode"] = 'additive'
        if not 'prior_scale' in regressor:
            regressor["prior_scale"] = 10.0
        model.add_regressor(regressor["name"], prior_scale=regressor["prior_scale"], standardize=regressor["standardize"], mode=regressor["mode"])
except ValueError as e:
    exitWithError(16, f"Error: {e} when adding a regressor to the Prophet model.")
except Exception as e:
    exitWithError(17, f"Error: An unexpected error occurred: {str(e)} when adding a regressor to the Prophet model.")


# Add seasonalities
for s in seasonalities:
    try:
        if s["condition_name"] == "null":
            model.add_seasonality(s["name"], s["period"], s["fourier_order"], s["prior_scale"], s["mode"])
        else:
            model.add_seasonality(s["name"], s["period"], s["fourier_order"], s["prior_scale"], s["mode"], s["condition_name"])
    except ValueError as e:
        exitWithError(18, f"Error: {e} when adding a seasonality to the Prophet model.")


# fit the model
try:
    model.fit(input_data)
except ValueError as e:
    exitWithError(19, f"Error: {e} while fitting model.")


# Make predictions for the future time periods
try:
    forecast = model.predict(future_data)
except Exception as e:
    exitWithError(20, f"Error: {e} while predicting future data.")

# Make predictions for the past
try:
    forecast_past = model.predict(input_data)
except Exception as e:
    exitWithError(21, f"Error: {e} while predicting past data.")


if log_transform:
    # Apply inverse log transformation to the forecasted values
    forecast['yhat'] = np.exp(forecast['yhat']) - 1
    forecast['yhat_lower'] = np.exp(forecast['yhat_lower']) - 1
    forecast['yhat_upper'] = np.exp(forecast['yhat_upper']) - 1

    # Apply inverse log transformation to the forecasted values
    forecast_past['yhat'] = np.exp(forecast_past['yhat']) - 1
    forecast_past['yhat_lower'] = np.exp(forecast_past['yhat_lower']) - 1
    forecast_past['yhat_upper'] = np.exp(forecast_past['yhat_upper']) - 1

    # Apply inverse log transformation to the input data
    input_data['y'] = np.exp(input_data['y']) - 1

# Post-process the forecast to set negative values to zero
if isinstance(cutoff_minimum, (int, float)) and cutoff_minimum is not None:
    forecast['yhat'] = forecast['yhat'].apply(lambda x: max(cutoff_minimum, x))
    forecast['yhat_lower'] = forecast['yhat_lower'].apply(lambda x: max(cutoff_minimum, x))
    forecast['yhat_upper'] = forecast['yhat_upper'].apply(lambda x: max(cutoff_minimum, x))
    forecast_past['yhat'] = forecast_past['yhat'].apply(lambda x: max(cutoff_minimum, x))
    forecast_past['yhat_lower'] = forecast_past['yhat_lower'].apply(lambda x: max(cutoff_minimum, x))
    forecast_past['yhat_upper'] = forecast_past['yhat_upper'].apply(lambda x: max(cutoff_minimum, x))


residuals = input_data['y'] - forecast_past['yhat']


# wMAPE berechnen
wMAPE = np.sum(np.abs(residuals)) / np.sum(input_data['y'])
print("wMAPE: ", wMAPE)


# Standardabweichungen der Residuen berechnen
stddev = residuals.std()

# Schwellenwert festlegen
threshold = anomaly_threshold_n_stddev * stddev

# Anomalien markieren
forecast_past['anomaly'] = (abs(residuals) > threshold).astype(int)

forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')
forecast_past['ds'] = forecast_past['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')
model.history['ds'] = model.history['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')
model.changepoints = model.changepoints.dt.strftime('%Y-%m-%d %H:%M:%S')

if config.get("holidays", False) and not config["holidays"].empty:
    model.holidays['ds'] = model.holidays['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')

# ToDo: forecast to dict

output = {
    "forecast": forecast.to_dict(orient='records'),
    "forecast_past": forecast_past.to_dict(orient='records'),
    "wMAPE": wMAPE,
    "model_holidays": model.holidays.to_dict(orient='records') if model.holidays else None,
    "model_train_holiday_names": model.train_holiday_names.to_dict(orient='records') if model.train_holiday_names else None,
    "model_history": model.history.to_dict(orient='records'),
    "model_changepoints": model.changepoints.to_dict(),
    "model_seasonalities": model.seasonalities,
    "model_country_holidays": model.country_holidays
}

try:
    with open(output_data_file, "w") as fs:
        fs.write(json.dumps(output))
except Exception as e:
    exitWithError(22, f"Error: {e} while writing output.json.")

print("--end processing--")
