import { InferenceSession, Tensor } from 'onnxruntime-web';
import { useState } from 'react';
import tinycolor from 'tinycolor2';

import { getEmbeddingUrl, getImageUrl } from '../misc/constants';
import { useAsyncEffect } from './useAsyncEffect';

import type { IPatientMedia } from '../misc/types';
import { getImageSize } from '../misc/utils';

export interface Box {
  x1: number;
  y1: number;
  x2: number;
  y2: number;
}

export type ModelParams = {
  embedding: EmbeddingData;
  prompts: Prompt[];
  bounds?: Box;
};

export interface EmbeddingData {
  tensor: Tensor;
  imageSize: [number, number];
  samScale: number; // scaled to longest side = 1024px
}

export interface Prompt {
  coords: [number, number];
  type: PromptType;
}

export enum PromptType {
  /** A special internal prompt type. Do not use manually! */
  Padding = -1,
  /**
   * Subtractive prompt type.
   * The corresponding prompt should be *removed* from the resulting mask.
   * Aka. background prompt.
   */
  Subtractive = 0,
  /**
   * Additive prompt type.
   * The corresponding prompt should be *included* in the resulting mask.
   * Aka. foreground prompt.
   */
  Additive = 1,
  /**
   * Box prompt is encoded as 2 coords with BoxTopLeft and BoxBottomRight labels.
   */
  BoxTopLeft = 2,
  /**
   * Box prompt is encoded as 2 coords with BoxTopLeft and BoxBottomRight labels.
   */
  BoxBottomRight = 3,
}

export interface Mask {
  /** The `Float32Array` holding all masks' data */
  data: Float32Array;
  /** Offset into `data`, b/c `data` stores all masks and creating slices has a noticeable performance impact */
  dataOffset: number;
  /** The size of the mask. All masks should be the same size as the original image. */
  size: [number, number];
}

export async function getEmbedding(filename: string) {
  const res = await fetch(getEmbeddingUrl(filename));
  const buff = new Uint8Array(await res.arrayBuffer());
  const view = new DataView(buff.buffer);

  const size = view.getUint8(0);
  if (size !== 4) throw Error("Unexpected embedding dimensionality");

  const shape = new Array<number>(size);
  for (let i = 0; i < size; ++i) {
    shape[i] = view.getUint16(1 + i * 2, false);
  }

  const dataOffset = 1 + size * 2;
  const data = new Float32Array(shape.reduce((a, b) => a * b, 1));
  for (let i = 0; i < data.length; ++i) {
    data[i] = view.getFloat32(dataOffset + i * 4, false);
  }

  return new Tensor("float32", data, shape);
}

export function getModelData({ prompts, bounds, embedding }: ModelParams) {
  const { tensor, samScale, imageSize: [imageWidth, imageHeight] } = embedding;
  const _prompts = prompts.slice();

  if (!bounds) {
    _prompts.push({ coords: [0, 0], type: PromptType.Padding });
  } else {
    _prompts.push({ coords: [bounds.x1, bounds.y1], type: PromptType.BoxTopLeft });
    _prompts.push({ coords: [bounds.x2, bounds.y2], type: PromptType.BoxBottomRight });
  }

  const promptsTensor = new Tensor(
    "float32",
    _prompts.flatMap((p) => p.coords.map((c) => c * samScale)),
    [1, _prompts.length, 2],
  );
  const labelsTensor  = new Tensor(
    "float32",
    _prompts.map((p) => p.type),
    [1, _prompts.length],
  );

  // previous mask input
  // TODO: but what does it mean..?
  const maskInput = new Tensor(
    "float32",
    new Float32Array(256 * 256),
    [1, 1, 256, 256],
  );
  const hasMaskInput = new Tensor("float32", [0]);

  return {
    image_embeddings: tensor,
    point_coords: promptsTensor,
    point_labels: labelsTensor,
    orig_im_size: new Tensor("float32", [imageHeight, imageWidth]),
    mask_input: maskInput,
    has_mask_input: hasMaskInput,
  };
}

/** Gets a data URL resembling the given mask in transparent/white */
export function getMaskDataUrl(mask: Mask): string {
  const canvas = document.createElement("canvas");
  canvas.width = mask.size[1];
  canvas.height = mask.size[0];
  const ctx = canvas.getContext("2d");
  ctx!.putImageData(getMaskImageData(mask, 'white'), 0, 0);
  return canvas.toDataURL("image/png");
}

/** Gets the `ImageData` corresponding to the given mask, with an optional color supplied. If no
 * color is supplied, chooses a random color. The resulting `ImageData` can be fed to a
 * `CanvasRenderingContext2D.putImageData` call.
 */
export function getMaskImageData(mask: Mask, color?: string, opacity = 1) {
  color = color ?? `hsl(${Math.random() * 360}, ${Math.random() * 100}%, ${Math.random() * 50 + 25}%)`;
  const [height, width] = mask.size;
  const {r, g, b} = tinycolor(color).toRgb();

  const pixels = new Uint8ClampedArray(width * height * 4);
  for (let i = 0; i < pixels.length; i++) {
    // mask confidence threshold
    if (mask.data[mask.dataOffset + i] > 0) {
      pixels[4 * i + 0] = r;
      pixels[4 * i + 1] = g;
      pixels[4 * i + 2] = b;
      pixels[4 * i + 3] = Math.floor(255 * opacity);
    }
  }

  return new ImageData(pixels, width, height);
}

export async function getMasks(model: InferenceSession, prompts: Prompt[], embedding: EmbeddingData) {
  // model takes ~0.5s on a lower end machine
  const results = await model.run(getModelData({ prompts, embedding }));
  const output = results[model.outputNames[0]];

  const size = [output.dims[2], output.dims[3]] as [number, number];
  return new Array(output.dims[1]).fill(0).map((_, i): Mask => ({
    data: output.data as Float32Array,
    dataOffset: i * output.dims[2] * output.dims[3],
    size,
  }));
}

export async function getMaskByBounds(model: InferenceSession, bounds: Box, embedding: EmbeddingData): Promise<Mask> {
  const results = await model.run(getModelData({ bounds, embedding, prompts: [] }));
  const output = results[model.outputNames[0]];

  const size = [output.dims[2], output.dims[3]] as [number, number];
  return {
    data: output.data as Float32Array,
    dataOffset: 0,
    size,
  };
}

export function useSAM() {
  const [model, setModel] = useState<InferenceSession | null>(null);
  useAsyncEffect(async () => {
    setModel(await InferenceSession.create("/models/sam-quantized.onnx"));
  }, []);
  return model;
}

export function useEmbedding(media: IPatientMedia) {
  const [data, setData] = useState<EmbeddingData | null>(null);

  useAsyncEffect(async () => {
    setData(null);
    if (!(media && media.photourl && media.sam_embedding)) return;

    const [tensor, [w, h]] = await Promise.all([
      getEmbedding(media.sam_embedding),
      getImageSize(getImageUrl(media.photourl)),
    ]);
    const samScale = 1024 / Math.max(w, h); // scale longest side to 1024px
    setData({
      tensor,
      imageSize: [w, h],
      samScale,
    });
  }, [media]);

  return data;
}
