import { useCallback, useMemo, useReducer } from "react"

export interface Node<T> {
    id: string
    value?: T
    children?: Node<T>[]
    disable?: boolean
}

export enum Selected {
    YES = 1,
    NO,
    PARTIAL,
}

export const AllIDs = <T>(node: Node<T>, ids?: Set<string>) => {
    const result = ids || new Set<string>()

    if (isLeaf(node)) {
        result.add(node.id)
    }
    node.children?.forEach((ch) => AllIDs(ch, result))
    return result
}

const toSelection = <T>(node: Node<T>, ids: Set<string>, nodes: Map<string, Selected>) => {
    if (isLeaf(node)) {
        return ids.has(node.id) ? Selected.YES : Selected.NO
    }

    var soFar: Selected | undefined
    for (const child of node.children || []) {
        const state = nodes.get(child.id)

        if (state === Selected.PARTIAL) {
            return Selected.PARTIAL
        }
        if (state === Selected.YES && soFar === Selected.NO) {
            return Selected.PARTIAL
        }
        if (state === Selected.NO && soFar === Selected.YES) {
            return Selected.PARTIAL
        }
        if (soFar === undefined) {
            soFar = state
        }
    }
    return soFar || Selected.NO
}

const toNodes = <T>(tree: Node<T>, ids: Set<string>, nodes?: Map<string, Selected>): Map<string, Selected> => {
    const result = nodes || new Map<string, Selected>()
    if (tree.disable) return result

    tree.children?.forEach((ch) => toNodes(ch, ids, result))
    result.set(tree.id, toSelection(tree, ids, result))

    return result
}

const toValues = <T>(tree: Node<T>, ids: Set<string>, values?: T[]): T[] => {
    const result: T[] = values || []

    if (ids.has(tree.id) && tree.value !== undefined) {
        result.push(tree.value)
    }
    tree.children?.forEach((ch) => toValues(ch, ids, result))
    return result
}

const isLeaf = <T>(node: Node<T>) => !node.children || node.children.length === 0

const setTree = <T>(tree: Node<T>, values: Set<string>, selected: Selected) => {
    if (isLeaf(tree)) {
        if (selected === Selected.YES && !tree.disable) {
            values.add(tree.id)
        } else {
            values.delete(tree.id)
        }
    }
    tree.children?.forEach((ch) => setTree(ch, values, selected))
}

const flipTree = <T>(tree: Node<T>, values: Set<string>) => {
    const nodes = toNodes(tree, values)

    const selected = nodes.get(tree.id) === Selected.YES ? Selected.NO : Selected.YES
    setTree(tree, values, selected)
}

const findNode = <T>(tree: Node<T>, id: string): Node<T> | undefined => {
    if (tree.id === id) {
        return tree
    }
    for (const ch of tree.children || []) {
        const node = findNode(ch, id)
        if (node) {
            return node
        }
    }
}

type State = {
    ids: Set<string>
}

type Action<T> = { type: "Flip"; id: string; tree: Node<T> }

function reducer<T>(state: State, action: Action<T>) {
    switch (action.type) {
        case "Flip":
            var current = findNode(action.tree, action.id)
            if (!current) {
                console.error("Node id not found. Ignoring flip request.", action.id, action.tree)
                return state
            }
            const flipped = new Set(state.ids)
            flipTree(current, flipped)
            return { ids: flipped }
    }
}

export function useTreeSet<T>(tree: Node<T>, initialIDs: Set<string>) {
    const [state, dispatch] = useReducer(reducer, { ids: initialIDs })
    const flip = useCallback((id: string) => dispatch({ type: "Flip", id, tree }), [tree])
    const nodes = useMemo(() => toNodes(tree, state.ids), [state.ids, tree])
    const values = useMemo(() => toValues(tree, state.ids), [state.ids, tree])

    return { ids: state.ids, values, nodes, flip }
}
