-
-
Notifications
You must be signed in to change notification settings - Fork 5k
MobileNetV5 #2527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MobileNetV5 #2527
Conversation
Adds a TSV describing the full params from an Orbax checkpoint Switching to RMS Norm Fixes for correct export Tinkering with 'mobiletnetv5' details, fixing some issues with msfa A few tweaks and comments to example MNV5 impl Update RmsNorm2d modules to use own 2d eager kernel instead of torch rms_norm w/ permute Fix propagation of act_layer to RmsNormAct*, use ConvNormAct for stem instead of just Conv2d Fixes from weights conversion Plumbing norm_layer through to MultiQueryAttention2d impl forward_features for Transformers compatibility Adding forward_* APIs to MobileNetV5Encoder cleanup cleanup, model entrypt rename Large redundant with 300m Update input size for configs Fix stem conv layer name fix: always norm in MSFA Always call final MSFA norm layer Remove some FIXME, fix MSFA docstring. Remove use_layer_scale and rely on values == None, not used currently in any case.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Cool, do you have any plans to pretrain it? |
@JulienMaille possibly yes, I may do something at a 'base' or slighty smaller size as a reference. More near term, I will bring over the image encoder only weights as timm models for fine-tune. Currently the model def is just being utilized as the encoder for the full weights as loaded in transformers. There wasn't time to co-ordinate the other bits before release. |
It should be pointed out that the '300m' size is the only official google size as used in gemma 3n. The base definition there is my own scale down that I used for testing, validation of the architecture. I did some initial epochs of pretrain, etc in sanity checks. Though I might tweak / change that model def before any final weights are trained on my end. |
Switching to RMS Norm
Fixes for correct export
Tinkering with 'mobiletnetv5' details, fixing some issues with msfa
A few tweaks and comments to example MNV5 impl
Update RmsNorm2d modules to use own 2d eager kernel instead of torch rms_norm w/ permute
Fix propagation of act_layer to RmsNormAct*, use ConvNormAct for stem instead of just Conv2d
Fixes from weights conversion
Plumbing norm_layer through to MultiQueryAttention2d
impl forward_features for Transformers compatibility
Adding forward_* APIs to MobileNetV5Encoder
cleanup
cleanup, model entrypt rename
Large redundant with 300m
Update input size for configs
Fix stem conv layer name
fix: always norm in MSFA
Always call final MSFA norm layer
Remove some FIXME, fix MSFA docstring. Remove use_layer_scale and rely on values == None, not used currently in any case.