import { useCallback, useContext, useMemo } from "react";
import SandBoxContext from "../../../../SandBoxContext";
import {
  GraphEntity,
  IEdge,
} from "../../../../../../components/graph-builder/types/GraphTypes";
import {
  useSelectSchemaEdges,
  useSelectSchemaNodes,
} from "../selectors/SchemaBuilderSelectors";
import { useSelectNodes } from "../selectors/GraphBuilderSelectors";

export const useSchemaValidator = () => {
  const [state] = useContext(SandBoxContext);
  const schemaEdges = useSelectSchemaEdges(state);
  const schemaNodes = useSelectSchemaNodes(state);
  const graphNodes = useSelectNodes(state);

  const hasSchema = useMemo(() => schemaNodes.length > 0, [schemaNodes]);
  const getEntityLabel = useCallback((entities: GraphEntity[], id: number) => {
    const entity = entities.find((e) => e.id === id);
    if (entity) {
      return entity.data.label;
    }
    return null;
  }, []);

  /**
   * calculates a map of [label]: [source, target]
   */
  const edgesRelationsMap = useMemo(
    () =>
      schemaEdges.reduce((acc, cur) => {
        const sourceLabel = getEntityLabel(schemaNodes, cur.source);
        const targetLabel = getEntityLabel(schemaNodes, cur.target);
        if (sourceLabel && targetLabel) {
          if (!acc[cur.data.label]) {
            acc[cur.data.label] = [];
          }
          acc[cur.data.label].push([sourceLabel, targetLabel]);
        }
        return acc;
      }, {} as { [key: string]: [string, string][] }),
    [getEntityLabel, schemaEdges, schemaNodes]
  );

  const validator = useCallback(
    (func: () => string) => {
      if (hasSchema) {
        return func();
      }
      return "";
    },
    [hasSchema]
  );

  const makeError = useCallback(
    (msg: string) => `Schema validation error: ${msg}`,
    []
  );

  /**
   * validate that an edge is legal
   */
  const validateEdge = useCallback(
    (edge: IEdge) => {
      return validator(() => {
        const sourceLabel = getEntityLabel(graphNodes, edge.source);
        const targetLabel = getEntityLabel(graphNodes, edge.target);
        const entries = edgesRelationsMap[edge.data.label];
        const validEntry = entries.find(
          ([source, target]) => sourceLabel === source && targetLabel === target
        );
        if (entries && validEntry) {
          return "";
        } else {
          return makeError(
            `Invalid edge '${edge.data.label}' between '${sourceLabel}' and '${targetLabel}'`
          );
        }
      });
    },
    [edgesRelationsMap, getEntityLabel, graphNodes, makeError, validator]
  );

  return {
    validateEdge,
  };
};
