import * as tf from "@tensorflow/tfjs-core";
import { loadGraphModel } from "@tensorflow/tfjs-converter";
import { registerOps } from "./tf_utils";
import modelMetadata from "../model.json";
import "@tensorflow/tfjs-backend-webgl";
const localStorageUrl = "indexeddb://blur_detection/";
const modelName = "model_" + modelMetadata.version;
const modelLocalUrl = localStorageUrl + modelName;
let model = null,
  cropAndResize = null;
export async function loadModel(path, cache = false) {
  registerOps();
  model = await loadGraphModel(modelLocalUrl).catch(() => loadGraphModel(`${path}${modelName}/model.json`).then(m => {
    if (cache) m.save(modelLocalUrl);
    return m;
  }));
  setupCropAndResize();
  modelsCleanup();
}
async function modelsCleanup() {
  const models = await tf.io.listModels();
  for (const path in models) {
    if (path.startsWith(localStorageUrl) && path != modelLocalUrl) {
      await tf.io.removeModel(path);
    }
  }
}
function setupCropAndResize() {
  const inputDims = Object.values(model.signature.inputs)[0].tensorShape.dim;
  const height = inputDims[0].size,
    width = inputDims[1].size;
  const canvas = document.createElement("canvas");
  canvas.width = width;
  canvas.height = height;
  const ctx = canvas.getContext("2d");
  const targetRatio = width / height;
  cropAndResize = img => {
    const {
      naturalWidth,
      naturalHeight
    } = img;
    const cropWidth = naturalHeight * targetRatio;
    const cropHeight = naturalWidth / targetRatio;
    ctx.drawImage(img, Math.max((naturalWidth - cropWidth) / 2, 0), Math.max((naturalHeight - cropHeight) / 2, 0), Math.min(cropWidth, naturalWidth), Math.min(cropHeight, naturalHeight), 0, 0, width, height);
    return canvas;
  };
}
export function predictRaw(img) {
  if (model == null) throw new Error("Model is not loaded!");
  return tf.tidy(() => {
    const input = tf.browser.fromPixels(cropAndResize(img), 3);
    return model.predict(input).arraySync();
  });
}
export function predictBlur(img, sharpThresh = 0.2) {
  const [sharpP, defocusedP, motionP] = predictRaw(img);
  if (sharpP >= sharpThresh) return {
    cat: "SHARP",
    prob: sharpP
  };else if (defocusedP >= motionP) return {
    cat: "DEFOCUSED",
    prob: defocusedP
  };else return {
    cat: "MOTION",
    prob: motionP
  };
}