import {
  Dispatch,
  FC,
  ReactNode,
  SetStateAction,
  createContext,
  useCallback,
  useEffect,
  useState,
} from 'react';
import useWebSocket from 'react-use-websocket';
import useSession from '@/hooks/useSession';
import { useNavigate, useParams } from 'react-router-dom';
import {
  Thread,
  ThreadID,
  ThreadBlock,
  ThreadStatus,
  ThreadFilters,
  DEFAULT_FILTERS,
  WSThread,
  WSParameters,
} from '@/models/thread';

import { getToday } from '@/utils/functions';
import updateTitle from '@/api/threads/updateTitle';
import { fetchAuthSession } from 'aws-amplify/auth';

interface ThreadContextProps extends Thread {
  status: ThreadStatus;
  setStatus: Dispatch<SetStateAction<ThreadStatus>>;
  handleTitleUpdate: (newTitle: string) => void;
  handlePromptSubmit: (msg: string, filters?: ThreadFilters) => void;
  handleRegenerate: (block: ThreadBlock) => void;
  resetThread: () => void;
}

export const ThreadContext = createContext<ThreadContextProps>({
  id: undefined,
  title: '',
  blocks: [],
  status: 'READY',
  filters: DEFAULT_FILTERS,
  setStatus: () => ({}),
  handleTitleUpdate: () => ({}),
  handlePromptSubmit: () => ({}),
  handleRegenerate: () => ({}),
  resetThread: () => ({}),
});

export const ThreadProvider: FC<{ children: ReactNode }> = ({ children }) => {
  const { id: routeParamID } = useParams();

  const [threadID, setThreadID] = useState<ThreadID>(routeParamID);
  const { session, refreshThreads, pushUserEvent } = useSession();
  const { sub, threads } = session;
  const [title, setTitle] = useState<string>('');
  const [blocks, setBlocks] = useState<ThreadBlock[]>([]);
  const [filters, setFilters] = useState<ThreadFilters>(DEFAULT_FILTERS);
  const [status, setStatus] = useState<ThreadStatus>('READY');
  const navigate = useNavigate();
  const socketUrl = `${import.meta.env.VITE_BACKEND_WS_LLM}/${
    import.meta.env.VITE_BACKEND_STAGE
  }`;
  const [idToken, setIdToken] = useState<string | undefined>(undefined);

  const refreshToken = async () => {
    setIdToken(undefined);
    const session = await fetchAuthSession();
    if (session && session.tokens?.idToken) {
      setIdToken(session?.tokens?.idToken.toString());
    } else {
      console.error('Failed to fetch session');
    }
  };

  const { sendJsonMessage } = useWebSocket(
    idToken ? `${socketUrl}/?IdToken=${idToken}` : null,
    {
      onError: (event) => {
        pushUserEvent('error_websocket_failed');
        console.error('Failed to connect', event);
        if (status === 'FETCHING') {
          setStatus('ERROR');
          refreshThreads();
        }
      },
      onMessage: (msg: { data: string }) => {
        try {
          const { success, data: msgData }: WSThread = JSON.parse(msg.data);

          if (success === false) {
            setStatus('ERROR');
            refreshThreads();
            return;
          }

          if (threadID !== undefined && msgData.thread_id !== threadID) {
            pushUserEvent('error_websocket_interrupted');
            return;
          }

          if (msgData.type === 'streamNewMessage') {
            if (msgData.additional_data === undefined) {
              throw new Error('No additional data');
            }

            const bufferedMessage =
              msgData.additional_data.buffered_streaming_contents;

            if (status === 'FETCHING') {
              setStatus('TYPING');
            }

            setBlocks((current: ThreadBlock[]) => {
              try {
                const block = current[current.length - 1];
                block.response = bufferedMessage;
                return [...current.slice(0, -1), block];
              } catch (e) {
                console.error(e);
                return current;
              }
            });
          } else if (msgData.type === 'streamComplete') {
            setBlocks((current: ThreadBlock[]) => {
              const block = current[current.length - 1];

              if (msgData.additional_data !== undefined) {
                block.response =
                  msgData.additional_data.buffered_streaming_contents;
              }

              if (msgData.rag_sources !== undefined) {
                block.sources = msgData.rag_sources.map((source) => ({
                  title: source.title,
                  url: source.url,
                  logo: '',
                }));
              }

              return [...current.slice(0, -1), block];
            });

            refreshThreads();

            setIdToken(undefined);
            setStatus('READY');
          } else if (msgData.type === 'streamError') {
            pushUserEvent('error_websocket_stream');
            setIdToken(undefined);
            setStatus('ERROR');
          } else if (success !== undefined) {
            if (success === true) {
              const payloadData = msgData;

              if (threadID === undefined) {
                setThreadID(payloadData.thread_id);
                setTitle(payloadData.title as string);
              }

              setBlocks((current: ThreadBlock[]) => {
                const block = current[current.length - 1];
                block.id = payloadData.message_id;

                return [...current.slice(0, -1), block];
              });

              setStatus('FETCHING');
            } else {
              console.error('Failed to fetch', msgData);
              setStatus('ERROR');
            }
          }
        } catch (e) {
          pushUserEvent('error_websocket_internal_error');
          console.error(e);
          setStatus('ERROR');
        }
      },
      shouldReconnect: () => false,
    }
  );

  const resetThread = () => {
    setThreadID(undefined);
    setIdToken(undefined);
    setTitle('');
    setBlocks([]);
    setStatus('READY');
  };

  const handleTitleUpdate = (newTitle: string): void => {
    if (!threadID) return;

    const currentTitle = title;
    setTitle(newTitle);
    pushUserEvent('update_thread_title');
    updateTitle(sub, threadID, newTitle).then((success) => {
      if (success) {
        refreshThreads();
      } else {
        setTitle(currentTitle);
      }
    });
  };

  const handlePromptSubmit = (
    prompt: string,
    filters?: ThreadFilters
  ): void => {
    const newThreadBlock: ThreadBlock = {
      id: 'new',
      disliked: false,
      prompt,
      response: '',
      queryDate: getToday(),
    };
    setBlocks([...blocks, newThreadBlock]);

    if (!threadID) {
      setFilters(filters ?? DEFAULT_FILTERS);
      pushUserEvent('submit_new_thread');
      navigate('/app/thread/new');
    } else {
      pushUserEvent('add_follow_up');
      addResponse(newThreadBlock);
    }
  };

  const handleRegenerate = (block: ThreadBlock): void => {
    setBlocks((current: ThreadBlock[]) => {
      return [
        ...current.slice(0, -1),
        {
          id: block.id,
          prompt: block.prompt,
          response: null,
          disliked: false,
          sources: [],
          queryDate: getToday(),
        },
      ];
    });
    addResponse(block);
  };

  // Add Response
  const addResponse = useCallback(
    (block: ThreadBlock): void => {
      setStatus('FETCHING');

      const { remainingToday, remainingThisMonth } = session.promptUsage;

      if (remainingToday <= 0 || remainingThisMonth <= 0) {
        setStatus('ERROR');
        return;
      }

      const { earliestDate, latestDate } = filters;

      const wsParams: WSParameters = {
        action: 'createThreadPre',
        user_sub: session.sub,
        user_input: block.prompt,
      };

      if (!threadID) {
        wsParams.filters = {
          source_category: filters.sources,
          publish_date: {
            earliest: earliestDate ? new Date(earliestDate).getTime() : 1,
            latest: latestDate
              ? new Date(latestDate).getTime()
              : getToday().getTime(),
          },
          countries: filters.countries,
        };
      } else {
        wsParams.action = 'addMessageThreadPre';
        wsParams.thread_id = threadID;
      }
      refreshToken().then(() => sendJsonMessage(wsParams));
    },
    [session.promptUsage, session.sub, filters, sendJsonMessage, threadID]
  );

  // Update ThreadID
  useEffect(() => {
    if (routeParamID !== 'new' && routeParamID !== threadID) {
      setThreadID(routeParamID);
      setBlocks([]);
    }
  }, [routeParamID, threadID]);

  // Start New Thread
  useEffect(() => {
    if (
      routeParamID === 'new' &&
      blocks.length === 1 &&
      !blocks[0].response &&
      status === 'READY'
    ) {
      addResponse(blocks[blocks.length - 1]);
    }
  }, [addResponse, blocks, routeParamID, status]);

  // Load Existing Thread
  useEffect(() => {
    if (
      blocks.length === 0 &&
      typeof threadID !== 'undefined' &&
      routeParamID === threadID
    ) {
      const savedThread = threads.find(
        (thread: Thread) => thread.id === threadID
      );
      if (savedThread) {
        setTitle(savedThread.title);
        setBlocks(savedThread.blocks);
        setFilters(savedThread.filters);
        setIdToken(undefined);
        setStatus('READY');
      } else {
        navigate('/not-found', { replace: true });
      }
    }
  }, [navigate, routeParamID, threadID, threads, blocks]);

  // Update URL on new ThreadID
  useEffect(() => {
    if (routeParamID === 'new' && threadID) {
      navigate(`/app/thread/${threadID}`, { replace: true });
    }
  }, [navigate, routeParamID, threadID]);

  // Clean up when not in a thread
  useEffect(() => {
    if (typeof routeParamID === 'undefined') {
      resetThread();
    }
  }, [routeParamID]);

  return (
    <ThreadContext.Provider
      value={{
        id: threadID,
        title,
        blocks,
        status,
        filters,
        setStatus,
        handleTitleUpdate,
        handlePromptSubmit,
        handleRegenerate,
        resetThread,
      }}
    >
      {children}
    </ThreadContext.Provider>
  );
};
