import math
import pandas as pd
import numpy as np
#from prophet import Prophet
from neuralprophet import NeuralProphet
from os import path
import sys
import json
import logging

logger = logging.getLogger("cmdstanpy")
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.WARNING)

logging.getLogger("prophet").setLevel(logging.WARNING)
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)

print("--start prophet--")

# Check for required argument Temp Directory Path
if len(sys.argv) != 2:
    print("Error: Missing argument: Path to temp directory.")
    sys.exit(1)

input_directory = sys.argv[1]
input_config_file = path.join(input_directory, "config.json")
output_data_file = path.join(input_directory, "output.json")

# read to read input config file
try:
    with open(input_config_file) as fs:
        input_config_json_string = fs.read()
except FileNotFoundError:
    print(f"Error: File {input_config_file} does not exist.")
    sys.exit(2)
except PermissionError:
    print(f"Error: Permission denied to open config file {input_config_file}.")
    sys.exit(3)
except Exception as e:
    print(f"Error: An unexpected error occurred: {str(e)} when opening and loading the config file {input_config_file}.")
    sys.exit(4)

try:
    config = json.loads(input_config_json_string)
except ValueError:
    print(f"Error: Could not decode JSON from the specified config file.")
    sys.exit(5)
except Exception as e:
    print(f"Error: An unexpected error occurred: {str(e)} while decoding the specified config file.")


# load input data (handed over by .php)
input_data_loaded = pd.DataFrame(config["input_data"])
config.pop("input_data", None)

# load future data (handed over by .php)
future_data_loaded = pd.DataFrame(config["future_data"])
config.pop("future_data", None)

print("Original DataFrame Past:")
print(input_data_loaded.head(10))  # Print the first 10 records

print("Original DataFrame Future:")
print(future_data_loaded.head(10))  # Print the first 10 records

# Past: rename the first two columns to 'ds' and 'y'
# ToDo: only do this if not already named 'ds' and 'y'
input_data_loaded.rename(columns={input_data_loaded.columns[0]: "ds", input_data_loaded.columns[1]: "y"}, inplace=True)

# Future: rename the first column to 'ds'
# ToDo: only do this if not already named 'ds'
future_data_loaded.rename(columns={future_data_loaded.columns[0]: "ds"}, inplace=True)

# convert the 'ds' columns to datetime
input_data_loaded['ds'] = pd.to_datetime(input_data_loaded['ds'], format='%Y-%m-%d %H:%M:%S')
future_data_loaded['ds'] = pd.to_datetime(future_data_loaded['ds'], format='%Y-%m-%d %H:%M:%S')


# load regressors (handed over by .php)
regressors = config["regressors"]   # array of regressor objects with keys: 'name', 'colNum'
config.pop("regressors", None)

# rename input_data columns to mach the regressor names
for regressor in regressors:
    col_index = regressor["colNum"]
    col_name = regressor["name"]
    input_data_loaded.rename(columns={input_data_loaded.columns[col_index]: col_name}, inplace=True)
    future_data_loaded.rename(columns={future_data_loaded.columns[col_index-1]: col_name}, inplace=True)

print("Original DataFrame Past renamed:")
print(input_data_loaded.head(10))  # Print the first 10 records

print("Original DataFrame Future renamed:")
print(future_data_loaded.head(10))  # Print the first 10 records

# List to keep track of columns to be dropped and corresponding regressors to be removed
columns_to_drop = []
regressors_to_remove = []
columns_to_check = input_data_loaded.columns[2:]

# Identify columns to drop if they have <= 2 non-NaN values
for column in columns_to_check:
    if input_data_loaded[column].dropna().shape[0] <= 2:
        columns_to_drop.append(column)
        # Find the corresponding regressor
        for regressor in regressors:
            if regressor["name"] == column:
                regressors_to_remove.append(regressor)
                break


# Drop the identified columns
input_data = input_data_loaded.drop(columns=columns_to_drop)
future_data = future_data_loaded

del input_data_loaded
del future_data_loaded

# Remove the corresponding regressors
regressors = [regressor for regressor in regressors if regressor not in regressors_to_remove]

print("Regressors:")
for regressor in regressors:
    print(regressor)

print("Input DataFrame Past:")
print(input_data.head(10))  # Print the first 10 records

print("Input DataFrame Future:")
print(future_data.head(10))  # Print the first 10 records
print("number of records: ", len(future_data))


# growth: type of trend growth used in the model
if 'growth' in config:
    if not config['growth'] in ['linear', 'logistic'] :
        print(f"Error: growth parameter should be either 'linear' or 'logistic'.")
        sys.exit(6)
else:
    config['growth'] = 'linear' # default

# make sure that we have a 'cap' column if the growth is logistic
# in this case, optionally also add a 'floor' column
if config['growth'] == 'logisticXXXX':  # ToDo: Neuralprophet doesn't support logistic growth+cap+floor
    if 'logistic_growth_cap' not in config:
        print(f"Error: logistic growth requires 'logistic_growth_cap' parameter.")
        sys.exit(7)
    else:
        input_data['cap'] = config["logistic_growth_cap"]
        future_data['cap'] = config["logistic_growth_cap"]
    if 'logistic_growth_floor' in config:
        if config["logistic_growth_floor"] != None:
            input_data['floor'] = config["logistic_growth_floor"]
            future_data['floor'] = config["logistic_growth_floor"]
config.pop("logistic_growth_cap", None)
config.pop("logistic_growth_floor", None)

# mcmc_samples: To control the number of MCMC samples for the Bayesian estimation of the model parameters
# critical setting for controlling how Prophet estimates model parameters, balancing between computational efficiency and the thoroughness of uncertainty estimation.
# ToDo: check that it's an integer >= 0
if not 'mcmc_samples' in config:
    config['mcmc_samples'] = 0 # default
config.pop("mcmc_samples", None)    # Neuralprophet doesn't support mcmc_samples

# stan_backend: backend to use for fitting the model with Stan, the probabilistic programming language used for Bayesian inference. This parameter allows users to choose the computational engine that runs the Stan code for model fitting
if not 'stan_backend' in config:
    config['stan_backend'] = None # default
else:
    if not config['stan_backend'] in ['cmdstanpy', 'pystan', None]:
        print(f"Error: stan_backend parameter should be either 'cmdstanpy' or 'pystan'.")
        sys.exit(8)
config.pop("stan_backend", None)    # Neuralprophet doesn't support stan_backend

# interval_width: the width of the uncertainty intervals for the forecasts. This parameter determines the range within which the true values are expected to lie with a certain level of confidence.
# The interval_width corresponds to the confidence level for the uncertainty intervals. For instance, an interval_width of 0.90 implies that there is a 90% chance that the true value will fall within the predicted interval
if not 'interval_width' in config:
    config['interval_width'] = 0.8 # default
# ToDo: check that it's a float between 0 and 1
config.pop("interval_width", None)    # Neuralprophet doesn't support interval_width

# n_changepoints: controls the number of potential changepoints in the trend of the time series. Changepoints are locations in the data where the trend can change direction or rate.
if not 'n_changepoints' in config:
    config['n_changepoints'] = 25 # default
# ToDo: check integer >= 0

# seasonality_mode: specifies how seasonality components are modeled. This parameter determines whether the seasonal effects (such as daily, weekly, and yearly patterns) are treated as additive or multiplicative.
if not 'seasonality_mode' in config:
    config['seasonality_mode'] = 'additive' # default
else:
    if not config['seasonality_mode'] in ['additive', 'multiplicative']:
        print(f"Error: seasonality_mode parameter should be either 'additive' or 'multiplicative'.")
        sys.exit(9)

# changepoint_range: specifies the proportion of the historical data within which the model will look for potential changepoints. Changepoints are points in time where the trend changes its direction or rate of growth/decline.
if not "changepoint_range" in config:
    config["changepoint_range"] = 0.8 # default
# ToDo: check float between 0 and 1
config.pop("changepoint_range", None)    # Neuralprophet doesn't support changepoint_range

# daily_seasonality: specifies whether to include daily seasonality in the model and, if so, how to configure it. Daily seasonality captures patterns that repeat every 24 hours, which can be crucial for data that exhibits intra-day variations, such as website traffic, energy consumption, or sales data.
if not 'daily_seasonality' in config:
    config['daily_seasonality'] = True # default
else:
    if not config['daily_seasonality'] in ['auto', True, False]:
        # ToDo: there is one correct use case: an integer: Specifies the Fourier order of the seasonal components for the daily seasonality, which determines the flexibility and complexity of the seasonal pattern.
        print(f"Error: daily_seasonality parameter should be either 'auto', True or False.")
        sys.exit(10)

# weekly_seasonality: specifies whether to include weekly seasonality in the model and, if so, how to configure it. Weekly seasonality captures patterns that repeat every week, which can be essential for data that shows variations depending on the day of the week, such as retail sales, website traffic, or call center volumes.
if not 'weekly_seasonality' in config:
    config['weekly_seasonality'] = True # default
else:
    if not config['weekly_seasonality'] in ['auto', True, False]:
        # ToDo: there is one correct use case: an integer: Specifies the Fourier order of the seasonal components for the weekly seasonality, which determines the flexibility and complexity of the seasonal pattern.
        print(f"Error: weekly_seasonality parameter should be either 'auto', True or False.")
        sys.exit(11)

# yearly_seasonality: specifies whether to include yearly seasonality in the model and, if so, how to configure it. Yearly seasonality captures patterns that repeat every year, which can be crucial for data that exhibits annual variations, such as sales influenced by holidays, weather patterns, or agricultural cycles.
if not 'yearly_seasonality' in config:
    config['yearly_seasonality'] = True # default
else:
    if not config['yearly_seasonality'] in ['auto', True, False]:
        # ToDo: there is one correct use case: an integer: Specifies the Fourier order of the seasonal components for the yearly seasonality, which determines the flexibility and complexity of the seasonal pattern.
        print(f"Error: yearly_seasonality parameter should be either 'auto', True or False.")
        sys.exit(12)

# uncertainty_samples: To specify the number of Monte Carlo samples to draw for estimating the uncertainty intervals of the forecast
if not "uncertainty_samples" in config:
    config["uncertainty_samples"] = 1000 # default
# ToDo: check for integer >= 0
config.pop("uncertainty_samples", None)    # Neuralprophet doesn't support uncertainty_samples

# changepoint_prior_scale: controls the flexibility of the automatic changepoint detection. This parameter influences how strongly the model allows the trend to change at the identified changepoints
if not 'changepoint_prior_scale' in config:
    config['changepoint_prior_scale'] = 0.05 # default
# ToDo: check for positive float > 0
config.pop("changepoint_prior_scale", None)    # Neuralprophet doesn't support changepoint_prior_scale

# seasonality_prior_scale: controls the flexibility of the seasonal components. This parameter influences how much the seasonality can vary, thereby impacting the overall seasonality fit in the model.
if not 'seasonality_prior_scale' in config:
    config['seasonality_prior_scale'] = 10.0 # default
# ToDo: check for positive float > 0
config.pop("seasonality_prior_scale", None)    # Neuralprophet doesn't support seasonality_prior_scale

# country_holidays_code: specifies the country code for the country-specific holidays to include in the model. This parameter allows users to add holidays that are specific to a particular country, such as national holidays, public observances, or cultural events.
if not 'holidays_country_code' in config:
    country_holidays_code = None # default
else:
    country_holidays_code = config["holidays_country_code"]
# ToDo: check if string, i. e. "US" or "DE"
config.pop("holidays_country_code", None)

seasonalities = config["seasonalities"] if config.get("seasonalities", False) else []
config.pop("seasonalities", None)

if 'cutoff_minimum' in config:
    cutoff_minimum = config["cutoff_minimum"]
else:
    cutoff_minimum = None
config.pop("cutoff_minimum", None)

if 'log_transform' in config:
    log_transform = config["log_transform"]

    # Apply log transformation to the target variable
    input_data['y'] = np.log(input_data['y'] + 1)  # Adding 1 to avoid log(0)
else:
    log_transform = False
config.pop("log_transform", None)

if not 'anomaly_threshold_n_stddev' in config:
    anomaly_threshold_n_stddev = 3.0 # default
else:
    # ToDo: check for positive float > 0
    anomaly_threshold_n_stddev = config["anomaly_threshold_n_stddev"]
config.pop("anomaly_threshold_n_stddev", None)

# ToDo: implement handing over a constant 'changepoints' list of dates
# ToDo: implement handing over a constant 'holidays' list of dates with 'holiday, 'ds', 'lower_window', 'upper_window' columns
# ToDo: implement holidays_prior_scale

# Neuralprophet requires the 'y' column
if 'y' not in future_data.columns:
    future_data['y'] = None

# create an instance of the Prophet model
try:
    #model = Prophet(** config)
    model = NeuralProphet(** config)
except Exception as e:
    print(f"Error: An unexpected error occurred: {str(e)} when creating an instance of the Prophet model.")
    sys.exit(13)


# add country holidays if the 'holidays_country_code' is set
if country_holidays_code:
    try:
        model.add_country_holidays(country_name = country_holidays_code)
        # ToDo: implement 'lower_window', 'upper_window' and 'holidays' parameters
    except ValueError as e:
        print(f"Error: the country code {country_holidays_code} for adding holidays is not supported.")
        sys.exit(14)
    except Exception as e:
        print(f"Error: An unexpected error occurred: {str(e)} when adding country holidays to the Prophet model.")
        sys.exit(15)


# Add regressors
try:
    for regressor in regressors:
        # standardize: True, False or 'auto' (default)
        # mode: 'additive' (default) or 'multiplicative'
        # prior_scale: positive float > 0, default 10.0
        if not 'standardize' in regressor:
            regressor["standardize"] = 'auto'
        if not 'mode' in regressor:
            regressor["mode"] = 'additive'
        if not 'prior_scale' in regressor:
            regressor["prior_scale"] = 10.0
        #model.add_regressor(regressor["name"], prior_scale=regressor["prior_scale"], standardize=regressor["standardize"], mode=regressor["mode"])
        model.add_future_regressor(regressor["name"])
except ValueError as e:
    print(f"Error: {e} when adding a regressor to the Prophet model.")
    sys.exit(16)
except Exception as e:
    print(f"Error: An unexpected error occurred: {str(e)} when adding a regressor to the Prophet model.")
    sys.exit(17)


# Add seasonalities
for s in seasonalities:
    try:
        if s["condition_name"] == "null":
            model.add_seasonality(s["name"], s["period"], s["fourier_order"], s["prior_scale"], s["mode"])
        else:
            model.add_seasonality(s["name"], s["period"], s["fourier_order"], s["prior_scale"], s["mode"], s["condition_name"])
    except ValueError as e:
        print(f"Error: {e} when adding a seasonality to the Prophet model.")
        sys.exit(18)


# fit the model
try:
    model.fit(input_data)
except ValueError as e:
    print(f"Error: {e} while fitting model.")
    sys.exit(19)


# Make predictions for the future time periods
try:
    forecast = model.predict(future_data)
except Exception as e:
    print(f"Error: {e} while predicting future data.")
    sys.exit(20)

# Make predictions for the past
try:
    forecast_past = model.predict(input_data)
except Exception as e:
    print(f"Error: {e} while predicting past data.")
    sys.exit(21)

# Set pandas display options to show all columns
pd.set_option('display.max_columns', None)  # Show all columns
pd.set_option('display.width', None)        # Adjust the width to avoid line breaks
pd.set_option('display.max_rows', None)     # Show all rows if needed

print("Output DataFrame Forecast Future:")
print(forecast.head(10))  # Print the first 10 records
print("number of records: ", len(future_data))


if log_transform:
    # Apply inverse log transformation to the forecasted values
    forecast['yhat1'] = np.exp(forecast['yhat1']) - 1
    #forecast['yhat_lower'] = np.exp(forecast['yhat_lower']) - 1
    #forecast['yhat_upper'] = np.exp(forecast['yhat_upper']) - 1

    # Apply inverse log transformation to the forecasted values
    forecast_past['yhat1'] = np.exp(forecast_past['yhat1']) - 1
    #forecast_past['yhat_lower'] = np.exp(forecast_past['yhat_lower']) - 1
    #forecast_past['yhat_upper'] = np.exp(forecast_past['yhat_upper']) - 1

    # Apply inverse log transformation to the input data
    input_data['y'] = np.exp(input_data['y']) - 1

# Post-process the forecast to set negative values to zero
if isinstance(cutoff_minimum, (int, float)) and cutoff_minimum is not None:
    forecast['yha1t'] = forecast['yhat1'].apply(lambda x: max(cutoff_minimum, x))
    #forecast['yhat_lower'] = forecast['yhat_lower'].apply(lambda x: max(cutoff_minimum, x))
    #forecast['yhat_upper'] = forecast['yhat_upper'].apply(lambda x: max(cutoff_minimum, x))
    forecast_past['yhat1'] = forecast_past['yhat1'].apply(lambda x: max(cutoff_minimum, x))
    #forecast_past['yhat_lower'] = forecast_past['yhat_lower'].apply(lambda x: max(cutoff_minimum, x))
    #forecast_past['yhat_upper'] = forecast_past['yhat_upper'].apply(lambda x: max(cutoff_minimum, x))


residuals = input_data['y'] - forecast_past['yhat1']


# wMAPE berechnen
wMAPE = np.sum(np.abs(residuals)) / np.sum(input_data['y'])
print("wMAPE: ", wMAPE)


# Standardabweichungen der Residuen berechnen
stddev = residuals.std()


# ToDo: Threshold als Parameter 2.0 bis 10.0

# Schwellenwert festlegen
threshold = anomaly_threshold_n_stddev * stddev

# Anomalien markieren
forecast_past['anomaly'] = (abs(residuals) > threshold).astype(int)

forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')
forecast_past['ds'] = forecast_past['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')
#model.history['ds'] = model.history['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')
#model.changepoints = model.changepoints.dt.strftime('%Y-%m-%d %H:%M:%S')

#if config.get("holidays", False) and not config["holidays"].empty:
#    model.holidays['ds'] = model.holidays['ds'].dt.strftime('%Y-%m-%d %H:%M:%S')

# ToDo: forecast to dict

def replace_nan(obj):
    if isinstance(obj, float) and math.isnan(obj):
        return None
    elif isinstance(obj, dict):
        return {k: replace_nan(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [replace_nan(elem) for elem in obj]
    else:
        return obj

output = {
    "forecast": forecast.to_dict(orient='records'),
    "forecast_past": forecast_past.to_dict(orient='records'),
    "wMAPE": wMAPE
    #"model_holidays": model.holidays.to_dict(orient='records') if model.holidays else None,
    #"model_train_holiday_names": model.train_holiday_names.to_dict(orient='records') if model.train_holiday_names else None,
    #"model_history": model.history.to_dict(orient='records'),
    #"model_changepoints": model.changepoints.to_dict(),
    #"model_seasonalities": model.seasonalities,
    #"model_country_holidays": model.country_holidays
}

output = replace_nan(output)

try:
    with open(output_data_file, "w") as fs:
        fs.write(json.dumps(output, ensure_ascii=True))
except Exception as e:
    print(f"Error: {e} while writing output.json.")
    sys.exit(22)
