import abc
import datetime as dt
import json
import os
import random
import sys
import textwrap
import traceback
import warnings
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pandas as pd
import simplejson
from flair.data import Sentence
from flair.models import TextClassifier
from google.auth import load_credentials_from_file
from langchain_community.callbacks import get_openai_callback
from langchain_community.chat_models import ChatOpenAI
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_google_vertexai import VertexAI
from nltk.sentiment import SentimentIntensityAnalyzer
from textblob import TextBlob
from tqdm import tqdm
from transformers import BertForSequenceClassification as BertModel
from transformers import BertTokenizerFast as BertTokens
from transformers import pipeline


class Sentiment(metaclass=abc.ABCMeta):
    """
    Abstract base class for Sentiment models.
    """

    def __init__(self, name: str):
        """
        Instantiate the Sentiment model.

        :param name: the name of the model.
        """
        self.name = name
        self.total_cost = 0.0
        self.total_runtime = 0.0

    @abc.abstractmethod
    def sentiment_label(self, sentence: str) -> int:
        """
        Abstract method for returning a sentiment label from a model.

        :param sentence: the sentence we want to classify the sentiment of.
        :return: a heuristic for the sentiment of the given sentence.
        """
        raise NotImplementedError("Implement `sentiment_label`.")

    def wrapped_sentiment_label(self, sentence: str) -> int:
        """
        TODO

        :param sentence:
        :return:
        """
        try:
            # Start the timer on the model.
            t0 = dt.datetime.utcnow()

            # Call the sentiment label method internally.
            label = self.sentiment_label(sentence=sentence)

            # Add the runtime to the internal runtime counter.
            self.total_runtime += (dt.datetime.utcnow() - t0).total_seconds()

            # Return label.
            return label

        except Exception:
            print(self.name, traceback.format_exc())
            return 0  # Return a neutral label.


class HumanSentiment(Sentiment):

    def __init__(self, name="Human"):
        """
        Instantiates a `HumanSentiment` model. This model will ask the user for labels.

        :param name: the name of the model. The default for `HumanSentiment` is "Human".
        """
        super().__init__(name=name)

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        while True:
            print("\n", textwrap.fill(sentence, 80), ">>>", sep="\n")
            grading = input("Please assign a +, *, OR - rating to this story: ")

            if grading == "+":
                return 1  # Return positive.
            elif grading == "*":
                return 0  # Return neutral.
            elif grading == "-":
                return -1  # Return negative.


class RandomSentiment(Sentiment):

    def __init__(self, name: str = "Random"):
        """
        Instantiates a `RandomSentiment` model. This model uses pattern to label text.

        :param name: the name of the model. The default for `RandomSentiment` is "Random".
        """
        super().__init__(name=name)

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        # Just guess a value randomly.
        return random.choice([-1, 0, 1])


class TextBlobSentiment(Sentiment):

    def __init__(self, name: str = "TextBlob", decision_boundary: float = 0.10):
        """
        Instantiates a `TextBlobSentiment` model. This model uses TextBlob to label text.

        :param name: the name of the model. The default for `TextBlobSentiment` is "TextBlob".
        :param decision_boundary: the polarity threshold for positive or negative.
        """
        super().__init__(name=name)

        self.decision_boundary = decision_boundary

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        polarity = TextBlob(sentence).sentiment.polarity

        if polarity >= abs(self.decision_boundary):
            return 1  # If above, return positive.
        elif polarity <= -abs(self.decision_boundary):
            return -1  # If below, return negative.

        return 0  # Otherwise, return neutral.


class VaderSentiment(Sentiment):

    def __init__(self, name: str = "VADER", decision_boundary: float = 0.05):
        """
        Instantiates a `VaderSentiment` model. This model uses NLTK (VADER) to label text.

        :param name: the name of the model. The default for `VaderSentiment` is "VADER".
        :param decision_boundary: the compound score threshold for positive or negative.
        """
        super().__init__(name=name)

        # Initialize the `vader_model` class attribute.
        self.vader_model = SentimentIntensityAnalyzer()
        self.decision_boundary = decision_boundary

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        compound = self.vader_model.polarity_scores(sentence)['compound']

        if compound >= abs(self.decision_boundary):
            return 1  # If above, return positive.
        elif compound <= -abs(self.decision_boundary):
            return -1  # If below, return negative.

        return 0  # Otherwise, return neutral.


class FlairSentiment(Sentiment):

    def __init__(self, name: str = "Flair", decision_boundary: float = 0.75):
        """
        Instantiates a `VaderSentiment` model. This model uses Flair to label text.

        :param name: the name of the model. The default for `FlairSentiment` is "Flair".
        :param decision_boundary: the probability at which + or - labels are accepted.
        """
        super().__init__(name=name)

        # Initialize the `flair_model` class attribute.
        self.flair_model = TextClassifier.load('en-sentiment')
        self.decision_boundary = decision_boundary

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        sentence_obj = Sentence(sentence)
        self.flair_model.predict(sentence_obj)
        prob = sentence_obj.labels[0].score
        label = sentence_obj.labels[0].value

        if prob < self.decision_boundary:
            return 0  # Return neutral.

        if label == "NEGATIVE":
            return -1  # Return negative.
        else:
            return 1  # Return positive.


class FinBERTSentiment(Sentiment):

    def __init__(self, name: str = "FinBERT", version: str = "ProsusAI/finbert"):
        """
        Instantiates a `FinBERTSentiment` model. This model uses FinBERT to label text.

        :param name: the name of the model. The default for `FinBERTSentiment` is "FinBERT".
        :param version: the version of FinBERT to use to assign the labels to the text.
        """
        super().__init__(name=name)

        # Initialize the `finbert_model` class attribute.
        self.finbert_model = pipeline(
            task="sentiment-analysis", device="cuda",
            model=BertModel.from_pretrained(version),
            tokenizer=BertTokens.from_pretrained(version),
        )

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            prediction = self.finbert_model([sentence])[0]
            prob, label = prediction["score"], prediction["label"]

            if label.lower() == "negative":
                return -1  # Return negative.
            elif label.lower() == "neutral":
                return 0  # Return neutral.
            elif label.lower() == "positive":
                return 1  # Return positive.


class SigmaSentiment(Sentiment):

    def __init__(self, name: str = "SigmaFSA", version: str = "Sigma/financial-sentiment-analysis"):
        """
        Instantiates a `SigmaSentiment` model. This model uses FinancialBERT to label text.

        :param name: the name of the model. The default for `SigmaSentiment` is "SigmaFSA".
        """
        super().__init__(name=name)

        # Initialize the `sigma_sent_model` class attribute.
        self.sigma_sent_model = pipeline(
            task="sentiment-analysis",
            device="cuda", model=version
        )

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            prediction = self.sigma_sent_model([sentence])[0]
            prob, label = prediction["score"], prediction["label"]

            if label.lower() == "label_0":
                return -1  # Return negative.
            elif label.lower() == "label_1":
                return 0  # Return neutral.
            elif label.lower() == "label_2":
                return 1  # Return positive.


class SentimentOutputParser(BaseOutputParser[int]):

    def parse(self, text: str) -> int:
        """
        Method used to parse the ChatGPT response.

        :param text: the answer provided by ChatGPT.
        :return: an integer classification {-1, 0, 1}
        """
        text = text.lower().strip()

        if "pos" in text:
            return 1  # Return positive.
        elif "neg" in text:
            return -1  # Return negative.

        return 0  # Return negative.


class LangChainSentiment(Sentiment, metaclass=abc.ABCMeta):

    def __init__(self, name: str = "LangChain"):
        """
        Instantiates a `LangChainSentiment` model. This model uses ChatGPT to label text.

        :param name: the name of the model. The default for `ChatGPTSentiment` is "ChatGPT".
        """
        self.model = None

        super().__init__(name=name)

        self.parser = SentimentOutputParser()

        self.sentiment_template = """
            You are a sentiment classification AI.

                1. You are only able to reply with POS, NEU, or NEG.
                2. When you are unsure of the sentiment, you MUST reply with NEU.

            Here are some examples of correct replies:

                Story: Unexpected demand for the new XYZ product is likely to boost earnings.
                Reply: POS

                Story: Sales of the new XYZ product are inline with projections.
                Reply: NEU

                Story: Company releases profit warning after the sales of XYZ disappoint.
                Reply: NEG

                Story: Following better than expected job numbers, the stock market rallied.
                Reply: POS

                Story: The stock market ended flat after job numbers came in as expected.
                Reply: NEU

                Story: A large spike in unemployment numbers sent the stock market into panic.
                Reply: NEG

                Story: XYZ stock soared after the FDA approved its new cancer treatment.
                Reply: POS

                Story: XYZ will announce results of its cancer treatment on the 15th of July.
                Reply: NEU

                Story: Following poor results, the FDA shuts down trials of XYZ cancer treatment.
                Reply: NEG

            Okay, now please classify the following news story:

                Story: {story}
                Reply: 
        """


class ChatGPTSentiment(LangChainSentiment):

    def __init__(self, name: str = "ChatGPT", version: str = "gpt-3.5-turbo"):
        """
        Instantiates a `ChatGPTSentiment` model. This model uses ChatGPT to label text.

        :param name: the name of the model. The default for `ChatGPTSentiment` is "ChatGPT".
        :param version: the name of the LLM in OpenAI that we want to use for labelling.
        """
        super().__init__(name=name)

        self.model = ChatOpenAI(
            openai_api_key=os.environ["OPENAI_API_KEY"],
            model=version
        )

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        with get_openai_callback() as cb:
            # Format the prompt with the story and then run the chain.
            formatted_prompt = self.sentiment_template.format(story=sentence)
            prompt = PromptTemplate.from_template(formatted_prompt)
            chain = prompt | self.model | SentimentOutputParser()
            output = chain.invoke({"text": formatted_prompt})

            # Update the total cost with the data from the callback.
            self.total_cost = self.total_cost + cb.total_cost

            return output


class VertexSentiment(LangChainSentiment):

    def __init__(self, name: str = "Vertex", version: str = "gemini-pro", credentials=None):
        """
        Instantiates a `VertexSentiment` model. This model uses Gemini to label text.

        :param name: the name of the model. The default for `VertexSentiment` is "Vertex".
        :param version: the name of the LLM in OpenAI that we want to use for labelling.
        """
        super().__init__(name=name)

        self.model = VertexAI(
            model_name=version,
            credentials=credentials
        )

        self.version = version
        self.chars_in = 0
        self.chars_out = 0

        self.version_to_costs = {
            "text-bison": {
                "char_in": 0.00025 / 1000,
                "char_out": 0.0005 / 1000
            },
            "text-unicorn": {
                "char_in": 0.0025 / 1000,
                "char_out": 0.0075 / 1000
            },
            "gemini-pro": {
                "char_in": 0.00025 / 1000,
                "char_out": 0.0005 / 1000
            }
        }

    def sentiment_label(self, sentence: str) -> int:
        """
        This method will return a sentiment label of either -1 (negative), 0 (neutral) or 1 (positive).

        :param sentence: the sentence we want to quantify the sentiment of.
        :return: the sentiment (a value of either -1, 0, or 1).
        """
        # Format the prompt with the story and then run the chain.
        formatted_prompt = self.sentiment_template.format(story=sentence)
        prompt = PromptTemplate.from_template(formatted_prompt)
        chain = prompt | self.model | SentimentOutputParser()
        response = chain.invoke({"text": formatted_prompt})

        # Update the estimate of input and output chars used.
        self.chars_in = self.chars_in + len(prompt.template)
        self.chars_out = self.chars_out + 3

        # Estimate the cost incurred using this model.
        cost_table = self.version_to_costs[self.version]
        chars_in_est_cost = self.chars_in * cost_table["char_in"]
        chars_out_est_cost = self.chars_out * cost_table["char_out"]
        self.total_cost = chars_in_est_cost + chars_out_est_cost

        return response


if __name__ == "__main__":

    gcp_credentials = load_credentials_from_file('YOUR_GCP_AUTH_FILE.json')[0]

    sort_against = "Human"

    dataset = "small/nosible-news-small"
    with open(f"{dataset}.json", "r") as f:
        news_sample = json.load(f)

    # ==================================================================================================================
    # EXTRACT THE SENTIMENT HEURISTIC FROM EACH MODEL FOR EACH NEWS STORY.
    # ==================================================================================================================

    classifiers = {
        "Human": HumanSentiment(),
        "Random": RandomSentiment(),

        "TextBlob-0.10": TextBlobSentiment(name="TextBlob-0.10", decision_boundary=0.10),
        "TextBlob-0.20": TextBlobSentiment(name="TextBlob-0.20", decision_boundary=0.20),
        "TextBlob-0.30": TextBlobSentiment(name="TextBlob-0.30", decision_boundary=0.30),

        "VADER-0.10": VaderSentiment(name="VADER-0.10", decision_boundary=0.10),
        "VADER-0.20": VaderSentiment(name="VADER-0.20", decision_boundary=0.20),
        "VADER-0.30": VaderSentiment(name="VADER-0.30", decision_boundary=0.30),

        "Flair-0.50": FlairSentiment(name="Flair-0.50", decision_boundary=0.70),
        "Flair-0.75": FlairSentiment(name="Flair-0.75", decision_boundary=0.80),
        "Flair-0.95": FlairSentiment(name="Flair-0.95", decision_boundary=0.90),

        "SigmaFSA": SigmaSentiment(name="Sigma", version="Sigma/financial-sentiment-analysis"),
        "FinBERT": FinBERTSentiment(name="FinBERT", version="ProsusAI/finbert"),
        "FinBERT-Tone": FinBERTSentiment(name="FinBERT-Tone", version="yiyanghkust/finbert-tone"),

        "Text-Bison": VertexSentiment(name="Text-Bison", version="text-bison", credentials=gcp_credentials),
        "Text-Unicorn": VertexSentiment(name="Text-Unicorn", version="text-unicorn", credentials=gcp_credentials),
        "Gemini-Pro": VertexSentiment(name="Gemini-Pro", version="gemini-pro", credentials=gcp_credentials),

        "GPT-3.5-Turbo": ChatGPTSentiment(name="GPT-3.5-Turbo", version="gpt-3.5-turbo"),
        "GPT-4-Turbo": ChatGPTSentiment(name="GPT-4-Turbo", version="gpt-4-1106-preview"),
        "GPT-4-Original": ChatGPTSentiment(name="GPT-4-Original", version="gpt-4"),
    }

    keys = list(news_sample.keys())
    news_labels = {}

    with (
        tqdm(total=len(keys), ncols=80, file=sys.stdout) as bar,
        ThreadPoolExecutor(max_workers=30) as tpe
    ):

        for ix, key in enumerate(keys):

            # Get the story for this key.
            story = news_sample[key]
            news_labels[key] = {}

            # Extract the label of the sentiment expressed in the sentence using each model.
            story = " ".join((story["Headline"] + " . " + story["Description"]).split())
            futures = {m.name: tpe.submit(m.wrapped_sentiment_label, story) for m in classifiers.values()}

            for model, future in futures.items():

                try:
                    # Get the result from the thread.
                    news_labels[key][model] = 0
                    result = future.result(timeout=5)
                    news_labels[key][model] = result

                except Exception:
                    # Print out the caught exception.
                    print(traceback.format_exc())

            # Print the results we got back from the model.
            bar.update(n=1)

            if ix % 250 == 0:
                print("Caching the results just in case ...")
                with open(f"{dataset}-labels.json", "w+") as f:
                    simplejson.dump(news_labels, f, indent=4)

    with open(f"{dataset}-labels.json", "w+") as f:
        simplejson.dump(news_labels, f, indent=4)

    # ==================================================================================================================
    # CALCULATE THE PERCENTAGE OF TIMES THAT EACH MODEL AGREES WITH ONE ANOTHER.
    # ==================================================================================================================

    # Convert the sentiment benchmarks into a pandas DataFrame.
    df = pd.DataFrame.from_dict(data=news_labels, orient="index")
    df.to_csv(f"{dataset}-labels.csv")

    # Calculate how often each of the models agrees with each other.
    matches = np.zeros((len(df.columns), len(df.columns)))
    for i, model_i in enumerate(df.columns):
        for j, model_j in enumerate(df.columns):
            i_labels = df[model_i].values
            j_labels = df[model_j].values
            same: np.ndarray = i_labels == j_labels
            matches[i, j] = same.sum() / len(j_labels)

    # Print out the comparison matrix that shows how often they agree.
    comp = pd.DataFrame(matches, index=df.columns, columns=df.columns)
    vs_sort_against: pd.Series = comp[sort_against].copy()
    vs_sort_against.sort_values(inplace=True, ascending=False)
    best_to_worst = list(vs_sort_against.index)
    comp = comp.reindex(columns=best_to_worst, index=best_to_worst)
    comp.to_csv(f"{dataset}-benchmark.csv")

    # ==================================================================================================================
    # EXTRACT AND SAVE THE TIME TAKEN AND THE COST INCURRED OF EACH OF THE MODELS.
    # ==================================================================================================================

    cost_and_time = {}
    for name, classifier in classifiers.items():
        cost_and_time[name] = {
            "Total Cost": classifier.total_cost,
            "Total Runtime": classifier.total_runtime
        }

    # Convert the sentiment benchmarks into a pandas DataFrame.
    df = pd.DataFrame.from_dict(data=cost_and_time, orient="index")
    df.to_csv(f"{dataset}-costs.csv")
