import {
  FilesetResolver,
  ImageSegmenter,
  type ImageSegmenterOptions,
} from "@mediapipe/tasks-vision";

type AIModelId = "selfieSegmenter" | "multiClassSegmenter";

export interface AIModelConfig {
  url: string;
  version: string;
  inputShape: {
    width: number;
    height: number;
  };
  labels: string[];
}

export type AIModels = Record<AIModelId, AIModelConfig>;

export const aiModels: AIModels = {
  selfieSegmenter: {
    url: "https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_segmenter/float16/latest/selfie_segmenter.tflite",
    version: "v1",
    inputShape: {
      width: 256,
      height: 256,
    },
    labels: ["background", "person"],
  },
  multiClassSegmenter: {
    url: "https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite",
    version: "v1",
    inputShape: {
      width: 256,
      height: 256,
    },
    labels: ["background", "hair", "body-skin", "face-skin", "clothes", "others"],
  },
};

interface WasmConfig {
  wasmPath: string;
  task: "vision" | "audio" | "text" | "gen-ai" | "gen-ai-experimental";
}

export const VISION_TASKS_WASM_PATH =
  "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.14/wasm";

export async function createImageSegmenter(options: ImageSegmenterOptions) {
  const vision = await getWasmFileSet({
    wasmPath: VISION_TASKS_WASM_PATH,
    task: "vision",
  });

  const imageSegmenter = await ImageSegmenter.createFromOptions(vision, options);
  return imageSegmenter;
}

export async function getWasmFileSet({ task, wasmPath }: WasmConfig) {
  switch (task) {
    case "vision":
      return FilesetResolver.forVisionTasks(wasmPath);
    case "audio":
      return FilesetResolver.forAudioTasks(wasmPath);
    case "text":
      return FilesetResolver.forTextTasks(wasmPath);
    case "gen-ai":
      return FilesetResolver.forGenAiTasks(wasmPath);
    case "gen-ai-experimental":
      return FilesetResolver.forGenAiExperimentalTasks(wasmPath);
    default:
      throw new Error(`Unknown task: ${task}`);
  }
}
