const hull = require("hull.js");
const simplify = require("simplify-js");

import React, { useContext, useEffect, useState, FC } from "react";
import { InferenceSession, Tensor } from "onnxruntime-web";
import { modelData } from "../components/helpers/onnxModelAPI";
import { modelScaleProps } from "../components/helpers/Interfaces";

const ort = require("onnxruntime-web");
/* @ts-ignore */
import npyjs from "npyjs";
import {
  arrToPoints,
  cropImage,
  findBoundingRectangle,
  float32ArrayToBinaryMask,
} from "../components/helpers/maskUtils";
import { S3_BUCKET_PREFIX } from "../constants";
import { canvasToFile } from "../utils/masks";
import { rotateImageDueToWidth } from "../api/masks";

// Define image, embedding and model paths
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";

export interface modelInputProps {
  x: number;
  y: number;
  clickType: number;
}

// Helper function for handling image scaling needed for SAM
const handleImageScale = (w: number, h: number) => {
  // Input images to SAM must be resized so the longest side is 1024
  const LONG_SIDE_LENGTH = 1024;
  const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
  return samScale;
};

interface IUseModelProps {
  w: number;
  h: number;
  imageEmbeddingUrl: string | null;
  persistMask: (imageFile: File, polygon: any) => string | undefined;
}

export const useModel = ({
  w,
  h,
  imageEmbeddingUrl,
  persistMask,
}: IUseModelProps) => {
  const [isModelLoaded, setIsModelLoaded] = useState(false);

  const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
  const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor
  // The ONNX model expects the input to be rescaled to 1024.
  // The modelScale state variable keeps track of the scale values.
  const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);

  useEffect(() => {
    console.log(w, h);
    if (!imageEmbeddingUrl || w === 0 || h === 0) {
      return;
    }
    // Initialize the ONNX model
    const initModel = async () => {
      try {
        if (MODEL_DIR === undefined) return;
        const URL: string = MODEL_DIR;
        const model = await InferenceSession.create(URL);
        setModel(model);
        setIsModelLoaded(true); // need to check for tensor loading as well
      } catch (e) {
        console.log(e);
      }
    };
    initModel();

    const samScale = handleImageScale(w, h);
    setModelScale({
      height: h, // original image height
      width: w, // original image width
      samScale: samScale, // scaling factor for image which has been resized to longest side 1024
    });

    const imageNpyUrl = `${S3_BUCKET_PREFIX}/${imageEmbeddingUrl}`;
    // Load the Segment Anything pre-computed embedding
    Promise.resolve(loadNpyTensor(imageNpyUrl, "float32")).then((embedding) =>
      setTensor(embedding)
    );
  }, [w, h, imageEmbeddingUrl]);

  // Decode a Numpy file into a tensor.
  const loadNpyTensor = async (tensorFile: string, dType: string) => {
    let npLoader = new npyjs();
    const npArray = await npLoader.load(tensorFile);
    const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
    return tensor;
  };

  const sendClickToModel = async (
    click: modelInputProps,
    image: HTMLImageElement
  ) => {
    //  If the click is on an existing mask in masks,
    //  1. Remove the mask from the list
    //  2. Exit

    //  Else:
    //  1. Run ONNX on the click
    //  2. Add the new Mask
    try {
      if (
        model === null ||
        click === null ||
        tensor === null ||
        modelScale === null
      )
        return;
      else {
        // in the future, sam can take multiple clicks
        const clicks = [click];
        const feeds = modelData({
          clicks,
          tensor,
          modelScale,
        });
        if (feeds === undefined) return;
        // Run the SAM ONNX model with the feeds returned from modelData()
        const results = await model.run(feeds);
        const output = results[model.outputNames[0]];
        const width = output.dims[2];
        const height = output.dims[3];
        // Usage example
        const floatArray = output.data; // new Float32Array(/* Your Float32Array data */);

        const binaryMask = float32ArrayToBinaryMask(floatArray, width, height);
        const points = arrToPoints(floatArray, width, height);
        const convexHull = hull(points, 80); // Adjust the second parameter (concavity) as needed
        const pointsToSimplify = convexHull.map((point: any[]) => {
          return { x: point[0], y: point[1] };
        });
        const simplifiedPolygonMask = simplify(pointsToSimplify, 1); // Adjust the second parameter (tolerance) as needed

        const boundingRectangle = findBoundingRectangle(binaryMask);
        const [x, y, cropWidth, cropHeight] = boundingRectangle;
        const croppedMaskImageElement = cropImage(
          image,
          x,
          y,
          cropWidth,
          cropHeight,
          binaryMask
        );

        const croppedImage = await canvasToFile(
          croppedMaskImageElement,
          "cropped_image.png"
        );
        let imageFile = croppedImage;
        if (cropWidth > cropHeight) {
          imageFile = (await rotateImageDueToWidth(croppedImage)) as File;
        }
        return persistMask(imageFile, simplifiedPolygonMask);
      }
    } catch (e) {
      console.log(e);
      return "";
    }
  };
  return { isModelLoaded, sendClickToModel };
};
