import {
    tidy,
    loadGraphModel,
    GraphModel,
    Tensor3D,
    image,
    Rank,
    Tensor,
} from "@tensorflow/tfjs"
import { IMAGENET_CLASSES } from "./imagenet_classes"

import type { TrashTypes } from "helpers/history/type"

export class AI {
    private model: GraphModel | undefined

    public async loadModel() {
        this.model = await loadGraphModel(`${process.env.PUBLIC_URL}/model.json`)
    }

    public dispose() {
        if (this.model) {
            this.model.dispose()
        }
    }

    private getClasses(
        logits: Float32Array | Int32Array | Uint8Array,
        topK: number,
    ): { label: typeof TrashTypes[number]; value: any }[] {
        const valuesAndIndices = []
        for (let i = 0; i < logits.length; i++) {
            valuesAndIndices.push({ value: logits[i], index: i })
        }
        valuesAndIndices.sort((a, b) => {
            return b.value - a.value
        })
        const topkValues = new Float32Array(topK)
        const topkIndices = new Int32Array(topK)
        for (let i = 0; i < topK; i++) {
            topkValues[i] = valuesAndIndices[i].value
            topkIndices[i] = valuesAndIndices[i].index
        }

        const topClassesAndProbs = []
        for (let i = 0; i < topkIndices.length; i++) {
            topClassesAndProbs.push({
                label: IMAGENET_CLASSES[topkIndices[i]],
                value: topkValues[i],
            })
        }
        return topClassesAndProbs
    }

    public async preprocess(image: Tensor3D) {
        return tidy(() => image.expandDims(0).toFloat())
    }

    public async predict(
        capture: Tensor3D,
    ): Promise<[{ label: typeof TrashTypes[number]; value: any }[], number] | undefined> {
        const resized = image.resizeBilinear(capture, [224, 224])
        const imageData = tidy(() => resized.expandDims(0).toFloat())

        let start = performance.now()
        const logits = (await this.model?.predict(imageData)) as Tensor<Rank>

        imageData.dispose()
        resized.dispose()

        if (!logits) {
            return undefined
        }

        const probs = await logits.data()

        const res = this.getClasses(probs, 2)
        console.log(res)
        let end = performance.now()
        const time = end - start

        return [res, time]
    }
}
