import {ReducerAction, ReducerState, useReducer} from 'react'; export type UndoableNode<S> = { current: S; next: UndoableNode<S> | undefined; previous: UndoableNode<S> | undefined; }; type UndoAction = { type: 'undo'; }; type RedoAction = { type: 'redo'; }; export type UndoableReducerAction<A> = UndoAction | RedoAction | A; export type UndoableReducer<R extends React.Reducer<any, any>> = React.Reducer< UndoableNode<ReducerState<R>>, UndoableReducerAction<ReducerAction<R>> >; function isUndoOrRedoAction( action: UndoableReducerAction<any> ): action is UndoAction | RedoAction { if (action.type) { return action.type === 'undo' || action.type === 'redo'; } return false; } function undoableReducer<S>( state: UndoableNode<S>, action: UndoAction | RedoAction ): UndoableNode<S> { if (action.type === 'undo') { return state.previous === undefined ? state : state.previous; } if (action.type === 'redo') { return state.next === undefined ? state : state.next; } throw new Error('Unreachable case'); } export function makeUndoableReducer<R extends React.Reducer<any, any>>( reducer: R ): UndoableReducer<R> { return ( state: UndoableNode<ReducerState<R>>, action: UndoableReducerAction<ReducerAction<R>> ) => { if (isUndoOrRedoAction(action)) { return undoableReducer(state, action); } const newState: UndoableNode<ReducerState<R>> = { next: undefined, previous: state, current: reducer(state.current, action), }; state.next = newState; return newState; }; } export function useUndoableReducer< R extends React.Reducer<ReducerState<R>, ReducerAction<R>> >( reducer: R, initialState: ReducerState<R> ): [ ReducerState<R>, React.Dispatch<UndoableReducerAction<ReducerAction<R>>>, { nextState: ReducerState<R> | undefined; previousState: ReducerState<R> | undefined; } ] { const [state, dispatch] = useReducer(makeUndoableReducer(reducer), { current: initialState, previous: undefined, next: undefined, }); return [ state.current, dispatch, {previousState: state.previous?.current, nextState: state.next?.current}, ]; }