import { PaginatedElement } from 'types';
import CFService from '../cfservice';

import { getArmMetrics } from 'services/model/model.arms.repo';

import { CFTraitRepository } from 'services/traits/traits.repo';

import { AlgoMeta, AppModel, Model, ModelId, ModelTag } from 'domain/model.types';
import { AuthAction, isAllowedTo } from '../authorization.service';
import { AlgorithmClass } from 'services/intervention/intervention.types';
import { AppTrait, Trait, TraitCategory, TraitSubject } from 'domain/traits.types';
import TraitService from '../traits/traitSession.service';
import { CohortID } from 'services/cohort/cohort.types';
import { ArmCluster, ArmInfo, ArmSensitivity } from './model.repo.types';
import { CFModelRepository } from './model.repo';
import { subscribe } from 'services/session/session.service';
import { KEY } from 'repositories/storage.localstorage';

export default class ModelService extends CFService {
  _name = 'modelService';
  models: AppModel[] = [];
  classmetas: AlgoMeta[] = [];
  algorithmClasses: AlgorithmClass[] = [];
  classMetaReady = false;
  pinnedModelId: ModelId | null = null;

  modelRepository: CFModelRepository;
  traitRepository: CFTraitRepository;
  traitService: TraitService;

  armInfo: Record<ModelId, ArmInfo> = {};
  initArmPromise: Record<ModelId, Promise<ArmInfo>> = {};

  isIncognito = false;

  constructor(modelRepository: CFModelRepository, traitRepository: CFTraitRepository, traitService: TraitService) {
    super();
    this.traitRepository = traitRepository;
    this.traitService = traitService;
    this.modelRepository = modelRepository;
  }

  private async initializeClassMeta() {
    this.classmetas = await this.modelRepository.getClassmeta();

    this.algorithmClasses = [...new Set(this.classmetas.map((classmeta) => classmeta.class_name))];

    this.classMetaReady = true;
  }

  async init() {
    subscribe(KEY.INCOGNITO, (value) => {
      this.isIncognito = value;
    });
  }

  private transformModel = (model: Model): AppModel => {
    return {
      ...model,
      removable: model.definition.purpose !== ModelTag.Intervention,
      pinned: (model.definition.tags || []).includes(ModelTag.Landing),
      published: model.definition.published === undefined ? true : model.definition.published,
      pinnable:
        model.definition?.algo_spec?.class_name === AlgorithmClass.Censoring && isAllowedTo(AuthAction.PinModel),
    };
  };

  private addInvId = async (model: AppModel): Promise<AppModel> => {
    let invId;
    try {
      invId = await this.modelRepository.getInvId(model.definition.id);
    } catch {
      invId = -1;
    }

    return {
      ...model,
      invId,
    };
  };

  async getRecommendModels(page: number, pageSize: number): Promise<PaginatedElement<AppModel>> {
    return this.getModels(page, pageSize, AlgorithmClass.Recommender);
  }

  async getModels(
    page: number,
    pageSize: number,
    className: AlgorithmClass | undefined = undefined
  ): Promise<PaginatedElement<AppModel>> {
    const seeUnpublished = isAllowedTo(AuthAction.SeeUnpublishedModel); /*&& !this.isIncognito*/

    const list = await this.modelRepository.get(page, pageSize, className, seeUnpublished);

    const extendedList = list.data.map(this.transformModel);

    this.models = extendedList;

    // assuming that page 0 will be requested always
    if (page === 0 && list.data[0] !== undefined && extendedList[0].pinned) {
      this.pinnedModelId = list.data[0].definition.id as ModelId;
    }

    return { total: list.total, data: extendedList };
  }

  async getAlgorithmClasses() {
    if (!this.classMetaReady) {
      await this.initializeClassMeta();
    }

    return this.algorithmClasses;
  }

  async getAlgorithmNames(modelClass: AlgorithmClass) {
    if (!this.classMetaReady) {
      await this.initializeClassMeta();
    }

    return this.classmetas.filter((meta) => meta.class_name === modelClass);
  }

  async getById(id: ModelId): Promise<AppModel> {
    const cachedModel = this.models.find((model) => model.definition.id.toString() === id.toString());

    if (cachedModel) {
      if (cachedModel.definition.algo_spec.class_name === AlgorithmClass.Bandit) {
        return this.addInvId(this.transformModel(cachedModel));
      }
      return cachedModel;
    }

    const model = await this.modelRepository.getById(id);

    if (model.definition.algo_spec.class_name === AlgorithmClass.Bandit) {
      return this.addInvId(this.transformModel(model));
    }

    return this.transformModel(model);
  }

  async pin(id: ModelId, pin: boolean): Promise<void> {
    if (pin && this.pinnedModelId !== null) {
      throw new Error('another-model-pinned');
    }

    await this.modelRepository.pin(`${id}`, pin);

    if (pin) {
      this.modelRepository.addTag(id, ModelTag.Landing);

      this.pinnedModelId = id;
    } else {
      this.modelRepository.removeTag(id, ModelTag.Landing);
      this.pinnedModelId = null;
    }
  }

  async getPinnedModel(): Promise<AppModel | undefined> {
    const list = await this.modelRepository.get(0, 1, undefined, true);

    if (list.total === 0) {
      return undefined;
    }

    if (list.data[0].definition.tags.includes(ModelTag.Landing)) {
      return this.transformModel(list.data[0]);
    } else {
      return undefined;
    }
  }

  async getAvailableTags(id: ModelId): Promise<ModelTag[]> {
    const model = await this.getById(id);

    if (model.definition.algo_spec.class_name === AlgorithmClass.Censoring) {
      return [ModelTag.Landing, ModelTag.ML];
    } else {
      return [ModelTag.Intervention];
    }
  }

  async addTag(id: ModelId, tag: ModelTag) {
    await this.modelRepository.addTag(id, tag);
    const cachedModel = this.models.find((model) => model.definition.id === id);

    if (cachedModel) {
      cachedModel.definition.tags.push(tag);
    }
  }

  async removeTag(id: ModelId, tag: ModelTag) {
    await this.modelRepository.removeTag(id, tag);

    const cachedModel = this.models.find((model) => model.definition.id === id);

    if (cachedModel) {
      const tagIndex = cachedModel.definition.tags.findIndex((currentTag) => currentTag === tag);

      cachedModel.definition.tags.splice(tagIndex, 1);
    }
  }

  async getUsers(modelId: ModelId, curPage: number, curPageSize: number): Promise<PaginatedElement<string>> {
    const users = await this.modelRepository.getUsers(modelId, curPage, curPageSize);

    return users;
  }

  async getMachineLearningTraits(): Promise<Trait[]> {
    // all machine learning traits (mltraits) are available
    // for all the
    const traits = await this.traitService.getTraits({ subject: TraitSubject.User, category: TraitCategory.MLT });
    return traits;
  }

  async getMLTraitsInCohort(cohortId: CohortID): Promise<AppTrait[]> {
    const traits = await this.traitRepository.getModelTraits(cohortId);

    return Object.keys(traits)
      .filter((traitCode) => this.traitService.getTraitDefinition(traitCode) !== undefined)
      .map((traitCode) => {
        const trait = this.traitService.getTraitDefinition(traitCode);

        return { ...trait, models: traits[traitCode] };
      });
  }

  async pause(modelId: ModelId) {
    await this.modelRepository.pause(modelId, true);
  }

  async resume(modelId: ModelId) {
    await this.modelRepository.pause(modelId, false);
  }

  async publish(modelId: ModelId) {
    await this.modelRepository.publish(modelId, true);
  }

  async unpublish(modelId: ModelId) {
    await this.modelRepository.publish(modelId, false);
  }

  async getArmClusters(modelId: ModelId): Promise<ArmCluster> {
    if (this.initArmPromise[modelId] === undefined) {
      this.initArmPromise[modelId] = getArmMetrics(modelId);
    }

    const armMetrics = await this.initArmPromise[modelId];
    return armMetrics.arm_cluster;
  }

  async getArmSensitivy(modelId: ModelId): Promise<ArmSensitivity> {
    if (this.initArmPromise[modelId] === undefined) {
      this.initArmPromise[modelId] = getArmMetrics(modelId);
    }

    const armMetrics = await this.initArmPromise[modelId];
    return armMetrics.sensitivity;
  }
}
