import { useState, useEffect, useCallback } from 'react';
import ApexCharts from 'react-apexcharts';
import {
  ArrowLongUpIcon,
  ArrowLongDownIcon,
  XMarkIcon,
} from '@heroicons/react/24/solid';
import toast from 'react-hot-toast';

import Spinner from '../../Spinner';
import MonthPicker from './MonthPicker';
import { useUser } from '../../../UserContext';
import { getPricePerToken } from '../../../utils/generalUtils';
import { prepareChartData } from './prepareChartData';

const TokenUsageChart = ({ model, setModelNotFound, closeModel }) => {
  const { customAxios } = useUser();
  const costPerToken = getPricePerToken(model.model_config.base_model);
  const [loading, setLoading] = useState(true);
  const [errorNoModel, setErrorNoModel] = useState(false);
  const [totalTokens, setTotalTokens] = useState(0);
  const [_, setTotalTokensLastMonth] = useState(0);
  const [chartData, setChartData] = useState([]);
  const [chartSeries, setChartSeries] = useState([]);
  const [percentageChange, setPercentageChange] = useState(0);
  const [costChange, setCostChange] = useState(0);

  const [displayDate, setDisplayDateState] = useState(() => {
    const currentMonth = new Date().getUTCMonth();
    const currentYear = new Date().getUTCFullYear();
    const previousMonth = currentMonth - 1 < 0 ? 11 : currentMonth - 1;
    const previousYear = currentMonth - 1 < 0 ? currentYear - 1 : currentYear;

    return {
      thisMonth: { month: currentMonth, year: currentYear },
      previousMonth: { month: previousMonth, year: previousYear },
    };
  });

  function getInitialDisplayDate() {
    const currentMonth = new Date().getUTCMonth();
    const currentYear = new Date().getUTCFullYear();
    const previousMonth = currentMonth - 1 < 0 ? 11 : currentMonth - 1;
    const previousYear = currentMonth - 1 < 0 ? currentYear - 1 : currentYear;

    return {
      thisMonth: { month: currentMonth, year: currentYear },
      previousMonth: { month: previousMonth, year: previousYear },
    };
  }

  useEffect(() => {
    setDisplayDateState(getInitialDisplayDate());
  }, [model]);

  const setDisplayDate = useCallback((newDisplayDate) => {
    setDisplayDateState((prevDisplayDate) => {
      if (
        prevDisplayDate.thisMonth.month !== newDisplayDate.thisMonth.month ||
        prevDisplayDate.thisMonth.year !== newDisplayDate.thisMonth.year ||
        prevDisplayDate.previousMonth.month !==
          newDisplayDate.previousMonth.month ||
        prevDisplayDate.previousMonth.year !== newDisplayDate.previousMonth.year
      ) {
        return newDisplayDate;
      }
      return prevDisplayDate;
    });
  }, []);

  const getDollars = (tokens, decimalPlaces = 2) => {
    return new Intl.NumberFormat('en-US', {
      style: 'currency',
      currency: 'USD',
      minimumFractionDigits: decimalPlaces,
    }).format(tokens);
  };

  useEffect(() => {
    const getModelUsageData = async () => {
      setLoading(true);
      setChartData([]);
      setChartSeries([]);
      setErrorNoModel(false);
      try {
        const endpoint =
          'type' in model && model.type === 'base_model'
            ? `tailor/v1/base_model/${model.model_id}`
            : `tailor/v1/models/${model.model_id}`;

        const response = await customAxios.get(endpoint, {
          params: {
            month: displayDate.thisMonth.month + 1, // Months are 0-indexed
            year: displayDate.thisMonth.year,
          },
        });

        if (response.status === 200) {
          const { usage_data, deployment_records } = response.data.message;

          const preparedData = prepareChartData(
            usage_data,
            deployment_records,
            displayDate.thisMonth,
            model.created_at_unix * 1000,
          );

          setChartData(preparedData);
          setChartSeries([{ name: 'Tokens', data: preparedData }]);

          const totalTokensThisMonth = preparedData.reduce(
            (acc, cur) => acc + (cur.y || 0),
            0,
          );
          setTotalTokens(totalTokensThisMonth);

          const responseLastMonth = await customAxios.get(endpoint, {
            params: {
              month: displayDate.previousMonth.month + 1,
              year: displayDate.previousMonth.year,
            },
          });

          if (responseLastMonth.status === 200) {
            const {
              usage_data: usageDataLastMonth,
              deployment_records: deploymentRecordsLastMonth,
            } = responseLastMonth.data.message;

            const preparedDataLastMonth = prepareChartData(
              usageDataLastMonth,
              deploymentRecordsLastMonth,
              displayDate.previousMonth,
              model.created_at_unix * 1000,
            );
            const totalTokensLastMonth = preparedDataLastMonth.reduce(
              (acc, cur) => acc + (cur.y || 0),
              0,
            );
            setTotalTokensLastMonth(totalTokensLastMonth);

            const percentageChange = totalTokensLastMonth
              ? ((totalTokensThisMonth - totalTokensLastMonth) /
                  totalTokensLastMonth) *
                100
              : 0;
            setPercentageChange(Number(percentageChange.toFixed(2)));

            const costChange =
              totalTokensThisMonth * costPerToken -
              totalTokensLastMonth * costPerToken;
            setCostChange(costChange);
          }
        } else if (response.status === 404) {
          setErrorNoModel(true);
          setModelNotFound(true);
        } else {
          throw new Error('Failed to fetch model usage data');
        }
      } catch (error) {
        console.error(error);
        toast.error('Failed to fetch model usage data, try again later.', {
          id: 'model-usage-fetch-error',
        });
      } finally {
        setLoading(false);
      }
    };
    getModelUsageData();
  }, [model, setModelNotFound, customAxios, displayDate, costPerToken]);

  const options = {
    chart: {
      animations: {
        enabled: false,
      },
      toolbar: {
        className: 'z-1',
        show: true,
        tools: {
          zoom: true,
          zoomin: true,
          zoomout: true,
          download: false,
          selection: true,
          pan: true,
          reset: true,
        },
      },
    },
    xaxis: {
      type: 'datetime', // Set the x-axis to use datetime values
      labels: {
        datetimeUTC: false,
      },
    },
    dataLabels: {
      enabled: false,
    },
    markers: {
      size: 2,
      strokeColors: '#4338CA',
      colors: '#4338CA',
      strokeWidth: 2,
      hover: {
        size: 4,
      },
    },
    yaxis: {
      labels: {
        formatter: function (value) {
          const valueString = (value / 1000).toLocaleString();
          return valueString + ' k';
        },
      },
    },
    colors: ['#818CF8'],
    tooltip: {
      custom: function ({ seriesIndex, dataPointIndex, w }) {
        const data = w.config.series[seriesIndex].data[dataPointIndex];
        const date = new Date(data.x);
        const formattedDate = date.toLocaleString('en-US', {
          // timeZone: 'UTC',
          month: 'short',
          day: 'numeric',
          hour: 'numeric',
          minute: 'numeric',
        });
        if (data.y === null) {
          return `<div class="bg-white border border-zinc-300 rounded-md shadow-md">
                    <div class="text-xs text-zinc-600 px-2 pt-2 pb-1 border-b border-zinc-300">
                      ${formattedDate}
                    </div>
                    <div class="text-sm text-zinc-600 bg-zinc-100 w-full px-2 pt-1">
                      Model Not Deployed
                    </div>
                  </div>`;
        } else {
          return `<div class="bg-white border border-zinc-300 rounded-md shadow-md">
                    <div class="text-xs text-zinc-600 px-2 pt-2 pb-1 border-b border-zinc-300">
                      ${formattedDate}

                    </div>
                    <div class="text-sm text-zinc-600 bg-zinc-100 w-full px-2 pt-1">
                      Tokens: ${data.y}
                    </div>
                    <div class="text-sm px-2 pt-1 pb-2 text-zinc-600 bg-zinc-100">
                      Cost: $${(data.y * costPerToken).toFixed(2)}
                    </div>
                  </div>`;
        }
      },
    },
    stroke: {
      width: 4,
      curve: 'smooth',
    },
    title: {
      text: 'Token Usage Per Hour',
      align: 'left',
    },
  };

  if (loading) {
    return (
      <div className="h-[37.125rem] flex items-center justify-around relative">
        <button
          className="absolute p-1 rounded-full -top-9 right-2 bg-zinc-100 text-zinc-800 md:hidden"
          onClick={closeModel}
        >
          <XMarkIcon className="w-5 h-5" />
        </button>
        <Spinner size={'36px'} borderTopColor={'gray'} />
      </div>
    );
  }

  if (errorNoModel) {
    return (
      <>
        <div className="h-[37.125rem] flex flex-col items-center justify-center text-sm md:text-base text-zinc-800 text-center relative">
          <button
            className="absolute p-1 rounded-full -top-9 right-2 bg-zinc-100 text-zinc-800 md:hidden"
            onClick={closeModel}
          >
            <XMarkIcon className="w-5 h-5" />
          </button>
          <p>Model not found. Please select a different model.</p>
          <p>If you believe this is an error, please contact support.</p>
        </div>
      </>
    );
  }

  if (chartData.length === 0) {
    return (
      <div className="h-[37.125rem] flex flex-col items-center justify-center text-zinc-800 text-center">
        <p>No usage data for the selected month.</p>
      </div>
    );
  }

  return (
    <>
      <div className="px-2 relative">
        <button
          className="absolute p-1 rounded-full -top-9 right-2 bg-zinc-100 text-zinc-800 md:hidden"
          onClick={closeModel}
        >
          <XMarkIcon className="w-5 h-5" />
        </button>

        <div className="flex flex-col flex-wrap items-start mb-4 gap-x-20 gap-y-2">
          <div className="text-xl font-bold text-zinc-800 max-w-[90%] break-words">
            {model.model_name}
          </div>
          <MonthPicker setDate={setDisplayDate} currentDate={displayDate} />
        </div>
        <ApexCharts
          options={options}
          series={chartSeries}
          height={250}
          type={
            chartData?.filter((entry) => entry.tokens !== null).length > 1
              ? 'area'
              : 'scatter'
          }
        />
        <div className="flex flex-wrap md:justify-center md:gap-8 md:flex-nowrap">
          <div className="flex items-baseline w-screen p-1 md:flex-col md:items-center md:justify-around gap-x-2 md:w-auto ">
            <div className="md:flex-1 md:pt-4 md:text-lg w-fit text-nowrap">
              Model Usage
            </div>
            <div className="text-xs md:text-sm text-zinc-600">This Month</div>
            <div className="text-xs md:pb-4 md:text-sm text-zinc-600">
              vs. Last Month
            </div>
          </div>
          <div className="relative w-1/2 p-2 border md:rounded-lg rounded-tl-md border-zinc-300 md:w-fit md:min-w-36 min-h-24">
            <div className="flex flex-col items-center justify-around h-full">
              <span
                className={`flex text-xs absolute top-1 right-1 ${
                  percentageChange > 0
                    ? '!text-lime-600'
                    : percentageChange < 0
                      ? 'text-rose-600'
                      : 'text-zinc-600'
                }`}
              >
                {percentageChange > 0 ? (
                  <ArrowLongUpIcon className="w-4 h-4" />
                ) : percentageChange < 0 ? (
                  <ArrowLongDownIcon className="w-4 h-4" />
                ) : null}
                {percentageChange === 0
                  ? '⎯'
                  : Math.abs(percentageChange) + '%'}
              </span>
              <span className="mt-3 font-bold md:text-lg">
                {totalTokens.toLocaleString()}
              </span>
              <div className="flex flex-col items-center justify-center text-xs text-center md:text-base text-zinc-700 grow">
                <div>Tokens Used</div>
                <div className="text-xs text-zinc-500">(this month)</div>
              </div>
            </div>
          </div>
          <div className="relative w-1/2 p-2 border-r md:border border-y md:rounded-lg rounded-tr-md border-zinc-300 md:min-w-36 md:w-fit min-h-24">
            <div className="flex flex-col items-center justify-around h-full">
              <span
                className={`text-xs flex absolute top-1 right-1 ${
                  costChange > 0
                    ? '!text-lime-600'
                    : costChange < 0
                      ? 'text-rose-600'
                      : 'text-zinc-600'
                }`}
              >
                {costChange > 0 ? (
                  <ArrowLongUpIcon className="w-4 h-4" />
                ) : costChange < 0 ? (
                  <ArrowLongDownIcon className="w-4 h-4" />
                ) : null}
                {getDollars(Math.abs(costChange))}
              </span>
              <span className="mt-3 font-bold md:text-lg">
                {getDollars(totalTokens * costPerToken)}
              </span>
              <div className="flex flex-col items-center justify-center text-xs text-center md:text-base text-zinc-700 grow">
                <div>Cost</div>
                <div className="text-xs text-zinc-500">(this month)</div>
              </div>
            </div>
          </div>
          <div className="relative w-1/2 p-2 border-b md:border border-x md:rounded-lg rounded-bl-md border-zinc-300 md:w-fit md:min-w-36 min-h-24">
            <div className="flex flex-col items-center justify-around h-full">
              <span className="mt-3 font-bold md:text-lg">
                {getDollars(costPerToken * 1_000_000)}
              </span>
              <span className="flex flex-col items-center justify-center text-xs text-center md:text-base text-zinc-700 grow">
                <div>Cost per</div>
                <div>1M tokens</div>
              </span>
            </div>
          </div>
          <div className="relative w-1/2 p-2 border-b border-r rounded-br-md border-zinc-300 md:hidden"></div>
        </div>
      </div>
    </>
  );
};

export default TokenUsageChart;
