您好,欢迎来到游6网!

当前位置:首页 > AI > 使用PaddlePaddle2.0高层API完成基于VGG16的图像分类任务

使用PaddlePaddle2.0高层API完成基于VGG16的图像分类任务

发布时间:2025-07-23    编辑:游乐网

本文介绍如何用PaddlePaddle2.0高层API,基于VGG16完成Cifar10图像分类任务。包括利用高层API简化模型组网与训练,加载VGG16网络并查看结构,加载Cifar10数据集及数据增强,还讲解了用高层API进行模型训练、验证与测试的过程,最后提及使用时的注意事项。

使用paddlepaddle2.0高层api完成基于vgg16的图像分类任务 - 游乐网

使用PaddlePaddle2.0高层API完成基于VGG16的图像分类任务

本示例教程将会演示如何使用飞桨的卷积神经网络来完成目标检测任务。这是一个较为简单的示例,将会使用飞桨框架内置模型VGG16网络完成Cifar10数据集的图像分类任务。

一、PaddlePaddle2.0新亮点——高层API助力开发者快速上手深度学习

飞桨致力于让深度学习技术的创新与应用更简单

1.模型组网更简单

对于新手来说,完全可以省去以往复杂的组网代码,一行代码便可以完成组网。

目前PaddlePaddle2.0-rc1的内置模型有:

'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet'

使用一行代码便可以加载:

ModelName = paddle.vision.models.ModelName()
登录后复制        

(将ModelName替换成上面的模型名称即可,模型名称后面别忘了加括号!!!)

2.模型训练更简单

PaddlePaddle2.0-rc1增加了paddle.Model高层API,大部分任务可以使用此API用于简化训练、评估、预测类代码开发。

使用两句代码便可以训练模型:

# 训练前准备ModelName.prepare(    paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()),    paddle.nn.CrossEntropyLoss(),    paddle.metric.Accuracy(topk=(1, 2))    )# 启动训练ModelName.fit(train_dataset, epochs=2, batch_size=64, log_freq=200)
登录后复制    

二、使用飞桨快速加载VGG网络并查看模型结构

Very Deep Convolutional Networks For Large-Scale Image Recognition 论文地址:https://arxiv.org/pdf/1409.1556.pdf

1.查看飞桨框架内置模型

In [ ]
import paddleprint('飞桨框架内置模型:', paddle.vision.models.__all__)
登录后复制        
飞桨框架内置模型: ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet']
登录后复制        

2.一行代码加载VGG16

In [ ]
vgg16 = paddle.vision.models.vgg16()
登录后复制    

3.查看VGG16的网络结构及参数

In [ ]
paddle.summary(vgg16, (64, 3, 32, 32))
登录后复制        
-------------------------------------------------------------------------------   Layer (type)         Input Shape          Output Shape         Param #    ===============================================================================     Conv2D-1        [[64, 3, 32, 32]]     [64, 64, 32, 32]        1,792           ReLU-1         [[64, 64, 32, 32]]    [64, 64, 32, 32]          0            Conv2D-2        [[64, 64, 32, 32]]    [64, 64, 32, 32]       36,928           ReLU-2         [[64, 64, 32, 32]]    [64, 64, 32, 32]          0           MaxPool2D-1      [[64, 64, 32, 32]]    [64, 64, 16, 16]          0            Conv2D-3        [[64, 64, 16, 16]]   [64, 128, 16, 16]       73,856           ReLU-3        [[64, 128, 16, 16]]   [64, 128, 16, 16]          0            Conv2D-4       [[64, 128, 16, 16]]   [64, 128, 16, 16]       147,584          ReLU-4        [[64, 128, 16, 16]]   [64, 128, 16, 16]          0           MaxPool2D-2     [[64, 128, 16, 16]]    [64, 128, 8, 8]           0            Conv2D-5        [[64, 128, 8, 8]]     [64, 256, 8, 8]        295,168          ReLU-5         [[64, 256, 8, 8]]     [64, 256, 8, 8]           0            Conv2D-6        [[64, 256, 8, 8]]     [64, 256, 8, 8]        590,080          ReLU-6         [[64, 256, 8, 8]]     [64, 256, 8, 8]           0            Conv2D-7        [[64, 256, 8, 8]]     [64, 256, 8, 8]        590,080          ReLU-7         [[64, 256, 8, 8]]     [64, 256, 8, 8]           0           MaxPool2D-3      [[64, 256, 8, 8]]     [64, 256, 4, 4]           0            Conv2D-8        [[64, 256, 4, 4]]     [64, 512, 4, 4]       1,180,160         ReLU-8         [[64, 512, 4, 4]]     [64, 512, 4, 4]           0            Conv2D-9        [[64, 512, 4, 4]]     [64, 512, 4, 4]       2,359,808         ReLU-9         [[64, 512, 4, 4]]     [64, 512, 4, 4]           0            Conv2D-10       [[64, 512, 4, 4]]     [64, 512, 4, 4]       2,359,808         ReLU-10        [[64, 512, 4, 4]]     [64, 512, 4, 4]           0           MaxPool2D-4      [[64, 512, 4, 4]]     [64, 512, 2, 2]           0            Conv2D-11       [[64, 512, 2, 2]]     [64, 512, 2, 2]       2,359,808         ReLU-11        [[64, 512, 2, 2]]     [64, 512, 2, 2]           0            Conv2D-12       [[64, 512, 2, 2]]     [64, 512, 2, 2]       2,359,808         ReLU-12        [[64, 512, 2, 2]]     [64, 512, 2, 2]           0            Conv2D-13       [[64, 512, 2, 2]]     [64, 512, 2, 2]       2,359,808         ReLU-13        [[64, 512, 2, 2]]     [64, 512, 2, 2]           0           MaxPool2D-5      [[64, 512, 2, 2]]     [64, 512, 1, 1]           0       AdaptiveAvgPool2D-1  [[64, 512, 1, 1]]     [64, 512, 7, 7]           0            Linear-1          [[64, 25088]]          [64, 4096]        102,764,544        ReLU-14           [[64, 4096]]          [64, 4096]             0            Dropout-1          [[64, 4096]]          [64, 4096]             0            Linear-2           [[64, 4096]]          [64, 4096]        16,781,312         ReLU-15           [[64, 4096]]          [64, 4096]             0            Dropout-2          [[64, 4096]]          [64, 4096]             0            Linear-3           [[64, 4096]]          [64, 1000]         4,097,000   ===============================================================================Total params: 138,357,544Trainable params: 138,357,544Non-trainable params: 0-------------------------------------------------------------------------------Input size (MB): 0.75Forward/backward pass size (MB): 309.99Params size (MB): 527.79Estimated Total Size (MB): 838.53-------------------------------------------------------------------------------
登录后复制        
{'total_params': 138357544, 'trainable_params': 138357544}
登录后复制                

三、使用飞桨框架API加载数据集

飞桨框架将一些我们常用到的数据集作为领域API对用户进行开放,对应API所在目录为paddle.vision.datasets与paddle.text.datasets

目前已经收录的数据集有:

视觉相关数据集: ['DatasetFolder', 'ImageFolder', 'MNIST', 'FashionMNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']

自然语言相关数据集: ['Conll05st', 'Imdb', 'Imikolov', 'Movielens', 'UCIHousing', 'WMT14', 'WMT16']

1.快速加载数据集

使用一行代码即可加载数据集到本机缓存目录~/.cache/paddle/dataset:

paddle.vision.datasets.DataSetName()

将DataSetName替换成上述数据集名称即可,别忘了名称后面跟一个小括号!

In [ ]
from paddle.vision.transforms import ToTensor# 训练数据集 用ToTensor将数据格式转为Tensortrain_dataset = paddle.vision.datasets.Cifar10(mode='train', transform=ToTensor())# 验证数据集val_dataset = paddle.vision.datasets.Cifar10(mode='test', transform=ToTensor())
登录后复制    

2.对训练数据做数据增强

训练过程中有时会遇到过拟合的问题,其中一个解决方法就是对训练数据做增强,对数据进行处理得到不同的图像,从而泛化数据集。

查看飞桨框架提供的数据增强方法:

In [ ]
import paddleprint('数据处理方法:', paddle.vision.transforms.__all__)
登录后复制        
数据处理方法: ['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']
登录后复制        In [ ]
from paddle.vision.transforms import Compose, Resize, ColorJitter, ToTensor, RandomHorizontalFlip, RandomVerticalFlip, RandomRotationimport numpy as npfrom PIL import Image# 定义想要使用那些数据增强方式,这里用到了随机调整亮度、对比度和饱和度、图像翻转等transform = Compose([ColorJitter(), RandomHorizontalFlip(), ToTensor()])# 通过transform参数传递定义好的数据增项方法即可完成对自带数据集的应用train_dataset = paddle.vision.datasets.Cifar10(mode='train', transform=transform)# 验证数据集val_dataset = paddle.vision.datasets.Cifar10(mode='test', transform=transform)
登录后复制    

这里需要注意的坑是,一定要把ToTensor()放在最后,否则会报错

检查数据集:

In [ ]
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)for batch_id, data in enumerate(train_loader()):    x_data = data[0]    y_data = data[1]    breakprint(x_data.numpy().shape)print(y_data.numpy().shape)
登录后复制        
(64, 3, 32, 32)(64,)
登录后复制        

四、使用高层API进行模型训练、验证与测试

飞桨框架提供了两种训练与预测的方法:

一种是用paddle.Model对模型进行封装,通过高层API如Model.fit()、Model.evaluate()、Model.predict()等完成模型的训练与预测;另一种就是基于基础API常规的训练方式。

使用高层API只需要改动少量参数即可完成模型训练,对新手小白真的特别友好!

1.调用fit()接口来启动训练过程

In [26]
import paddlefrom paddle.vision.transforms import ToTensorfrom paddle.vision.models import vgg16# build modelmodel = vgg16()# build vgg16 model with batch_normmodel = vgg16(batch_norm=True)# 使用高层API——paddle.Model对模型进行封装model = paddle.Model(model)# 为模型训练做准备,设置优化器,损失函数和精度计算方式model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters()),                    loss=paddle.nn.CrossEntropyLoss(),                    metrics=paddle.metric.Accuracy())# 启动模型训练,指定训练数据集,设置训练轮次,设置每次数据集计算的批次大小,设置日志格式model.fit(train_dataset,          epochs=10,          batch_size=256,          save_dir="vgg16/",          save_freq=10,          verbose=1)
登录后复制        
The loss value printed in the log is the current step, and the metric is the average value of previous step.Epoch 1/10step 196/196 [==============================] - loss: 2.3752 - acc: 0.1054 - 202ms/step        save checkpoint at /home/aistudio/vgg16/0Epoch 2/10step 196/196 [==============================] - loss: 2.2819 - acc: 0.1114 - 203ms/step        Epoch 3/10step 196/196 [==============================] - loss: 2.2614 - acc: 0.1306 - 201ms/step        Epoch 4/10step 196/196 [==============================] - loss: 2.2229 - acc: 0.1822 - 195ms/step        Epoch 5/10step 196/196 [==============================] - loss: 1.9132 - acc: 0.1846 - 198ms/step        Epoch 6/10step 196/196 [==============================] - loss: 1.7738 - acc: 0.2000 - 194ms/step        Epoch 7/10step 196/196 [==============================] - loss: 1.8450 - acc: 0.2286 - 195ms/step        Epoch 8/10step 196/196 [==============================] - loss: 1.5770 - acc: 0.2782 - 198ms/step        Epoch 9/10step 196/196 [==============================] - loss: 1.5743 - acc: 0.3446 - 195ms/step        Epoch 10/10step 196/196 [==============================] - loss: 1.4283 - acc: 0.4043 - 201ms/step        save checkpoint at /home/aistudio/vgg16/final
登录后复制        

看到loss在明显下降、acc在明显上升,说明模型效果还不错,剩下需要慢慢调参优化

2.调用evaluate()在测试集上对模型进行验证

对于训练好的模型进行评估操作可以使用evaluate接口来实现,事先定义好用于评估使用的数据集后,可以简单的调用evaluate接口即可完成模型评估操作,结束后根据prepare中loss和metric的定义来进行相关评估结果计算返回。

In [27]
# 用 evaluate 在测试集上对模型进行验证eval_result = model.evaluate(val_dataset, verbose=1)
登录后复制        
Eval begin...The loss value printed in the log is the current batch, and the metric is the average value of previous step.step 10000/10000 [==============================] - loss: 0.2509 - acc: 0.4379 - 10ms/step         Eval samples: 10000
登录后复制        

3.调用predict()接口进行模型测试

高层API中提供了predict接口来方便用户对训练好的模型进行预测验证,只需要基于训练好的模型将需要进行预测测试的数据放到接口中进行计算即可,接口会将经过模型计算得到的预测结果进行返回。

In [28]
# 用 predict 在测试集上对模型进行测试test_result = model.predict(val_dataset)
登录后复制        
Predict begin...step 10000/10000 [==============================] - 10ms/step         Predict samples: 10000
登录后复制        

五、总结与展望——PaddlePaddle2.0 rc1入手指南

给大家总结一下我在使用PaddlePaddle2.0 rc1时遇到的坑,希望大家可以避免:

1.这个项目本来是使用VOC2012数据集进行目标检测任务的训练的,奈何VOC2012数据集的下载速度实在是太慢了,所以我果断放弃,希望后期可以找到解决办法2.最好结合PaddlePaddle2.0的文档和GitHub上的源码来使用,特别是新手小白,不然出现一些报错可能会很难解决,可以多去GitHub上提issue3.使用数据增强 Compose()方法时,切记!一定要把ToTensor()放在最后,这个问题看看源码就能解决

热门合集

MORE

+

MORE

+

变态游戏推荐

MORE

+

热门游戏推荐

MORE

+

关于我们  |  游戏下载排行榜  |  专题合集  |  端游游戏  |  手机游戏  |  联系方式: youleyoucom@outlook.com

Copyright 2013-2019 www.youleyou.com    湘公网安备 43070202000716号

声明:游6网为非赢利性网站 不接受任何赞助和广告 湘ICP备2023003002号-9