import React from "react";
import { isHTMLElement } from "../utils/isHTMLElement";

type FocusTrapContainerSetter<T> = (focusTrapContainer: T) => void;

/**
 * Traps focus within a DOM node. Shifting focus with `Tab` and `Shift-Tab` will loop through all focusable elements of the node.
 * Returns a function that should be set as the ref of the node to trap focus within.
 */
export function useFocusTrap<T extends HTMLElement>(): FocusTrapContainerSetter<T> {
    const [focusContainer, setFocusContainer] = React.useState<T | null>(null);

    React.useEffect(() => {
        if (!focusContainer) return;

        const onKeyDown = (event: KeyboardEvent) => {
            if (event.key !== "Tab") return;

            const tabbableElements = Array.from(focusContainer.querySelectorAll(focusableSelector)).filter(isTabbableElement);

            const firstFocusableElement = tabbableElements[0];
            const lastFocusableElement = tabbableElements[tabbableElements.length - 1];

            if (event.shiftKey && document.activeElement === firstFocusableElement) {
                lastFocusableElement.focus();
                event.preventDefault();
            } else if (!event.shiftKey && document.activeElement === lastFocusableElement) {
                firstFocusableElement.focus();
                event.preventDefault();
            }
        };

        focusContainer.addEventListener("keydown", onKeyDown);
        return () => {
            focusContainer.removeEventListener("keydown", onKeyDown);
        };
    }, [focusContainer]);

    return setFocusContainer;
}

// From https://github.com/testing-library/user-event/blob/main/src/utils/focus/selector.ts
const focusableSelector = ["input:not([type=hidden]):not([disabled])", "button:not([disabled])", "select:not([disabled])", "textarea:not([disabled])", '[contenteditable=""]', '[contenteditable="true"]', "a[href]", "[tabindex]:not([disabled])"].join(
    ", "
);

function isTabbableElement(element: Element): element is HTMLElement {
    if (!isHTMLElement(element)) return false;
    return element.tabIndex >= 0;
}
