import '@tensorflow/tfjs-backend-cpu';

import {ImageClassificationModel, loadImageClassification} from '@tensorflow/tfjs-automl';
import {PhotoContentPredicter, Prediction} from './photo_content_predicter';

import type {ImagePrediction} from '@tensorflow/tfjs-automl/dist/img_classification';

export enum AutoMLVersion {
    DEFAULT = '20221130',
    MODEL_2021 = '20211120',
    CONSTRUCTIONAL_DEFECTS = '20210617',
}

export enum AutoMLBoundries {
    MINIMAL = 0.41,
}

export class AutoMLPhotoContentPredicter implements PhotoContentPredicter {
    constructor(
        private version = (process.env.MIX_FEATURE_ML_MODEL_VERSION as AutoMLVersion) ?? AutoMLVersion.DEFAULT
    ) {}

    private mountModelPromise: Promise<null> | null = null;
    private model: ImageClassificationModel | null = null;
    private initialisedModels: Partial<Record<AutoMLVersion, ImageClassificationModel>> = {};

    public async getPrediction(file: File | string): Promise<Prediction | null> {
        if (!this.modelIsMounted()) {
            try {
                await this.timeout(this.mountModel());
            } catch (e) {
                return null;
            }
        }
        if (this.model === null) {
            return null;
        }
        const img = await this.getImgElementFromFile(file);
        const classification = await this.model.classify(img, {centerCrop: true});
        return this.transformMLResultToPrediction(classification);
    }

    public async getVersionPrediction(version: AutoMLVersion, uploadedFile: File): Promise<Prediction | null> {
        const model = await this.mountVersion(version);
        const img = await this.getImgElementFromFile(uploadedFile);
        return this.transformMLResultToPrediction(await model.classify(img, {centerCrop: true}));
    }

    public async getVersionPredictions(version: AutoMLVersion, uploadedFile: File): Promise<Prediction[] | null> {
        const model = await this.mountVersion(version);
        const img = await this.getImgElementFromFile(uploadedFile);
        return this.reduceAutoMLPredictions(await model.classify(img, {centerCrop: true}));
    }

    private async getImgElementFromFile(file: File | string): Promise<HTMLImageElement> {
        return new Promise<HTMLImageElement>((resolve) => {
            const img = document.createElement('img');
            img.onload = () => resolve(img);
            img.src = typeof file === 'string' ? file : URL.createObjectURL(file);
        });
    }

    private reduceAutoMLPredictions(predictions: ImagePrediction[]): Prediction[] {
        const returnPredictions: Prediction[] = [];
        for (const prediction of predictions) {
            if (prediction.prob > AutoMLBoundries.MINIMAL) {
                returnPredictions.push(this.transformMLPredictionToPrediction(prediction));
            }
        }
        return returnPredictions;
    }

    private transformMLPredictionToPrediction(MLprediction: ImagePrediction): Prediction {
        return {
            className: '' + MLprediction.label,
            probability: MLprediction.prob ?? 0,
        };
    }

    private transformMLResultToPrediction(predictions: ImagePrediction[]): Prediction | null {
        return predictions
            .map((prediction) => ({
                className: '' + prediction.label,
                probability: prediction.prob,
            }))
            .reduce<null | Prediction>((p, c) => {
                return p === null || c.probability > p.probability ? c : p;
            }, null);
    }

    private async mountModel() {
        if (this.mountModelPromise === null) {
            // eslint-disable-next-line no-async-promise-executor
            this.mountModelPromise = new Promise(async (resolve) => {
                const modelVersion = this.version || AutoMLVersion.DEFAULT;
                this.model = await this.mountVersion(modelVersion);
                resolve(null);
            });
        }
        return this.mountModelPromise;
    }

    private async mountVersion(modelVersion: AutoMLVersion): Promise<ImageClassificationModel> {
        if (this.initialisedModels[modelVersion]) {
            return this.initialisedModels[modelVersion] as ImageClassificationModel;
        }
        const model = await loadImageClassification('/ml/automl/' + modelVersion + '/model.json');
        this.initialisedModels[modelVersion] = model;
        return model;
    }

    private modelIsMounted() {
        return this.model !== null;
    }

    public async prepare() {
        if (!this.modelIsMounted()) {
            await this.mountModel();
        }
    }

    private async timeout(promise: Promise<unknown>) {
        // eslint-disable-next-line no-async-promise-executor
        return new Promise(async (resolve, reject) => {
            const timeout = setTimeout(() => {
                reject();
            }, 2000);
            try {
                const result = await promise;
                clearTimeout(timeout);
                resolve(result);
            } catch (e) {
                reject(e);
            }
        });
    }
}
