import { Clip, getAbsoluteTimestamp } from "~/utils/videoClips";

import type { CenterRect, Rect, FaceSpan, ProjectFaceData } from "./face-data.types";

const DISABLE_SMOOTHING = false; // for local debug

const testColors = ["green", "purple", "white", "yellow", "violet", "blue", "cyan", "orange"];
const cycleColor = (id: number) => testColors[id % testColors.length];

const blendCenterRects = (r0: CenterRect, r1: CenterRect, alpha: number) => {
  return {
    x0: (1 - alpha) * r1.x0 + alpha * r0.x0,
    y0: (1 - alpha) * r1.y0 + alpha * r0.y0,
    dx: (1 - alpha) * r1.dx + alpha * r0.dx,
    dy: (1 - alpha) * r1.dy + alpha * r0.dy,
  };
};

const getFaceAtTimestamp = (faceSpan: FaceSpan, timestamp: number) => {
  const index1 = faceSpan.timestamps.findIndex((t: number) => t >= timestamp) ?? 0;
  const index0 = Math.max(0, index1 - 1);

  // no smoothing version at all
  if (DISABLE_SMOOTHING) {
    return faceSpan.faceRects[index0];
  }

  const t0 = faceSpan.timestamps[index0];
  const t1 = faceSpan.timestamps[index1];

  const s0 = faceSpan.smoothedRects![index0];
  const s1 = faceSpan.smoothedRects![index1];

  if (t0 === t1) {
    return s0;
  }

  const r0 = rectToCenterRect(s0);
  const r1 = rectToCenterRect(s1);

  const alpha = 1 - (timestamp - t0) / (t1 - t0);
  return centerRectToRect(blendCenterRects(r0, r1, alpha));
};

const rectToCenterRect = (rect: Rect) => {
  return {
    x0: rect.x + rect.w / 2,
    y0: rect.y + rect.h / 2,
    dx: rect.w / 2,
    dy: rect.h / 2,
  };
};

const centerRectToRect = (rect: CenterRect) => {
  return {
    x: rect.x0 - rect.dx,
    y: rect.y0 - rect.dy,
    w: 2 * rect.dx,
    h: 2 * rect.dy,
  };
};

// created 'smoothed' rect versions of raw race rects
// using a simple ewma
const smoothFaceSpan = (faceSpan: FaceSpan) => {
  if (!faceSpan.faceRects.length) {
    return;
  }

  // do rect operations on CenterRect format
  const centerRects = faceSpan.faceRects.map((r) => rectToCenterRect(r));
  const alpha = 0.8; // ewma tuning param
  const smoothedRects = [centerRects[0]];
  centerRects.forEach((r: CenterRect, i: number) => {
    if (i === 0) {
      return;
    }

    // smoothed rect is a linear blend of last smoothed rect and next raw rect
    const lastRect = smoothedRects.at(-1);
    const newRect = blendCenterRects(lastRect!, r, alpha);
    smoothedRects.push(newRect);
  });

  faceSpan.smoothedRects = smoothedRects.map((r) => centerRectToRect(r));
};

interface RectWithId extends Rect {
  id: number;
  scene: number;
  frameIndex?: number;
}

/**
 * FaceData processes raw ProjectFaceData and provides methods for common tasks
 *
 */
export class FaceData {
  scenes: number[];
  faceData: Map<number, RectWithId[]>;
  timestamps: number[];
  faceSpansById: Map<number, FaceSpan>;
  fps: number;
  clips: Clip[];

  constructor(faceData: ProjectFaceData, scenes: number[], clips: Clip[]) {
    this.scenes = scenes;
    this.faceData = new Map();
    this.fps = faceData.fps;
    this.timestamps = faceData.faceFrames.map((f, i) => i / this.fps);
    this.clips = clips;

    // Aggregate faces over time
    let faceId = 0;
    let prevFaces: RectWithId[] = [];
    let prevScene = -1;
    faceData.faceFrames.forEach((facesArray, frameNumber) => {
      const faces = facesArray.map((f) => ({ x: f[0], y: f[1], w: f[2], h: f[3] }));
      const ts = frameNumber / this.fps;
      const currentScene = this.getCurrentScene(ts);
      const facesWithId: RectWithId[] = faces.map((f) => {
        const newScene = prevScene !== currentScene;

        // Consider the rect a continuation of a previous rect if
        // 1. Frame is not a new scene
        // 2. more than 50% area overlap with previous rect
        const matchingPrev =
          !newScene &&
          prevFaces.find((pf) => {
            const xOverlap = Math.max(0, Math.min(f.x + f.w, pf.x + pf.w) - Math.max(f.x, pf.x));
            const yOverlap = Math.max(0, Math.min(f.y + f.h, pf.y + pf.h) - Math.max(f.y, pf.y));
            const areaOverlap = xOverlap * yOverlap;
            return areaOverlap / (f.w * f.h) > 0.5;
          });

        // If not a continuation increment face id
        // as we will later aggregate by id
        if (!matchingPrev) {
          faceId++;
          return {
            ...f,
            id: faceId,
            scene: currentScene,
          };
        }

        // If previous face rect match, reuse the face id
        return {
          ...f,
          id: matchingPrev.id,
          scene: currentScene,
        };
      });

      prevScene = currentScene;
      prevFaces = facesWithId;
      this.faceData.set(frameNumber, facesWithId);
    });

    // Group raw face rects by id
    const faceRectsById = new Map<number, RectWithId[]>();
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-ignore
    [...this.faceData.entries()].forEach(([key, faces]: [number, RectWithId[]]) => {
      faces.forEach((f) => {
        const fById = faceRectsById.get(f.id) ?? [];
        f.frameIndex = key;
        fById.push(f);
        faceRectsById.set(f.id, fById);
      });
    });

    // transforms into FaceSpan type with meta for reference later
    this.faceSpansById = new Map<number, FaceSpan>();
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-ignore
    [...faceRectsById.entries()].forEach(([id, faces]: [number, RectWithId[]]) => {
      let sumArea = 0;
      faces.forEach((f) => {
        sumArea += f.w * f.h;
      });

      const faceSpan: FaceSpan = {
        faceId: id,
        scene: faces.at(0)?.scene ?? -1,
        timestamps: faces.map((f) => (f.frameIndex ?? 0) / this.fps),
        faceRects: faces.map(({ x, y, w, h }) => ({ x, y, w, h })),
        totalArea: sumArea,
        name: "primary",
        color: cycleColor(id),
      };

      // populates smoothedRects attribute
      smoothFaceSpan(faceSpan);

      this.faceSpansById.set(id, faceSpan);
    });
  }

  /**
   * Intended for internal use
   * @param timestamp absolute time
   * @returns scene number (zero indexed)
   */
  getCurrentScene(timestamp: number) {
    if (timestamp >= (this.scenes.at(-1) ?? 0)) {
      return this.scenes.length;
    }
    return this.scenes.findIndex((ts) => ts > timestamp);
  }

  /**
   * Get all the faces present at a given time ordered by total area
   * @param timestamp relative time
   * @returns array of face rectangles along with 'name', 'color', 'scene' index, and id
   */
  getFacesAt(relativeTimestamp: number) {
    const timestamp = getAbsoluteTimestamp(this.clips, relativeTimestamp);
    const nextFrameWithData = this.timestamps.findIndex((ts) => ts > timestamp);

    const faceIds = (this.faceData.get(nextFrameWithData) ?? []).map((f) => f.id);
    const faceSpans: FaceSpan[] = faceIds.map((id) => this.faceSpansById.get(id)!);
    if (!faceSpans.length) {
      return null;
    }

    faceSpans.sort((a: FaceSpan, b: FaceSpan) => (a.totalArea < b.totalArea ? 1 : -1));
    if (faceSpans[0]) {
      faceSpans[0].name = "primary";
    }
    if (faceSpans[1]) {
      faceSpans[1].name = "secondary";
    }
    if (faceSpans[2]) {
      faceSpans[2].name = "tertiary";
    }

    const faces = faceSpans.map((fSpan: FaceSpan) => getFaceAtTimestamp(fSpan, timestamp));

    const facesAugmented = faces
      .map((f, i) => ({
        ...f,
        name: faceSpans[i]?.name,
        color: faceSpans[i]?.color,
        scene: faceSpans[i]?.scene,
        id: faceSpans[i]?.faceId,
      }))
      .filter((f) => f.w * f.h > 0);

    return facesAugmented;
  }

  /**
   * Computes the maximum rectangle encompassing all faces within the specified time range.
   *
   * @param startTime - The starting time of the range.
   * @param endTime - The ending time of the range.
   * @param ignoreOutliers - Indicates if it should filter out outliers.
   * @returns The maximum rectangle that includes all faces within the time range, or undefined if no faces are found.
   */
  getSceneMaxRect(startTime: number, endTime: number, ignoreOutliers = false) {
    const faceRects: Rect[] = [];

    for (let i = 0; i < this.timestamps.length; i++) {
      const timestamp = this.timestamps[i];

      if (timestamp > endTime) {
        break;
      }
      if (timestamp < startTime) {
        continue;
      }

      const face = this.faceData.get(i)?.[0];
      if (face) {
        faceRects.push({ x: face.x, y: face.y, w: face.w, h: face.h });
      }
    }

    if (faceRects.length === 0) {
      return;
    }

    // Filter out outliers
    const filteredRects = ignoreOutliers ? this.filterOutliers(faceRects) : faceRects;

    // Combine the remaining rectangles
    const xMin = Math.min(...filteredRects.map((r) => r.x));
    const yMin = Math.min(...filteredRects.map((r) => r.y));
    const xMax = Math.max(...filteredRects.map((r) => r.x + r.w));
    const yMax = Math.max(...filteredRects.map((r) => r.y + r.h));

    return { x: xMin, y: yMin, w: xMax - xMin, h: yMax - yMin };
  }

  /**
   * Filters out outlier face rectangles based on their distances from the median center.
   *
   * @param faceRects - Array of face rectangles.
   * @param threshold - Threshold for filtering outliers.
   * @returns The face rectangles without the outliers.
   * @private
   */
  private filterOutliers(faceRects: Rect[], threshold = 1.5) {
    if (faceRects.length === 0) {
      return [];
    }

    const centers = faceRects.map(({ x, y, w, h }) => ({
      cx: x + w / 2,
      cy: y + h / 2,
    }));

    const medianCenter = {
      cx: centers.map((c) => c.cx).sort((a, b) => a - b)[Math.floor(centers.length / 2)],
      cy: centers.map((c) => c.cy).sort((a, b) => a - b)[Math.floor(centers.length / 2)],
    };

    const distances = centers.map((c) =>
      Math.sqrt(Math.pow(c.cx - medianCenter.cx, 2) + Math.pow(c.cy - medianCenter.cy, 2))
    );

    // first quartile
    const Q1 = distances[Math.floor(distances.length / 4)];
    // third quartile
    const Q3 = distances[Math.floor((distances.length * 3) / 4)];
    // Interquartile range
    const IQR = Math.abs(Q3 - Q1);

    const lowerBound = Q1 - threshold * IQR;
    const upperBound = Q3 + threshold * IQR;

    return faceRects.filter((_, i) => distances[i] >= lowerBound && distances[i] <= upperBound);
  }

  /**
   * Get all the faces present for a given timespan, ordered by total area
   * @param timestampStart absolute start of timespan
   * @param timestampEnd absolute end of timespan
   * @returns array of face rectangles
   */
  getFacesAtSpan(timestampStart: number, timestampEnd: number) {
    const nextFrameWithData = this.timestamps.findIndex((ts) => ts >= timestampStart);
    const faceIdsAtStart = (this.faceData.get(nextFrameWithData) ?? []).map((f) => f.id);

    // heuristic to calculate if a face is present for a timespan
    // - sample 5 timestamps during the timespan
    // - if the same `faceId` exists at all timestamps
    // - then we assume that the face is present for the entire timespan
    const timespan = timestampEnd - timestampStart;
    const step = timespan / 5;
    const sampleTimes = Array(5)
      .fill(timestampStart)
      .map((timestamp, i) => timestamp + i * step);
    const faceIdsForTimespan = faceIdsAtStart.filter((id) => {
      // array of whether the `id` was found at each sample time
      // e.g. [true, true, true, false, false]
      const faceExistsAtSampleTimes = sampleTimes.map((sampleTimestamp) => {
        const nextFrameWithData = this.timestamps.findIndex((ts) => ts >= sampleTimestamp);
        const faceIdsAtSampleTime = (this.faceData.get(nextFrameWithData) ?? []).map((f) => f.id);
        return faceIdsAtSampleTime.includes(id);
      });
      // if `id` exists at all sample times, then that `id` exists for the timespan
      return faceExistsAtSampleTimes.every((val) => val);
    });

    if (!faceIdsForTimespan.length) {
      return null;
    }

    const faceSpans: FaceSpan[] = faceIdsForTimespan.map((id) => this.faceSpansById.get(id)!);
    if (!faceSpans.length) {
      return null;
    }

    faceSpans.sort((a: FaceSpan, b: FaceSpan) => (a.totalArea < b.totalArea ? 1 : -1));
    const faceCoords = faceSpans.map((faceSpan) => getFaceAtTimestamp(faceSpan, timestampStart));

    return faceCoords;
  }
}
