import {
  InvocationLogRecord,
  InvocationProcess,
  InvocationRecord,
  InvocationStatus,
} from "@/shared/domain/actions";
import { APIService } from "@/shared/services";
import { APIResponse, RobotoAPICall } from "@/types";

export async function getInvocation(
  invocationId: string,
  orgId?: string,
  signal?: AbortSignal,
): Promise<InvocationRecord> {
  const apiCall: RobotoAPICall = {
    endpoint: () => `/actions/invocations/${invocationId}`,
    method: "GET",
    orgId,
    signal,
  };

  const { response, error } =
    await APIService.authorizedRequest<APIResponse<InvocationRecord>>(apiCall);

  if (error) {
    throw error;
  }

  if (!response?.data) {
    throw new Error("Invocation not found");
  }

  return response.data;
}

interface StreamLogsResponse {
  items: InvocationLogRecord[];
  has_next: boolean;
  last_read: string;
}

export async function getInvocationLogs(
  invocationId: string,
  orgId?: string,
  signal?: AbortSignal,
  lastRead?: string,
): Promise<StreamLogsResponse> {
  const request: RobotoAPICall = {
    endpoint: () => `/actions/invocations/${invocationId}/logs/stream`,
    method: "GET",
    orgId,
    signal,
  };

  if (lastRead) {
    const searchParams = new URLSearchParams();
    searchParams.set("last_read", lastRead);
    request.queryParams = searchParams;
  }

  const { response, error } =
    await APIService.authorizedRequest<APIResponse<StreamLogsResponse>>(
      request,
    );

  if (error) {
    throw error;
  }

  if (!response?.data) {
    throw new Error("No logs data found");
  }

  return response.data;
}

type LogsBatch = {
  [key in InvocationProcess]: InvocationLogRecord[];
};

interface LogStreamState {
  modified: Date;
  lastEmit: Date;
  loading: boolean;
  invocation: InvocationRecord | null;
  [InvocationProcess.Setup]: InvocationLogRecord[];
  [InvocationProcess.Action]: InvocationLogRecord[];
  [InvocationProcess.OutputHandler]: InvocationLogRecord[];
}

type LogStreamListener = () => void;

/**
 * A `useSyncExternalStore` compatible API for streaming invocation logs.
 * It is **not** intended to have multiple subscribers.
 */
export class LogStreamer {
  private invocationId: string;
  private lastRead: string | undefined = undefined;
  private listener: LogStreamListener = () => {
    /** noop */
  };
  private state: LogStreamState;
  private streamId: string | null = null;
  private orgId?: string;
  private abortController: AbortController | null = null;

  constructor(invocationId: string, orgId?: string) {
    this.invocationId = invocationId;
    this.orgId = orgId;
    this.state = {
      modified: new Date(),
      lastEmit: new Date(),
      loading: false,
      invocation: null,
      [InvocationProcess.Setup]: [],
      [InvocationProcess.Action]: [],
      [InvocationProcess.OutputHandler]: [],
    };
  }

  public close = () => {
    if (this.abortController) {
      this.abortController.abort();
      this.abortController = null;
    }

    if (!this.state.loading && this.streamId === null) {
      return;
    }

    this.state = {
      ...this.state,
      modified: new Date(),
      loading: false,
    };
    this.streamId = null;
  };

  /**
   * Starting stream of logs, emitting changes to the listener when logs are available
   *
   * The use of `streamId` is to prevent multiple streams from being opened at once.
   * This is seen in development in React StrictMode, when Effect setup is called twice.
   */
  public stream = async (streamId: string) => {
    if (this.streamId) {
      this.close();
    }

    this.streamId = streamId;
    this.abortController = new AbortController();

    let invocation = this.state.invocation;
    if (invocation === null) {
      invocation = await getInvocation(
        this.invocationId,
        this.orgId,
        this.abortController.signal,
      );
    }

    this.state = {
      ...this.state,
      modified: new Date(),
      loading: true,
      invocation: invocation,
    };

    this.emitChange();

    while (this.state.loading && this.streamId === streamId) {
      const { items, has_next, last_read } = await getInvocationLogs(
        this.invocationId,
        this.orgId,
        this.abortController.signal,
        this.lastRead,
      );
      this.lastRead = last_read;
      this.mergeState(items);

      if (
        this.state.invocation &&
        this.state.invocation.last_status < InvocationStatus.Completed
      ) {
        invocation = await getInvocation(
          this.invocationId,
          this.orgId,
          this.abortController.signal,
        );

        this.state = {
          ...this.state,
          modified: new Date(),
          invocation: invocation,
        };
      }

      if (
        !has_next &&
        this.state.invocation &&
        this.state.invocation.last_status >= InvocationStatus.Completed
      ) {
        this.state = {
          ...this.state,
          modified: new Date(),
          loading: false,
        };
        this.streamId = null;
      }

      this.emitChange();

      if (
        this.state.invocation &&
        this.state.invocation.last_status < InvocationStatus.Completed
      ) {
        await new Promise((resolve) => setTimeout(resolve, 2000));
      }
    }
  };

  public subscribe = (listener: LogStreamListener): (() => void) => {
    this.listener = listener;
    return () => {
      this.close();
    };
  };

  public getSnapshot = (): LogStreamState => {
    return this.state;
  };

  private emitChange() {
    if (this.state.lastEmit.getTime() >= this.state.modified.getTime()) {
      return;
    }
    this.listener();
    this.state.lastEmit = new Date();
  }

  private mergeState(records: InvocationLogRecord[]) {
    const { setup, action, output_handler } =
      this.groupAndTransformLogs(records);
    if (!setup.length && !action.length && !output_handler.length) {
      return;
    }

    const nextState = {
      ...this.state,
      modified: new Date(),
    };
    if (setup.length) {
      const mergedSetupLogs = new Map<string, InvocationLogRecord>();
      this.state[InvocationProcess.Setup].forEach((record) =>
        mergedSetupLogs.set(record.partial_id ?? record.timestamp, record),
      );
      setup.forEach((record) =>
        mergedSetupLogs.set(record.partial_id ?? record.timestamp, record),
      );
      nextState[InvocationProcess.Setup] = Array.from(mergedSetupLogs.values());
    }

    if (action.length) {
      const mergedActionLogs = new Map<string, InvocationLogRecord>();
      this.state[InvocationProcess.Action].forEach((record) =>
        mergedActionLogs.set(record.partial_id ?? record.timestamp, record),
      );
      action.forEach((record) =>
        mergedActionLogs.set(record.partial_id ?? record.timestamp, record),
      );
      nextState[InvocationProcess.Action] = Array.from(
        mergedActionLogs.values(),
      );
    }

    if (output_handler.length) {
      const mergedOutputHandlerLogs = new Map<string, InvocationLogRecord>();
      this.state[InvocationProcess.OutputHandler].forEach((record) =>
        mergedOutputHandlerLogs.set(
          record.partial_id ?? record.timestamp,
          record,
        ),
      );
      output_handler.forEach((record) =>
        mergedOutputHandlerLogs.set(
          record.partial_id ?? record.timestamp,
          record,
        ),
      );
      nextState[InvocationProcess.OutputHandler] = Array.from(
        mergedOutputHandlerLogs.values(),
      );
    }

    this.state = nextState;
  }

  /**
   * - Group message by process
   * - Account for partial log messages (those with the same log_id)
   * - Apply carriage returns
   */
  private groupAndTransformLogs(records: InvocationLogRecord[]): LogsBatch {
    const batch: LogsBatch = {
      [InvocationProcess.Setup]: [],
      [InvocationProcess.Action]: [],
      [InvocationProcess.OutputHandler]: [],
    };

    return records.reduce((accum, record) => {
      const processLogs = accum[record.process];

      const lastRecord = processLogs[processLogs.length - 1];
      if (record.partial_id && lastRecord?.partial_id === record.partial_id) {
        processLogs.pop();
      }

      if (record.log.includes("\r")) {
        const logMsgparts = record.log.split("\r");
        const recordWithCarriageReturnsApplied = {
          ...record,
          log: logMsgparts[logMsgparts.length - 1],
        };
        processLogs.push(recordWithCarriageReturnsApplied);
      } else {
        processLogs.push(record);
      }
      return accum;
    }, batch);
  }
}
