Skip to content

Commit 67a9f5d

Browse files
author
Attila Cseh
committed
TabActionContext implemented
1 parent aa9c61d commit 67a9f5d

File tree

75 files changed

+577
-704
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+577
-704
lines changed

invokeai/frontend/web/.storybook/ReduxInit.tsx

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ import { memo, useEffect } from 'react';
44

55
import { useAppDispatch } from '../src/app/store/storeHooks';
66
import { modelChanged } from 'features/controlLayers/store/actions';
7-
import { useParamsDispatch } from 'features/controlLayers/store/paramsSlice';
87
/**
98
* Initializes some state for storybook. Must be in a different component
109
* so that it is run inside the redux context.
1110
*/
1211
export const ReduxInit = memo(({ children }: PropsWithChildren) => {
13-
const dispatch = useParamsDispatch();
12+
const dispatch = useAppDispatch();
1413
useGlobalModifiersInit();
1514
useEffect(() => {
16-
dispatch(modelChanged, {
17-
model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' },
18-
});
15+
dispatch(
16+
modelChanged({
17+
model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' },
18+
})
19+
);
1920
}, [dispatch]);
2021

2122
return children;
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import type { Middleware, UnknownAction } from '@reduxjs/toolkit';
2+
import { injectTabActionContext } from 'app/store/util';
3+
import { isCanvasInstanceAction } from 'features/controlLayers/store/canvasSlice';
4+
import { isTabParamsStateAction } from 'features/controlLayers/store/paramsSlice';
5+
import { selectActiveCanvasId } from 'features/controlLayers/store/selectors';
6+
import { selectActiveTab } from 'features/ui/store/uiSelectors';
7+
8+
export const actionContextMiddleware: Middleware = (store) => (next) => (action) => {
9+
const currentAction = action as UnknownAction;
10+
11+
if (isTabActionContextRequired(currentAction)) {
12+
const state = store.getState();
13+
const tab = selectActiveTab(state);
14+
const canvasId = tab === 'canvas' ? selectActiveCanvasId(state) : undefined;
15+
16+
injectTabActionContext(currentAction, tab, canvasId);
17+
}
18+
19+
return next(action);
20+
};
21+
22+
const isTabActionContextRequired = (action: UnknownAction) => {
23+
return isTabParamsStateAction(action) || isCanvasInstanceAction(action);
24+
};

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { AppStartListening } from 'app/store/store';
2-
import { paramsDispatch, selectActiveParams, setInfillMethod } from 'features/controlLayers/store/paramsSlice';
2+
import { selectActiveParams, setInfillMethod } from 'features/controlLayers/store/paramsSlice';
33
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
44
import { appInfoApi } from 'services/api/endpoints/appInfo';
55

@@ -15,7 +15,7 @@ export const addAppConfigReceivedListener = (startAppListening: AppStartListenin
1515
// If the selected infill method does not exist, prefer 'lama' if it's in the list, otherwise 'tile'.
1616
// TODO(psyche): lama _should_ always be in the list, but the API doesn't guarantee it...
1717
const infillMethod = infill_methods.includes('lama') ? 'lama' : 'tile';
18-
paramsDispatch(api, setInfillMethod, infillMethod);
18+
dispatch(setInfillMethod(infillMethod));
1919
}
2020

2121
if (!nsfw_methods.includes('nsfw_checker')) {

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@ import {
77
selectActiveCanvasStagingAreaSessionId,
88
} from 'features/controlLayers/store/canvasStagingAreaSlice';
99
import { loraIsEnabledChanged } from 'features/controlLayers/store/lorasSlice';
10-
import {
11-
paramsDispatch,
12-
selectActiveParams,
13-
syncedToOptimalDimension,
14-
vaeSelected,
15-
} from 'features/controlLayers/store/paramsSlice';
10+
import { selectActiveParams, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
1611
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
1712
import {
1813
selectActiveCanvas,
@@ -70,7 +65,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
7065
// handle incompatible vae
7166
const { vae } = params;
7267
if (vae && vae.base !== newBase) {
73-
paramsDispatch(api, vaeSelected, null);
68+
dispatch(vaeSelected(null));
7469
modelsUpdatedDisabledOrCleared += 1;
7570
}
7671

@@ -163,13 +158,13 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
163158
}
164159
}
165160

166-
paramsDispatch(api, modelChanged, { model: newModel, previousModel: params.model });
161+
dispatch(modelChanged({ model: newModel, previousModel: params.model }));
167162

168163
const modelBase = selectBboxModelBase(state);
169164

170165
if (modelBase !== params.model?.base) {
171166
// Sync generate tab settings whenever the model base changes
172-
paramsDispatch(api, syncedToOptimalDimension);
167+
dispatch(syncedToOptimalDimension());
173168
const sessionId = selectActiveCanvasStagingAreaSessionId(state);
174169
const selectIsStaging = buildSelectIsStagingBySessionId(sessionId);
175170
const isStaging = selectIsStaging(state);

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
66
import {
77
clipEmbedModelSelected,
88
fluxVAESelected,
9-
paramsDispatch,
109
refinerModelChanged,
1110
selectActiveParams,
1211
t5EncoderModelSelected,
13-
toStore,
1412
vaeSelected,
1513
} from 'features/controlLayers/store/paramsSlice';
1614
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
@@ -116,7 +114,7 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
116114
// Only clear the model if we have one currently selected
117115
if (selectedMainModel !== null) {
118116
log.debug({ selectedMainModel }, 'No main models available, clearing');
119-
paramsDispatch(toStore(state, dispatch), modelChanged, { model: null });
117+
dispatch(modelChanged({ model: null }));
120118
}
121119
return;
122120
}
@@ -166,7 +164,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, log) => {
166164

167165
// Else, we need to clear the refiner model
168166
log.debug({ selectedRefinerModel }, 'Selected refiner model is not available, clearing');
169-
paramsDispatch(toStore(state, dispatch), refinerModelChanged, null);
167+
dispatch(refinerModelChanged(null));
170168
return;
171169
};
172170

@@ -190,7 +188,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
190188

191189
// Else, we need to clear the VAE model
192190
log.debug({ selectedVAEModel }, 'Selected VAE model is not available, clearing');
193-
paramsDispatch(toStore(state, dispatch), vaeSelected, null);
191+
dispatch(vaeSelected(null));
194192
return;
195193
};
196194

@@ -435,14 +433,14 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
435433
{ selectedT5EncoderModel, firstModel },
436434
'No selected T5 encoder model or selected T5 encoder model is not available, selecting first available model'
437435
);
438-
paramsDispatch(toStore(state, dispatch), t5EncoderModelSelected, zParameterT5EncoderModel.parse(firstModel));
436+
dispatch(t5EncoderModelSelected(zParameterT5EncoderModel.parse(firstModel)));
439437
return;
440438
}
441439

442440
// No available models, we should clear the selected model - but only if we have one selected
443441
if (selectedT5EncoderModel) {
444442
log.debug({ selectedT5EncoderModel }, 'Selected T5 encoder model is not available, clearing');
445-
paramsDispatch(toStore(state, dispatch), t5EncoderModelSelected, null);
443+
dispatch(t5EncoderModelSelected(null));
446444
return;
447445
}
448446
};
@@ -463,14 +461,14 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
463461
{ selectedCLIPEmbedModel, firstModel },
464462
'No selected CLIP embed model or selected CLIP embed model is not available, selecting first available model'
465463
);
466-
paramsDispatch(toStore(state, dispatch), clipEmbedModelSelected, zParameterCLIPEmbedModel.parse(firstModel));
464+
dispatch(clipEmbedModelSelected(zParameterCLIPEmbedModel.parse(firstModel)));
467465
return;
468466
}
469467

470468
// No available models, we should clear the selected model - but only if we have one selected
471469
if (selectedCLIPEmbedModel) {
472470
log.debug({ selectedCLIPEmbedModel }, 'Selected CLIP embed model is not available, clearing');
473-
paramsDispatch(toStore(state, dispatch), clipEmbedModelSelected, null);
471+
dispatch(clipEmbedModelSelected(null));
474472
return;
475473
}
476474
};
@@ -491,14 +489,14 @@ const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
491489
{ selectedFLUXVAEModel, firstModel },
492490
'No selected FLUX VAE model or selected FLUX VAE model is not available, selecting first available model'
493491
);
494-
paramsDispatch(toStore(state, dispatch), fluxVAESelected, zParameterVAEModel.parse(firstModel));
492+
dispatch(fluxVAESelected(zParameterVAEModel.parse(firstModel)));
495493
return;
496494
}
497495

498496
// No available models, we should clear the selected model - but only if we have one selected
499497
if (selectedFLUXVAEModel) {
500498
log.debug({ selectedFLUXVAEModel }, 'Selected FLUX VAE model is not available, clearing');
501-
paramsDispatch(toStore(state, dispatch), fluxVAESelected, null);
499+
dispatch(fluxVAESelected(null));
502500
return;
503501
}
504502
};

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import {
77
} from 'features/controlLayers/store/canvasStagingAreaSlice';
88
import {
99
heightChanged,
10-
paramsDispatch,
1110
selectActiveParams,
1211
setCfgRescaleMultiplier,
1312
setCfgScale,
@@ -67,56 +66,56 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
6766
// we store this as "default" within default settings
6867
// to distinguish it from no default set
6968
if (vae === 'default') {
70-
paramsDispatch(api, vaeSelected, null);
69+
dispatch(vaeSelected(null));
7170
} else {
7271
const vaeModel = models.find((model) => model.key === vae);
7372
const result = zParameterVAEModel.safeParse(vaeModel);
7473
if (!result.success) {
7574
return;
7675
}
77-
paramsDispatch(api, vaeSelected, result.data);
76+
dispatch(vaeSelected(result.data));
7877
}
7978
}
8079

8180
if (vae_precision) {
8281
if (isParameterPrecision(vae_precision)) {
83-
paramsDispatch(api, vaePrecisionChanged, vae_precision);
82+
dispatch(vaePrecisionChanged(vae_precision));
8483
}
8584
}
8685

8786
if (guidance) {
8887
if (isParameterGuidance(guidance)) {
89-
paramsDispatch(api, setGuidance, guidance);
88+
dispatch(setGuidance(guidance));
9089
}
9190
}
9291

9392
if (cfg_scale) {
9493
if (isParameterCFGScale(cfg_scale)) {
95-
paramsDispatch(api, setCfgScale, cfg_scale);
94+
dispatch(setCfgScale(cfg_scale));
9695
}
9796
}
9897

9998
if (!isNil(cfg_rescale_multiplier)) {
10099
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
101-
paramsDispatch(api, setCfgRescaleMultiplier, cfg_rescale_multiplier);
100+
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
102101
}
103102
} else {
104103
// Set this to 0 if it doesn't have a default. This value is
105104
// easy to miss in the UI when users are resetting defaults
106105
// and leaving it non-zero could lead to detrimental
107106
// effects.
108-
paramsDispatch(api, setCfgRescaleMultiplier, 0);
107+
dispatch(setCfgRescaleMultiplier(0));
109108
}
110109

111110
if (steps) {
112111
if (isParameterSteps(steps)) {
113-
paramsDispatch(api, setSteps, steps);
112+
dispatch(setSteps(steps));
114113
}
115114
}
116115

117116
if (scheduler) {
118117
if (isParameterScheduler(scheduler)) {
119-
paramsDispatch(api, setScheduler, scheduler);
118+
dispatch(setScheduler(scheduler));
120119
}
121120
}
122121
const setSizeOptions = { updateAspectRatio: true, clamp: true };
@@ -128,10 +127,10 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
128127
const activeTab = selectActiveTab(getState());
129128
if (activeTab === 'generate') {
130129
if (isParameterWidth(width)) {
131-
paramsDispatch(api, widthChanged, { width, ...setSizeOptions });
130+
dispatch(widthChanged({ width, ...setSizeOptions }));
132131
}
133132
if (isParameterHeight(height)) {
134-
paramsDispatch(api, heightChanged, { height, ...setSizeOptions });
133+
dispatch(heightChanged({ height, ...setSizeOptions }));
135134
}
136135
}
137136

invokeai/frontend/web/src/app/store/store.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import { authToastMiddleware } from 'services/api/authToastMiddleware';
4848
import type { JsonObject } from 'type-fest';
4949

5050
import { reduxRememberDriver } from './enhancers/reduxRemember/driver';
51+
import { actionContextMiddleware } from './middleware/actionContextMiddleware';
5152
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
5253
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
5354
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
@@ -179,6 +180,7 @@ export const createStore = (options?: { persist?: boolean; persistDebounce?: num
179180
.concat(api.middleware)
180181
.concat(dynamicMiddlewares)
181182
.concat(authToastMiddleware)
183+
.concat(actionContextMiddleware)
182184
// .concat(getDebugLoggerMiddleware({ withDiff: true, withNextState: true }))
183185
.prepend(listenerMiddleware.middleware),
184186
enhancers: (getDefaultEnhancers) => {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import type { UnknownAction } from '@reduxjs/toolkit';
2+
import type { TabName } from 'features/ui/store/uiTypes';
3+
4+
const TAB_KEY = Symbol('tab');
5+
const CANVAS_ID_KEY = Symbol('canvasId');
6+
7+
type TabActionContext = {
8+
[TAB_KEY]: TabName;
9+
[CANVAS_ID_KEY]?: string;
10+
};
11+
12+
export const injectTabActionContext = (action: UnknownAction, tab: TabName, canvasId?: string) => {
13+
const context: TabActionContext = canvasId ? { [TAB_KEY]: tab, [CANVAS_ID_KEY]: canvasId } : { [TAB_KEY]: tab };
14+
Object.assign(action, { meta: context });
15+
};
16+
17+
export const extractTabActionContext = (action: UnknownAction & { meta?: Partial<TabActionContext> }) => {
18+
const tab = action.meta?.[TAB_KEY];
19+
const canvasId = action.meta?.[CANVAS_ID_KEY];
20+
21+
if (!tab || (tab === 'canvas' && !canvasId)) {
22+
return undefined;
23+
}
24+
25+
return {
26+
tab,
27+
canvasId,
28+
};
29+
};

invokeai/frontend/web/src/common/components/SessionMenuItems.tsx

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { MenuItem } from '@invoke-ai/ui-library';
22
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
33
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
44
import { allEntitiesDeleted, inpaintMaskAdded } from 'features/controlLayers/store/canvasSlice';
5-
import { paramsReset, useParamsDispatch } from 'features/controlLayers/store/paramsSlice';
5+
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
66
import { selectActiveTab } from 'features/ui/store/uiSelectors';
77
import { memo, useCallback } from 'react';
88
import { useTranslation } from 'react-i18next';
@@ -11,7 +11,6 @@ import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
1111
export const SessionMenuItems = memo(() => {
1212
const { t } = useTranslation();
1313
const dispatch = useAppDispatch();
14-
const paramsDispatch = useParamsDispatch();
1514
const tab = useAppSelector(selectActiveTab);
1615
const canvasManager = useCanvasManagerSafe();
1716

@@ -21,8 +20,8 @@ export const SessionMenuItems = memo(() => {
2120
canvasManager?.stage.fitBboxToStage();
2221
}, [dispatch, canvasManager]);
2322
const resetGenerationSettings = useCallback(() => {
24-
paramsDispatch(paramsReset);
25-
}, [paramsDispatch]);
23+
dispatch(paramsReset());
24+
}, [dispatch]);
2625
return (
2726
<>
2827
{tab === 'canvas' && (

invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ import {
88
useToken,
99
} from '@invoke-ai/ui-library';
1010
import { createSelector } from '@reduxjs/toolkit';
11-
import { useAppSelector } from 'app/store/storeHooks';
11+
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
1212
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
1313
import WavyLine from 'common/components/WavyLine';
14-
import { selectImg2imgStrength, setImg2imgStrength, useParamsDispatch } from 'features/controlLayers/store/paramsSlice';
14+
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
1515
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
1616
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
1717
import { memo, useCallback, useMemo } from 'react';
@@ -25,15 +25,15 @@ const selectHasRasterLayersWithContent = createSelector(
2525

2626
export const ParamDenoisingStrength = memo(() => {
2727
const img2imgStrength = useAppSelector(selectImg2imgStrength);
28-
const paramsDispatch = useParamsDispatch();
28+
const dispatch = useAppDispatch();
2929
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
3030
const selectedModelConfig = useSelectedModelConfig();
3131

3232
const onChange = useCallback(
3333
(v: number) => {
34-
paramsDispatch(setImg2imgStrength, v);
34+
dispatch(setImg2imgStrength(v));
3535
},
36-
[paramsDispatch]
36+
[dispatch]
3737
);
3838

3939
const config = useAppSelector(selectImg2imgStrengthConfig);

0 commit comments

Comments
 (0)