'use client'; import { Button } from '@/components/ui/button'; import { cn } from '@/lib/utils'; import type { UIMessage } from 'ai'; import { ChevronLeftIcon, ChevronRightIcon } from 'lucide-react'; import type { ComponentProps, HTMLAttributes, ReactElement } from 'react'; import { createContext, useContext, useEffect, useState } from 'react'; type BranchContextType = { currentBranch: number; totalBranches: number; goToPrevious: () => void; goToNext: () => void; branches: ReactElement[]; setBranches: (branches: ReactElement[]) => void; }; const BranchContext = createContext(null); const useBranch = () => { const context = useContext(BranchContext); if (!context) { throw new Error('Branch components must be used within Branch'); } return context; }; export type BranchProps = HTMLAttributes & { defaultBranch?: number; onBranchChange?: (branchIndex: number) => void; }; export const Branch = ({ defaultBranch = 0, onBranchChange, className, ...props }: BranchProps) => { const [currentBranch, setCurrentBranch] = useState(defaultBranch); const [branches, setBranches] = useState([]); const handleBranchChange = (newBranch: number) => { setCurrentBranch(newBranch); onBranchChange?.(newBranch); }; const goToPrevious = () => { const newBranch = currentBranch > 0 ? currentBranch - 1 : branches.length - 1; handleBranchChange(newBranch); }; const goToNext = () => { const newBranch = currentBranch < branches.length - 1 ? currentBranch + 1 : 0; handleBranchChange(newBranch); }; const contextValue: BranchContextType = { currentBranch, totalBranches: branches.length, goToPrevious, goToNext, branches, setBranches, }; return (
div]:pb-0', className)} {...props} /> ); }; export type BranchMessagesProps = HTMLAttributes; export const BranchMessages = ({ children, ...props }: BranchMessagesProps) => { const { currentBranch, setBranches, branches } = useBranch(); const childrenArray = Array.isArray(children) ? children : [children]; // Use useEffect to update branches when they change useEffect(() => { if (branches.length !== childrenArray.length) { setBranches(childrenArray); } }, [childrenArray, branches, setBranches]); return childrenArray.map((branch, index) => (
div]:pb-0', index === currentBranch ? 'block' : 'hidden' )} key={branch.key} {...props} > {branch}
)); }; export type BranchSelectorProps = HTMLAttributes & { from: UIMessage['role']; }; export const BranchSelector = ({ className, from, ...props }: BranchSelectorProps) => { const { totalBranches } = useBranch(); // Don't render if there's only one branch if (totalBranches <= 1) { return null; } return (
); }; export type BranchPreviousProps = ComponentProps; export const BranchPrevious = ({ className, children, ...props }: BranchPreviousProps) => { const { goToPrevious, totalBranches } = useBranch(); return ( ); }; export type BranchNextProps = ComponentProps; export const BranchNext = ({ className, children, ...props }: BranchNextProps) => { const { goToNext, totalBranches } = useBranch(); return ( ); }; export type BranchPageProps = HTMLAttributes; export const BranchPage = ({ className, ...props }: BranchPageProps) => { const { currentBranch, totalBranches } = useBranch(); return ( {currentBranch + 1} of {totalBranches} ); };