import { FilesetResolver, FaceDetector } from "@mediapipe/tasks-vision";

import { VideoLoader } from "./videoLoader";

/**
 * Options for face detection.
 * @property wasmBasePath - The base path for the media-pipe WebAssembly files.
 * @property [fps] - The number of frames per second to process. Default is 30.
 * @property [normalize] - Whether to normalize the face boxes to the range [0, 1]. Default is true.
 * @property [onProgress] - An optional callback function to report progress.
 * @property [signal] - An optional AbortSignal to cancel the operation.
 */
export interface FaceDetectionOptions {
  wasmBasePath: string;
  fps?: number;
  normalize?: boolean;
  onProgress?: (currentFrame: number, totalFrames: number) => void;
  signal?: AbortSignal;
}

export interface FaceBox {
  x: number;
  y: number;
  width: number;
  height: number;
}

export interface FrameFaces {
  timestamp: number;
  faces: FaceBox[];
}

const FACE_DETECTOR_MODEL_BLAZE = `https://storage.googleapis.com/mediapipe-models/face_detector/blaze_face_short_range/float16/1/blaze_face_short_range.tflite`;
const DEFAULT_FPS = 30;

/**
 * Detect faces in a video file.
 * @param videoUrl - The URL of the video file.
 * @param options - The options for face detection.
 * @returns A promise that resolves to an array of face boxes where each element describes the
 * detected faces for a frame.
 */
export async function detectFaces(videoUrl: string | URL, options: FaceDetectionOptions) {
  const vision = await FilesetResolver.forVisionTasks(options.wasmBasePath);
  const normalize = options.normalize ?? true;
  let videoReader: VideoLoader | null = null;
  let faceDetector: FaceDetector | null = null;
  try {
    videoReader = new VideoLoader(videoUrl);
    faceDetector = await FaceDetector.createFromOptions(vision, {
      baseOptions: {
        modelAssetPath: FACE_DETECTOR_MODEL_BLAZE,
        delegate: "GPU",
      },
      runningMode: "VIDEO",
    });
    await videoReader.flush();
    const videoWidth = videoReader.videoWidth;
    const videoHeight = videoReader.videoHeight;
    if (videoWidth <= 0 || videoHeight <= 0 || videoReader.duration <= 0) {
      throw new Error("Invalid video file");
    }
    const widthDivisor = normalize ? videoWidth : 1;
    const heightDivisor = normalize ? videoHeight : 1;
    const fps = options.fps ?? DEFAULT_FPS;
    const numFrames = Math.ceil(videoReader.duration * fps);
    const frameFaces: FrameFaces[] = [];
    for (let i = 0; i < numFrames; i++) {
      options.signal?.throwIfAborted();
      const timestamp = Math.min(i / fps, videoReader.duration);
      videoReader.currentTime = timestamp;
      const result = faceDetector.detectForVideo(videoReader.videoImage, timestamp * 1000);
      const faces: FaceBox[] = result.detections.flatMap((detection) => {
        if (!detection.boundingBox) {
          return [];
        }
        const { originX, originY, width, height } = detection.boundingBox;
        return [
          {
            x: originX / widthDivisor,
            y: originY / heightDivisor,
            width: width / widthDivisor,
            height: height / heightDivisor,
          },
        ];
      });
      frameFaces.push({ timestamp, faces });
      options.onProgress?.(i, numFrames);
      // This is necessary to allow the UI to remain responsive
      await new Promise((resolve) => setTimeout(resolve, 0));
    }
    return frameFaces;
  } finally {
    videoReader?.destroy();
    faceDetector?.close();
  }
}
