import {
  concat,
  merge,
  MonoTypeOperatorFunction,
  Observable,
  of,
  OperatorFunction,
  throwError,
  timer,
} from 'rxjs';
import {
  catchError,
  filter,
  first,
  map,
  mergeMap,
  scan,
  switchMap,
  switchMapTo,
  tap,
} from 'rxjs/operators';

export interface State<T> {
  type: 'loading' | 'ok' | 'error';
  result: T | null;
  error: unknown;
}

/**
 * Perform side-effects at the start and first firing of an Observable (even an error).
 */
export function watch<T, TResult>(
  start: () => TResult,
  success: (result: TResult) => void,
  error: (result: TResult, err: unknown) => void = success,
): MonoTypeOperatorFunction<T> {
  return (observable) => {
    let isComplete = false;

    const result = start();
    return observable.pipe(
      tap(
        () => {
          if (!isComplete) {
            success(result);
            isComplete = true;
          }
        },
        (err) => {
          if (!isComplete) {
            error(result, err);
            isComplete = true;
          }
        },
      ),
    );
  };
}

/**
 * Returns a observable that emits {@link State} objects.
 *
 * 1. Emits an initial *loading* state the first time the source emits
 * 2. Triggers the passed in function whenever the source emits,
 *    and emits it as an *ok* or *error* state containing the result/error
 * 3. Emits a transitional *loading* state containing the previous result/error
 *    whenever the source re-emits
 *    (to facilitate showing current result/error during transitional loading)
 */
export function buildState<TContext, TState>(
  fn: (context: TContext) => Observable<TState>,
): OperatorFunction<TContext, State<TState> | null> {
  return (source) =>
    source.pipe(
      switchMap((context) =>
        concat(
          // initial loading state
          of({
            type: 'loading' as const,
            result: null,
            error: null,
          }),
          // fn result state
          (function () {
            try {
              return fn(context);
            } catch (error) {
              return throwError(error);
            }
          })().pipe(wrapState()),
        ),
      ),
      transitionState(),
    );
}

/**
 * RxJS operator that maps values to {@link State} objects.
 * Errors are caught and emitted as *error* states.
 */
function wrapState<T>(): OperatorFunction<T, State<T>> {
  return (source) =>
    source.pipe(
      map((result) => ({
        type: 'ok' as const,
        result,
        error: null,
      })),
      catchError((error: unknown) =>
        of({
          type: 'error' as const,
          result: null,
          error,
        }),
      ),
    );
}

/**
 * RxJS operator that maps states so that any *loading* states
 * include previous state result or error.
 */
function transitionState<T>(): OperatorFunction<State<T>, State<T> | null> {
  return (source) =>
    source.pipe(
      scan<State<T>, State<T> | null>(
        (previous, current) =>
          // include previous error & result in loading state
          current.type === 'loading' && previous
            ? {
                ...current,
                error: previous.error,
                result: previous.result,
              }
            : current,
        null,
      ),
    );
}

export function catchStateError<T>(
  handler: (error: unknown) => void,
): MonoTypeOperatorFunction<State<T> | null> {
  return (source) =>
    source.pipe(
      tap((state) => {
        if (state?.type === 'error') {
          handler(state.error);
        }
      }),
    );
}

/**
 * RxJS operator that re-emits the source whenever trigger emits.
 */
export function replayWhen<T>(
  trigger: Observable<void>,
): MonoTypeOperatorFunction<T> {
  return (source) =>
    merge(source, trigger.pipe(mergeMap(() => source.pipe(first()))));
}

/**
 * RxJS operator that re-runs source until the condition is fulfilled.
 */
export function pollUntil<T>(
  condition: (res: T) => boolean,
  {
    delay = 0,
    interval,
    maxAttempts = 10,
  }: {
    delay?: number;
    interval: number;
    maxAttempts?: number;
  },
): MonoTypeOperatorFunction<T> {
  return (source) =>
    timer(delay, interval).pipe(
      scan((attempts) => attempts + 1, 0),
      tap((attempts) => {
        if (attempts > maxAttempts) {
          throw new Error('Too many attempts');
        }
      }),
      switchMapTo(source),
      filter((res) => condition(res)),
      first(),
    );
}
