|
| 1 | +from ._fx import ( |
| 2 | + create_feature_extractor, |
| 3 | + get_graph_node_names, |
| 4 | + register_notrace_function, |
| 5 | + register_notrace_module, |
| 6 | + is_notrace_module, |
| 7 | + is_notrace_function, |
| 8 | + get_notrace_modules, |
| 9 | + get_notrace_functions, |
| 10 | +) |
1 | 11 | from .activations import *
|
2 |
| -from .adaptive_avgmax_pool import \ |
3 |
| - adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d |
| 12 | +from .adaptive_avgmax_pool import ( |
| 13 | + adaptive_avgmax_pool2d, |
| 14 | + select_adaptive_pool2d, |
| 15 | + AdaptiveAvgMaxPool2d, |
| 16 | + SelectAdaptivePool2d, |
| 17 | +) |
4 | 18 | from .attention import Attention, AttentionRope, maybe_add_mask
|
5 | 19 | from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
6 | 20 | from .attention_pool import AttentionPoolLatent
|
7 | 21 | from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
8 | 22 | from .blur_pool import BlurPool2d, create_aa
|
9 | 23 | from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
|
10 | 24 | from .cond_conv2d import CondConv2d, get_condconv_initializer
|
11 |
| -from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ |
12 |
| - set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \ |
13 |
| - set_reentrant_ckpt, use_reentrant_ckpt |
| 25 | +from .config import ( |
| 26 | + is_exportable, |
| 27 | + is_scriptable, |
| 28 | + is_no_jit, |
| 29 | + use_fused_attn, |
| 30 | + set_exportable, |
| 31 | + set_scriptable, |
| 32 | + set_no_jit, |
| 33 | + set_layer_config, |
| 34 | + set_fused_attn, |
| 35 | + set_reentrant_ckpt, |
| 36 | + use_reentrant_ckpt, |
| 37 | +) |
14 | 38 | from .conv2d_same import Conv2dSame, conv2d_same
|
15 | 39 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
|
16 | 40 | from .create_act import create_act_layer, get_act_layer, get_act_fn
|
|
20 | 44 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
|
21 | 45 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
22 | 46 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
23 |
| -from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ |
24 |
| - EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a |
| 47 | +from .evo_norm import ( |
| 48 | + EvoNorm2dB0, |
| 49 | + EvoNorm2dB1, |
| 50 | + EvoNorm2dB2, |
| 51 | + EvoNorm2dS0, |
| 52 | + EvoNorm2dS0a, |
| 53 | + EvoNorm2dS1, |
| 54 | + EvoNorm2dS1a, |
| 55 | + EvoNorm2dS2, |
| 56 | + EvoNorm2dS2a, |
| 57 | +) |
25 | 58 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
|
26 | 59 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
27 | 60 | from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
|
|
37 | 70 | from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
38 | 71 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
39 | 72 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
|
40 |
| -from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ |
41 |
| - SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d |
| 73 | +from .norm_act import ( |
| 74 | + BatchNormAct2d, |
| 75 | + GroupNormAct, |
| 76 | + GroupNorm1Act, |
| 77 | + LayerNormAct, |
| 78 | + LayerNormAct2d, |
| 79 | + SyncBatchNormAct, |
| 80 | + convert_sync_batchnorm, |
| 81 | + FrozenBatchNormAct2d, |
| 82 | + freeze_batch_norm_2d, |
| 83 | + unfreeze_batch_norm_2d, |
| 84 | +) |
42 | 85 | from .padding import get_padding, get_same_padding, pad_same
|
43 | 86 | from .patch_dropout import PatchDropout
|
44 | 87 | from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
|
45 | 88 | from .pool1d import global_pool_nlc
|
46 | 89 | from .pool2d_same import AvgPool2dSame, create_pool2d
|
47 | 90 | from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
48 |
| -from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ |
49 |
| - resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit |
50 |
| -from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ |
51 |
| - build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ |
52 |
| - FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat |
| 91 | +from .pos_embed_rel import ( |
| 92 | + RelPosMlp, |
| 93 | + RelPosBias, |
| 94 | + RelPosBiasTf, |
| 95 | + gen_relative_position_index, |
| 96 | + gen_relative_log_coords, |
| 97 | + resize_rel_pos_bias_table, |
| 98 | + resize_rel_pos_bias_table_simple, |
| 99 | + resize_rel_pos_bias_table_levit, |
| 100 | +) |
| 101 | +from .pos_embed_sincos import ( |
| 102 | + pixel_freq_bands, |
| 103 | + freq_bands, |
| 104 | + build_sincos2d_pos_embed, |
| 105 | + build_fourier_pos_embed, |
| 106 | + build_rotary_pos_embed, |
| 107 | + apply_rot_embed, |
| 108 | + apply_rot_embed_cat, |
| 109 | + apply_rot_embed_list, |
| 110 | + apply_keep_indices_nlc, |
| 111 | + FourierEmbed, |
| 112 | + RotaryEmbedding, |
| 113 | + RotaryEmbeddingCat, |
| 114 | + RotaryEmbeddingMixed, |
| 115 | + get_mixed_freqs, |
| 116 | +) |
53 | 117 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
54 | 118 | from .selective_kernel import SelectiveKernel
|
55 | 119 | from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
|
60 | 124 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
61 | 125 | from .trace_utils import _assert, _float_to_int
|
62 | 126 | from .typing import LayerType, PadType
|
63 |
| -from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \ |
64 |
| - init_weight_jax, init_weight_vit |
| 127 | +from .weight_init import ( |
| 128 | + trunc_normal_, |
| 129 | + trunc_normal_tf_, |
| 130 | + variance_scaling_, |
| 131 | + lecun_normal_, |
| 132 | + init_weight_jax, |
| 133 | + init_weight_vit, |
| 134 | +) |
0 commit comments