import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import matplotlib.backends
import matplotlib.pyplot as plt
from datetime import datetime
import os

class SubmissionTracker:
    def __init__(self, final=True):
        if final:
            self.solution_file = 'data/numerai_test_solution.csv'
            self.log_file = 'data/performance_log.csv'
        else:
            self.solution_file = 'data/numerai_dev_test_solution.csv'
            self.log_file = 'data/performance_dev_log.csv'
        self.final = final
        
        if not os.path.exists(self.log_file):
            pd.DataFrame(columns=['timestamp', 'model_name', 'roc_auc']).to_csv(self.log_file, index=False) # , 'precision', 'recall', 'f1_score'
    
    def score_prediction(self, preds, model_name='Unnamed Model'):
        df_test_solution = pd.read_csv(self.solution_file)
        #segments = [(i * 5000, (i+1) * 5000) for i in range(0, 4)]
        metrics = {'roc_auc': []} # , 'precision': [], 'recall': [], 'f1_score': []

        if preds.shape[0] == 17337 and self.final:
            raise ValueError("It seems you used the dev_test set for a final submission. Use the test set instead.")
        elif preds.shape[0] == 21191 and not self.final:
            raise ValueError("It seems you used the test set for a development submission. Use the dev_test set instead.")
        overall_metrics = {'roc_auc': roc_auc_score(df_test_solution.values[:, -1], preds[:, 1])}

        #print('Over time', {m: metrics[m] for m in ['roc_auc']})
        print('ROC AUC ', overall_metrics['roc_auc'])
        
        self._log_performance(model_name, overall_metrics)
        self._compare_with_previous(model_name, overall_metrics)
        self._plot_performance_history(model_name)
        self.compare_all_models()
        
        return overall_metrics
    
    def _log_performance(self, model_name, metrics):
        log_df = pd.read_csv(self.log_file)

        #if log_df.shape[0] + 1 > 5 and self.final:
        #    raise Exception(f"Maximum number of submission reached.")
        #if self.final:
        #    print(f"This is submission {log_df.shape[0] + 1}.")
        
        new_row = pd.DataFrame({
            'timestamp': [datetime.now()],
            'model_name': [model_name],
            **{k: [v] for k, v in metrics.items()}
        })
        log_df = pd.concat([log_df, new_row], ignore_index=True)
        log_df.to_csv(self.log_file, index=False)
        print(f"Performance logged for {model_name}")
    
    def _compare_with_previous(self, model_name, current_metrics):
        log_df = pd.read_csv(self.log_file)
        previous_runs = log_df[log_df['model_name'] == model_name].sort_values('timestamp')
        
        if len(previous_runs) > 1:
            last_run = previous_runs.iloc[-2]
            print(f"\nComparison with previous run for {model_name}:")
            for metric in ['roc_auc']: # , 'precision', 'recall', 'f1_score'
                diff = current_metrics[metric] - last_run[metric]
                print(f"{metric.capitalize()}: {diff:.4f} ({'improved' if diff > 0 else 'decreased'})")
        else:
            print(f"\nThis is the first run for {model_name}. No comparison available.")
    
    def _plot_performance_history(self, model_name):
        log_df = pd.read_csv(self.log_file)
        model_history = log_df[log_df['model_name'] == model_name].sort_values('timestamp')
        
        plt.figure(figsize=(12, 6))
        for metric in ['roc_auc']: # , 'precision', 'recall', 'f1_score'
            plt.plot(model_history['timestamp'], model_history[metric], label=metric, marker='o')

            if metric == 'roc_auc':
                plt.ticklabel_format(style='plain', axis='y')  # force plain numbers on y-axis
                plt.gca().get_yaxis().get_major_formatter().set_scientific(False)
                plt.ylim(0.5, 1)

        plt.title(f'{model_name} Performance History')
        plt.xlabel('Submission Time')
        plt.ylabel('Score')
        plt.legend()
        plt.xticks(rotation=45)
        plt.tight_layout()
    
    def compare_all_models(self):
        log_df = pd.read_csv(self.log_file)
        latest_performances = log_df.groupby('model_name').last().reset_index()
        
        metrics = ['roc_auc'] # , 'precision', 'recall', 'f1_score'
        
        # Table comparison
        print("\nAll models:")
        print(latest_performances[['model_name'] + metrics].to_string(index=False))
        
        # Visualization
        if False:
            plt.figure(figsize=(12, 6))
            bar_width = 0.2
            index = np.arange(len(latest_performances))
            
            for i, metric in enumerate(metrics):
                plt.bar(index + i*bar_width, latest_performances[metric], bar_width, label=metric)
            
            plt.xlabel('Models')
            plt.ylabel('Score')
            plt.title('Performance Comparison Across All Models')
            plt.xticks(index + bar_width, latest_performances['model_name'], rotation=45, ha='right')
            plt.legend()
            plt.tight_layout()