Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3f3f941
feat(mm): add UnknownModelConfig
psychedelicious Sep 18, 2025
b68871a
refactor(ui): move model categorisation-ish logic to central location…
psychedelicious Sep 18, 2025
bd893cf
refactor(ui)refactor(ui): more cleanup of model categories
psychedelicious Sep 18, 2025
fa47e23
refactor(ui): remove unused excludeSubmodels
psychedelicious Sep 18, 2025
7f9022e
feat(nodes): add unknown as model base
psychedelicious Sep 18, 2025
a87fcfd
chore(ui): typegen
psychedelicious Sep 18, 2025
e348105
feat(ui): add unknown model base support in ui
psychedelicious Sep 18, 2025
3f82c38
feat(ui): allow changing model type in MM, fix up base and variant se…
psychedelicious Sep 18, 2025
c9dd115
feat(mm): omit model description instead of making it "base type file…
psychedelicious Sep 18, 2025
57787e3
feat(app): add setting to allow unknown models
psychedelicious Sep 18, 2025
82409d1
feat(ui): allow changing model format in MM
psychedelicious Sep 18, 2025
39bb60a
feat(app): add the installed model config to install complete events
psychedelicious Sep 18, 2025
d6b72a3
chore(ui): typegen
psychedelicious Sep 18, 2025
b18916d
feat(ui): toast warning when installed model is unidentified
psychedelicious Sep 18, 2025
0159634
docs: update config docstrings
psychedelicious Sep 18, 2025
15e5c9a
chore(ui): typegen
psychedelicious Sep 18, 2025
4070f26
tests(mm): fix test for MM, leave the UnknownModelConfig class in the…
psychedelicious Sep 18, 2025
73bed0d
tidy(ui): prefer types from zod schemas for model attrs
psychedelicious Sep 18, 2025
b9c7c6a
chore(ui): lint
psychedelicious Sep 18, 2025
7f3e5ce
fix(ui): wrong translation string
psychedelicious Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class InvokeAIAppConfig(BaseSettings):
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
"""

_root: Optional[Path] = PrivateAttr(default=None)
Expand Down Expand Up @@ -198,6 +199,7 @@ class InvokeAIAppConfig(BaseSettings):
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
scan_models_on_startup: bool = Field(default=False, description="Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.")
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
allow_unknown_models: bool = Field(default=True, description="Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.")

# fmt: on

Expand Down
9 changes: 8 additions & 1 deletion invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,11 +546,18 @@ class ModelInstallCompleteEvent(ModelEventBase):
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
config: AnyModelConfig = Field(description="The installed model's config")

@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=job.source, key=(job.config_out.key), total_bytes=job.total_bytes)
return cls(
id=job.id,
source=job.source,
key=(job.config_out.key),
total_bytes=job.total_bytes,
config=job.config_out,
)


@payload_schema.register
Expand Down
58 changes: 41 additions & 17 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
from enum import Enum
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union

from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict

from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
Expand All @@ -55,6 +56,7 @@
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES

logger = logging.getLogger(__name__)
app_config = get_config()


class InvalidModelConfigException(Exception):
Expand Down Expand Up @@ -109,6 +111,18 @@ class MatchSpeed(int, Enum):
SLOW = 2


class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""

@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")

@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")


class ModelConfigBase(ABC, BaseModel):
"""
Abstract Base class for model configurations.
Expand All @@ -125,7 +139,7 @@ class ModelConfigBase(ABC, BaseModel):

@staticmethod
def json_schema_extra(schema: dict[str, Any]) -> None:
schema["required"].extend(["key", "type", "format"])
schema["required"].extend(["key", "base", "type", "format"])

model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)

Expand All @@ -152,14 +166,15 @@ def json_schema_extra(schema: dict[str, Any]) -> None:
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")

USING_LEGACY_PROBE: ClassVar[set] = set()
USING_CLASSIFY_API: ClassVar[set] = set()
USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set()
USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
# Cannot use `elif isinstance(cls, UnknownModelConfig)` because UnknownModelConfig is not defined yet
else:
ModelConfigBase.USING_CLASSIFY_API.add(cls)

Expand All @@ -170,7 +185,9 @@ def all_config_classes():
return concrete

@staticmethod
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
def classify(
mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides
) -> "AnyModelConfig":
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Expand All @@ -192,6 +209,13 @@ def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "bla
else:
return config_cls.from_model_on_disk(mod, **overrides)

if app_config.allow_unknown_models:
try:
return UnknownModelConfig.from_model_on_disk(mod, **overrides)
except Exception:
# Fall through to raising the exception below
pass

raise InvalidModelConfigException("Unable to determine model type")

@classmethod
Expand Down Expand Up @@ -240,32 +264,31 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
cls.cast_overrides(overrides)
fields.update(overrides)

type = fields.get("type") or cls.model_fields["type"].default
base = fields.get("base") or cls.model_fields["base"].default

fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = name = fields.get("name") or mod.name
fields["name"] = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
fields["description"] = fields.get("description")
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
fields["file_size"] = fields.get("file_size") or mod.size()

return cls(**fields)


class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
class UnknownModelConfig(ModelConfigBase):
base: Literal[BaseModelType.Unknown] = BaseModelType.Unknown
type: Literal[ModelType.Unknown] = ModelType.Unknown
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown

@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
def matches(cls, mod: ModelOnDisk) -> bool:
return False

@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {}


class CheckpointConfigBase(ABC, BaseModel):
Expand Down Expand Up @@ -353,7 +376,7 @@ def matches(cls, mod: ModelOnDisk) -> bool:

metadata = mod.metadata()
return (
metadata.get("modelspec.sai_model_spec")
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
)
Expand Down Expand Up @@ -751,6 +774,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()],
Annotated[UnknownModelConfig, UnknownModelConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
Expand Down
3 changes: 3 additions & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BaseModelType(str, Enum):
FluxKontext = "flux-kontext"
Veo3 = "veo3"
Runway = "runway"
Unknown = "unknown"


class ModelType(str, Enum):
Expand All @@ -55,6 +56,7 @@ class ModelType(str, Enum):
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
Video = "video"
Unknown = "unknown"


class SubModelType(str, Enum):
Expand Down Expand Up @@ -107,6 +109,7 @@ class ModelFormat(str, Enum):
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
Api = "api"
Unknown = "unknown"


class SchedulerPredictionType(str, Enum):
Expand Down
4 changes: 4 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,9 @@
"hfTokenReset": "HF Token Reset",
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
"urlUnauthorizedErrorMessage2": "Learn how here.",
"unidentifiedModelTitle": "Unable to identify model",
"unidentifiedModelMessage": "We were unable to identify the type, base and/or format of the installed model. Try editing the model and selecting the appropriate settings for the model.",
"unidentifiedModelMessage2": "If you don't see the correct settings, or the model doesn't work after changing them, ask for help on <DiscordLink /> or create an issue on <GitHubIssuesLink />.",
"imageEncoderModelId": "Image Encoder Model ID",
"installedModelsCount": "{{installed}} of {{total}} models installed.",
"includesNModels": "Includes {{n}} models and their dependencies.",
Expand Down Expand Up @@ -942,6 +945,7 @@
"modelConverted": "Model Converted",
"modelDeleted": "Model Deleted",
"modelDeleteFailed": "Failed to delete model",
"modelFormat": "Model Format",
"modelImageDeleted": "Model Image Deleted",
"modelImageDeleteFailed": "Model Image Delete Failed",
"modelImageUpdated": "Model Image Updated",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import {
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models';
import { modelSelected } from 'features/parameters/store/actions';
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/parameters/types/constants';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isCLIPEmbedModelConfigOrSubmodel,
isControlLayerModelConfig,
isControlNetModelConfig,
isFluxReduxModelConfig,
Expand All @@ -48,7 +48,7 @@ import {
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT5EncoderModelConfig,
isT5EncoderModelConfigOrSubmodel,
isVideoModelConfig,
} from 'services/api/types';
import type { JsonObject } from 'type-fest';
Expand Down Expand Up @@ -418,7 +418,7 @@ const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) =

const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfigOrSubmodel(m));

// If the currently selected model is available, we don't need to do anything
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
Expand Down Expand Up @@ -446,7 +446,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {

const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfigOrSubmodel(m));

// If the currently selected model is available, we don't need to do anything
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Konva from 'konva';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { isControlLayerModelConfig } from 'services/api/types';
import stableHash from 'stable-hash';
import type { Equals } from 'tsafe';
Expand Down Expand Up @@ -202,11 +202,19 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
createInitialFilterConfig = (): FilterConfig => {
if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) {
// If the parent is a control layer adapter, we should check if the model has a default filter and set it if so
const selectModelConfig = buildSelectModelConfig(
this.parent.state.controlAdapter.model.key,
isControlLayerModelConfig
);
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
const key = this.parent.state.controlAdapter.model.key;
const modelConfig = this.manager.stateApi.runSelector((state) => {
const { data } = selectModelConfigsQuery(state);
if (!data) {
return null;
}
return (
modelConfigsAdapterSelectors
.selectAll(data)
.filter(isControlLayerModelConfig)
.find((m) => m.key === key) ?? null
);
});
// This always returns a filter
const filter = getFilterForModel(modelConfig) ?? IMAGE_FILTERS.canny_edge_detection;
return filter.buildDefaults();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSl
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
import { API_BASE_MODELS } from 'features/modelManagerV2/models';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import Konva from 'konva';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ import {
getScaledBoundingBoxDimensions,
} from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import { API_BASE_MODELS } from 'features/modelManagerV2/models';
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import type { UndoableOptions } from 'redux-undo';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ import {
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import {
API_BASE_MODELS,
CLIP_SKIP_MAP,
SUPPORTS_ASPECT_RATIO_BASE_MODELS,
SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS,
SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS,
SUPPORTS_PIXEL_DIMENSIONS_BASE_MODELS,
SUPPORTS_REF_IMAGES_BASE_MODELS,
SUPPORTS_SEED_BASE_MODELS,
} from 'features/parameters/types/constants';
} from 'features/modelManagerV2/models';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
import type {
ParameterCanvasCoherenceMode,
ParameterCFGRescaleMultiplier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
import type { GroupStatusMap } from 'common/components/Picker/Picker';
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { API_BASE_MODELS } from 'features/modelManagerV2/models';
import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
Expand Down
4 changes: 2 additions & 2 deletions invokeai/frontend/web/src/features/metadata/parsing.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import {
zVideoDuration,
zVideoResolution,
} from 'features/controlLayers/store/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { ModelIdentifierField, ModelType } from 'features/nodes/types/common';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { zModelIdentifier } from 'features/nodes/types/v2/common';
import { modelSelected } from 'features/parameters/store/actions';
Expand Down Expand Up @@ -108,7 +108,7 @@ import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, ModelType } from 'services/api/types';
import type { AnyModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import z from 'zod';

Expand Down
Loading