当前位置: 首页
AI
FF Only:Attention真的需要吗?

FF Only:Attention真的需要吗?

热心网友 时间:2025-07-18
转载
本文复现去attention化论文,以Feed-Forward替代Transformer的attention层,基于ViT、DeiT模型在ImageNet表现良好。代码构建相关模型,展示结构与参数,在Cifar10简短训练,表明视觉Transformer中除注意力外的部分可能很重要。

ff only:attention真的需要吗? - 游乐网

前言

hi guy!我们又再一次见面了,这次来复现一篇非常有趣的论文,去attention化

免费影视、动漫、音乐、游戏、小说资源长期稳定更新! 👉 点此立即查看 👈

FF Only:Attention真的需要吗? - 游乐网

作者实验竟然惊讶发现,仅Feed-Forward就能在ImageNet表现良好的性能,这可以帮助研究人员理解为什么当前模型为什么这么有效

具体来说,该模型用FF(fead-forward)替换Transformer的attention层,基于ViT、DeiT的模型获得了良好的top1准确性,如下所示

FF Only:Attention真的需要吗? - 游乐网

完整代码

导入所需要的包

In [ ]
import paddleimport paddle.nn as nnimport paddle.nn.functional as Ffrom functools import partialtrunc_normal_ = nn.initializer.TruncatedNormal(std=.02)zeros_ = nn.initializer.Constant(value=0.)ones_ = nn.initializer.Constant(value=1.)kaiming_normal_ = nn.initializer.KaimingNormal()
登录后复制
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations  def convert_to_list(value, n, name, dtype=np.int):
登录后复制

基础函数定义

In [ ]
def swapdim(x, dim1, dim2):    a = list(range(len(x.shape)))    a[dim1], a[dim2] = a[dim2], a[dim1]    return x.transpose(a)def drop_path(x, drop_prob = 0., training = False):    if drop_prob == 0. or not training:        return x    keep_prob = 1 - drop_prob    shape = (x.shape[0],) + (1,) * (x.ndim - 1)      random_tensor = paddle.to_tensor(keep_prob) + paddle.rand(shape)    random_tensor = paddle.floor(random_tensor)     output = x.divide(keep_prob) * random_tensor    return outputclass DropPath(nn.Layer):    def __init__(self, drop_prob=None):        super(DropPath, self).__init__()        self.drop_prob = drop_prob    def forward(self, x):        return drop_path(x, self.drop_prob, self.training)        class Identity(nn.Layer):                          def __init__(self, *args, **kwargs):        super(Identity, self).__init__()     def forward(self, input):        return inputclass Mlp(nn.Layer):    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):        super().__init__()        out_features = out_features or in_features        hidden_features = hidden_features or in_features        self.fc1 = nn.Linear(in_features, hidden_features)        self.act = act_layer()        self.fc2 = nn.Linear(hidden_features, out_features)        self.drop = nn.Dropout(drop)    def forward(self, x):        x = self.fc1(x)        x = self.act(x)        x = self.drop(x)        x = self.fc2(x)        x = self.drop(x)        return x
登录后复制

模型组网

In [ ]
class LinearBlock(nn.Layer):    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_tokens=197):        super().__init__()        # First stage        self.mlp1 = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)        self.norm1 = norm_layer(dim)        # Second stage        self.mlp2 = Mlp(in_features=num_tokens, hidden_features=int(            num_tokens * mlp_ratio), act_layer=act_layer, drop=drop)        self.norm2 = norm_layer(num_tokens)        # Dropout (or a variant)        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()    def forward(self, x):        x = x + self.drop_path(self.mlp1(self.norm1(x)))        x = swapdim(x, -2, -1)        x = x + self.drop_path(self.mlp2(self.norm2(x)))        x = swapdim(x, -2, -1)        return xclass PatchEmbed(nn.Layer):    """ Wraps a convolution """    def __init__(self, patch_size=16, in_chans=3, embed_dim=768):        super().__init__()        self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)    def forward(self, x):        x = self.proj(x)        return xclass LearnedPositionalEncoding(nn.Layer):    """ Learned positional encoding with dynamic interpolation at runtime """    def __init__(self, height, width, embed_dim):        super().__init__()        self.height = height        self.width = width        self.pos_embed = self.create_parameter(shape=[1, embed_dim, height, width], default_initializer=trunc_normal_)        self.add_parameter("pos_embed", self.pos_embed)        self.cls_pos_embed = self.create_parameter(shape=[1, 1, embed_dim], default_initializer=trunc_normal_)        self.add_parameter("cls_pos_embed", self.cls_pos_embed)    def forward(self, x):        B, C, H, W = x.shape        if H == self.height and W == self.width:            pos_embed = self.pos_embed        else:            pos_embed = F.interpolate(self.pos_embed, size=[H, W], mode='bilinear', align_corners=False) # ??        return self.cls_pos_embed, pos_embedclass LinearVisionTransformer(nn.Layer):    """    Basically the same as the standard Vision Transformer, but with support for resizable     or sinusoidal positional embeddings.     """    def __init__(self, *, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm,                 positional_encoding='learned', learned_positional_encoding_size=(14, 14), block_cls=LinearBlock):        super().__init__()        # Config        self.num_classes = num_classes        self.patch_size = patch_size        self.num_features = self.embed_dim = embed_dim        # Patch embedding        self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)        # Class token        self.cls_token = self.create_parameter(shape=[1, 1, embed_dim], default_initializer=trunc_normal_)        self.add_parameter("cls_token", self.cls_token)        # Positional encoding        if positional_encoding == 'learned':            height, width = self.learned_positional_encoding_size = learned_positional_encoding_size            self.pos_encoding = LearnedPositionalEncoding(height, width, embed_dim)        else:            raise NotImplementedError('Unsupposed positional encoding')        self.pos_drop = nn.Dropout(p=drop_rate)        # Stochastic depth        dpr = [x for x in paddle.linspace(0, drop_path_rate, depth)]        self.blocks = nn.LayerList([            block_cls(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,                      drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, num_tokens=1 + (224 // patch_size)**2)            for i in range(depth)])        self.norm = norm_layer(embed_dim)        # Classifier head        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else Identity()        self.apply(self._init_weights)    def _init_weights(self, m):        if isinstance(m, nn.Linear):            trunc_normal_(m.weight)            if isinstance(m, nn.Linear) and m.bias is not None:                zeros_(m.bias)        elif isinstance(m, nn.LayerNorm):            zeros_(m.bias)            ones_(m.weight)    def forward_features(self, x):        # Patch embedding        B, C, H, W = x.shape  # B x C x H x W        x = self.patch_embed(x)  # B x E x H//p x W//p        # Positional encoding        # NOTE: cls_pos_embed for compatibility with pretrained models        cls_pos_embed, pos_embed = self.pos_encoding(x)        # Flatten image, append class token, add positional encoding        cls_tokens = self.cls_token.expand([B, -1, -1])        x = x.flatten(2)  # flatten        x = swapdim(x , 1, 2)        x = paddle.concat((cls_tokens, x), axis=1)  # class token        pos_embed = pos_embed.flatten(2)            # flatten        pos_embed = swapdim(pos_embed, 1, 2)        pos_embed = paddle.concat([cls_pos_embed, pos_embed], axis=1)  # class pos emb        x = x + pos_embed        x = self.pos_drop(x)        # Transformer        for blk in self.blocks:            x = blk(x)        # Final layernorm        x = self.norm(x)        return x[:, 0]    def forward(self, x):        x = self.forward_features(x)        x = self.head(x)        return x
登录后复制

模型生成

In [ ]
def linear_tiny(**kwargs):    model = LinearVisionTransformer(        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,        norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)    return modeldef linear_base(**kwargs):    model = LinearVisionTransformer(        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,        norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)    return modeldef linear_large(**kwargs):    model = LinearVisionTransformer(        patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,        norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)    return model
登录后复制

模型结构可视化

In [ ]
paddle.Model(linear_base()).summary((1,3,224,224))
登录后复制
--------------------------------------------------------------------------------------------------       Layer (type)             Input Shape               Output Shape               Param #    ==================================================================================================         Conv2D-4            [[1, 3, 224, 224]]         [1, 768, 14, 14]             590,592           PatchEmbed-4          [[1, 3, 224, 224]]         [1, 768, 14, 14]                0       LearnedPositionalEncoding-4  [[1, 768, 14, 14]]  [[1, 1, 768], [1, 768, 14, 14]]     151,296            Dropout-76            [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-76           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-148            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-73             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-77            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-149            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-73              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-37           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-77           [[1, 768, 197]]             [1, 768, 197]                394              Linear-150            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-74             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-78            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-151            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-74              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-37          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-78           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-152            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-75             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-79            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-153            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-75              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-38           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-79           [[1, 768, 197]]             [1, 768, 197]                394              Linear-154            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-76             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-80            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-155            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-76              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-38          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-80           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-156            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-77             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-81            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-157            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-77              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-39           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-81           [[1, 768, 197]]             [1, 768, 197]                394              Linear-158            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-78             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-82            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-159            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-78              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-39          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-82           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-160            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-79             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-83            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-161            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-79              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-40           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-83           [[1, 768, 197]]             [1, 768, 197]                394              Linear-162            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-80             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-84            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-163            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-80              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-40          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-84           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-164            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-81             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-85            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-165            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-81              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-41           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-85           [[1, 768, 197]]             [1, 768, 197]                394              Linear-166            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-82             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-86            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-167            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-82              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-41          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-86           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-168            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-83             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-87            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-169            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-83              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-42           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-87           [[1, 768, 197]]             [1, 768, 197]                394              Linear-170            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-84             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-88            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-171            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-84              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-42          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-88           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-172            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-85             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-89            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-173            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-85              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-43           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-89           [[1, 768, 197]]             [1, 768, 197]                394              Linear-174            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-86             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-90            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-175            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-86              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-43          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-90           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-176            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-87             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-91            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-177            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-87              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-44           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-91           [[1, 768, 197]]             [1, 768, 197]                394              Linear-178            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-88             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-92            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-179            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-88              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-44          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-92           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-180            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-89             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-93            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-181            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-89              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-45           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-93           [[1, 768, 197]]             [1, 768, 197]                394              Linear-182            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-90             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-94            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-183            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-90              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-45          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-94           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-184            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-91             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-95            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-185            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-91              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-46           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-95           [[1, 768, 197]]             [1, 768, 197]                394              Linear-186            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-92             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-96            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-187            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-92              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-46          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-96           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-188            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-93             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-97            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-189            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-93              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-47           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-97           [[1, 768, 197]]             [1, 768, 197]                394              Linear-190            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-94             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-98            [[1, 768, 197]]             [1, 768, 197]                 0               Linear-191            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-94              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-47          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-98           [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-192            [[1, 197, 768]]            [1, 197, 3072]             2,362,368             GELU-95             [[1, 197, 3072]]           [1, 197, 3072]                 0               Dropout-99            [[1, 197, 768]]             [1, 197, 768]                 0               Linear-193            [[1, 197, 3072]]            [1, 197, 768]             2,360,064             Mlp-95              [[1, 197, 768]]             [1, 197, 768]                 0               Identity-48           [[1, 768, 197]]             [1, 768, 197]                 0              LayerNorm-99           [[1, 768, 197]]             [1, 768, 197]                394              Linear-194            [[1, 768, 197]]             [1, 768, 788]              156,024              GELU-96             [[1, 768, 788]]             [1, 768, 788]                 0               Dropout-100           [[1, 768, 197]]             [1, 768, 197]                 0               Linear-195            [[1, 768, 788]]             [1, 768, 197]              155,433              Mlp-96              [[1, 768, 197]]             [1, 768, 197]                 0             LinearBlock-48          [[1, 197, 768]]             [1, 197, 768]                 0              LayerNorm-100          [[1, 197, 768]]             [1, 197, 768]               1,536             Linear-196               [[1, 768]]                 [1, 1000]                769,000    ==================================================================================================Total params: 61,942,252Trainable params: 61,942,252Non-trainable params: 0--------------------------------------------------------------------------------------------------Input size (MB): 0.57Forward/backward pass size (MB): 365.91Params size (MB): 236.29Estimated Total Size (MB): 602.77--------------------------------------------------------------------------------------------------
登录后复制
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/core/fromnumeric.py:87: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
登录后复制
{'total_params': 61942252, 'trainable_params': 61942252}
登录后复制

添加预训练权重

ImageNet-1k validation

In [ ]
# ff tinyff_tiny = linear_tiny()ff_tiny.set_state_dict(paddle.load('/home/aistudio/data/data96150/linear_tiny.pdparams'))# ff baseff_base = linear_base()ff_base.set_state_dict(paddle.load('/home/aistudio/data/data96150/linear_base.pdparams'))
登录后复制

Cifar10 验证性能

采用Cifar10数据集,无过多的数据增强

数据准备

In [ ]
import paddle.vision.transforms as Tfrom paddle.vision.datasets import Cifar10# 开启 GPUpaddle.set_device('gpu')# 数据准备transform = T.Compose([    T.Resize(size=(224,224)),    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),    T.ToTensor()])train_dataset = Cifar10(mode='train', transform=transform)val_dataset = Cifar10(mode='test',  transform=transform)
登录后复制
Cache file /home/aistudio/.cache/paddle/dataset/cifar/cifar-10-python.tar.gz not found, downloading https://dataset.bj.bcebos.com/cifar/cifar-10-python.tar.gz Begin to downloadDownload finished
登录后复制

模型准备

In [17]
ff_base = linear_base(num_classes=10)ff_base.set_state_dict(paddle.load('/home/aistudio/data/data96150/linear_base.pdparams'))model = paddle.Model(ff_base)
登录后复制

开始训练

由于时间篇幅只训练5轮,感兴趣的同学可以继续训练

In [16]
model.prepare(optimizer=paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters()),              loss=paddle.nn.CrossEntropyLoss(),              metrics=paddle.metric.Accuracy())# 开启训练可视化visualdl=paddle.callbacks.VisualDL(log_dir='visual_log') model.fit(    train_data=train_dataset,     eval_data=val_dataset,     batch_size=64,     epochs=5,    verbose=1,    callbacks=[visualdl] )
登录后复制

训练过程可视化

FF Only:Attention真的需要吗? - 游乐网

总结

这些结果表明,除了注意力以外,视觉Transformer的其他方面(例如patch embedding)可能比以前认为的要重要。我们希望这些结果能促使社区花费更多时间来理解为什么我们当前的模型如此有效。

来源:https://www.php.cn/faq/1414136.html

游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。

同类文章
更多
Kyugo

Kyugo

Kyugo Calendar 是什么 市面上日历应用层出不穷,但大多脱不开线性列表或方格矩阵的老路子。这时候,Kyugo Calendar带着它那独特的圆形设计出现了,它想干的,可不只是帮你记个日程那么简单。 由Kyugo团队打造的这款工具,本质上是一个专注生产力的日历。它的野心在于改变我们看待和管

时间:2026-04-22 22:34
Cantrip.io

Cantrip.io

Cantrip io是什么 如果你一直在寻找一个能让网站搭建变得像“魔法”一样简单的工具,那么Cantrip io很可能就是答案。这款由专注用户体验和技术的团队开发的平台,其核心卖点非常明确:为用户,尤其是那些不想操心插件、设计或复杂后台设置的用户,提供一个真正“无痛”的建站体验。它巧妙地将AI内容

时间:2026-04-22 22:34
Blessing Wiki

Blessing Wiki

Blessing Wiki是什么 在数字问候日渐同质化的今天,你是否想过,一条祝福也能真正“为你而生”?这就是Blessing Wiki想回答的问题。它并非出自大厂之手,而是一群由技术爱好者、创意作家和充满同理心的客服人员共同打造的工具。其核心理念很纯粹:将语言的优雅与人工智能的智能相结合,生成那些

时间:2026-04-22 22:33
Datascale

Datascale

Datascale是什么 在数据团队日常工作中,面对成百上千的SQL脚本和错综复杂的数据关系,是种什么体验?想必不少数据库管理员和工程师都深有体会:混乱、耗时且极易出错。好在我们现在有了新的解题思路——Datascale。这是一款由Poom开发的创新型云SQL建模平台,它最厉害的地方在于,能够帮你彻

时间:2026-04-22 22:33
Ecomtent

Ecomtent

Ecomtent AI是什么 当你在亚马逊、谷歌或eBay上浏览产品时,有没有想过,那些抓人眼球的图片和文案是怎么来的?背后很可能有AI的助力。Ecomtent AI正是这样一款工具,专为优化电商产品内容而生。它由Ecomtent公司开发,能自动生成高质量的图片、信息图表和文案,核心目标就一个:显著

时间:2026-04-22 22:32
热门专题
更多
刀塔传奇破解版无限钻石下载大全 刀塔传奇破解版无限钻石下载大全
洛克王国正式正版手游下载安装大全 洛克王国正式正版手游下载安装大全
思美人手游下载专区 思美人手游下载专区
好玩的阿拉德之怒游戏下载合集 好玩的阿拉德之怒游戏下载合集
不思议迷宫手游下载合集 不思议迷宫手游下载合集
百宝袋汉化组游戏最新合集 百宝袋汉化组游戏最新合集
jsk游戏合集30款游戏大全 jsk游戏合集30款游戏大全
宾果消消消原版下载大全 宾果消消消原版下载大全
  • 日榜
  • 周榜
  • 月榜
热门教程
更多
  • 游戏攻略
  • 安卓教程
  • 苹果教程
  • 电脑教程