import React, { useCallback, useEffect, useMemo, useState } from 'react';

import { EmbeddingMetrics, EmbeddingPlots } from 'domain/model.types';
import { AggLevel } from 'domain/stats.types';
import { CFRole } from 'domain/general.types';

import useModelId from 'views/model/hooks/useModelId';

import { getEmbeddingsPlots } from 'services/model/model.repo';

import { CFLineChart, ScaleType } from 'components/charts/CFLineChart';
import CFButton from 'components/buttons/CFButton';
import { CFScatterPlotChart } from 'components/charts/CFScatterPlotChart';

import { generateSeries } from './helpers';

import { scoreDescription } from './constants';

interface Props {
  metrics: EmbeddingMetrics;
}

const buttonTabs = [
  { label: 'Scores', value: 'scores' },
  { label: 'Plot', value: 'plot' },
] as const;

type ButtonTabs = (typeof buttonTabs)[number]['value'];

import './embeddings.scss';

const EmbeddingMetricsView = ({ metrics }: Props) => {
  const [currentButtonTab, setCurrentButtonTab] = useState<ButtonTabs>(buttonTabs[0].value);

  const [plots, setPlots] = useState<EmbeddingPlots>({
    pca: { x: [], y: [], index: [], legend: [] },
    tsne: { x: [], y: [], index: [], legend: [] },
  });

  const modelId = useModelId();

  useEffect(() => {
    if (!modelId) {
      return;
    }

    (async () => {
      const plots = await getEmbeddingsPlots(modelId);
      setPlots(plots);
    })();
  }, [modelId]);

  const handleCurrentButton = useCallback(
    (buttonTab: ButtonTabs) => () => {
      setCurrentButtonTab(buttonTab);
    },
    []
  );

  const series = useMemo(() => {
    if (!metrics) {
      return [];
    }

    return [
      {
        items: (metrics.embedding.rocauc || []).map((timepoint) => ({ time: timepoint.t, value: timepoint.v })),
        name: 'Link Prediction Performance',
      },
    ];
  }, [metrics]);

  const pcaSeries = useMemo(() => generateSeries(currentButtonTab, plots.pca), [currentButtonTab]);
  const tsneSeries = useMemo(() => generateSeries(currentButtonTab, plots.tsne), [currentButtonTab]);

  return (
    <div className="embeddings-metrics">
      <div className="embeddings-metrics__controls">
        {buttonTabs.map(({ label, value }) => (
          <CFButton
            key={value}
            value={label}
            onClick={handleCurrentButton(value)}
            role={currentButtonTab === value ? CFRole.Cyan : CFRole.Secondary}
          />
        ))}
      </div>

      {currentButtonTab === 'scores' && (
        <CFLineChart
          scale={ScaleType.Linear}
          yLabel={''}
          units={''}
          title={'Link Prediction Performance'}
          data={[...series]}
          aggregationLevel={AggLevel.Day}
          isLoading={false}
          showLegend={false}
          description={scoreDescription}
        ></CFLineChart>
      )}

      {currentButtonTab === 'plot' && (
        <div className="embeddings-metrics__plots">
          <CFScatterPlotChart
            title={'PCA'}
            showLegend={true}
            series={pcaSeries}
            xLabel={``}
            yLabel={``}
            step={20}
            square={false}
            size={500}
            expandable={true}
          />

          <CFScatterPlotChart
            title={'TSNE'}
            showLegend={true}
            series={tsneSeries}
            xLabel={``}
            yLabel={``}
            step={20}
            square={false}
            size={500}
          />
        </div>
      )}
    </div>
  );
};

export default EmbeddingMetricsView;
