基于Attention U-Net的宠物图像分割
本文基于《Attention U-Net: Learning Where to Look for the Pancreas》,实现了用于宠物图像分割的Attention U-Net模型。通过划分数据集,构建含注意力门的网络结构,用RMSProp优化器和交叉熵损失训练,经15轮后在测试集上预测,结果展示了模型对宠物图像的分割效果,验证了其有效性。

基于Attention U-Net的宠物图像分割
论文:Attention U-Net: Learning Where to Look for the Pancreas
免费影视、动漫、音乐、游戏、小说资源长期稳定更新! 👉 点此立即查看 👈
简介
首次在医学图像的CNN中使用Soft Attention,该模块可以替代分类任务中的Hard attention和器官定位任务中的定位模块。Attention U-Net是一种新的用于医学成像的注意门(AG)模型,该模型自动学习聚焦于不同形状和大小的目标结构。隐含地学习抑制输入图像中不相关的区域,同时突出对特定任务有用的显著特征。Attention模块只需很小的计算开销,同时提高了模型的灵敏度和预测精度。效果
模型结构
Attention Gate模块
Attention的意思是,把注意力放到目标区域上,简单来说就是让目标区域的值变大。Attention模块用在了skip connection上,原始U-Net只是单纯的把同层的下采样层的特征直接concate到上采样层中,改进后的使用attention模块对下采样层同层和上采样层上一层的特征图进行处理后再和上采样后的特征图进行concate
环境设置
In [1]import osimport ioimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image as PilImageimport paddleimport paddle.nn as nnimport paddle.nn.functional as Fpaddle.set_device('gpu')paddle.__version__登录后复制 '2.1.0'登录后复制
数据处理
此处数据处理部分借鉴了『跟着雨哥学AI』系列06:趣味案例——基于U-Net的宠物图像分割
In [2]# 解压缩!tar -xf data/data50154/images.tar.gz!tar -xf data/data50154/annotations.tar.gz登录后复制 In [3]
IMAGE_SIZE = (160, 160)train_images_path = "images/"label_images_path = "annotations/trimaps/"image_count = len([os.path.join(train_images_path, image_name) for image_name in os.listdir(train_images_path) if image_name.endswith('.webp')])print("用于训练的图片样本数量:", image_count)# 对数据集进行处理,划分训练集、测试集def _sort_images(image_dir, image_type): """ 对文件夹内的图像进行按照文件名排序 """ files = [] for image_name in os.listdir(image_dir): if image_name.endswith('.{}'.format(image_type)) \ and not image_name.startswith('.'): files.append(os.path.join(image_dir, image_name)) return sorted(files)def write_file(mode, images, labels): with open('./{}.txt'.format(mode), 'w') as f: for i in range(len(images)): f.write('{}\t{}\n'.format(images[i], labels[i])) images = _sort_images(train_images_path, 'jpg')labels = _sort_images(label_images_path, 'png')eval_num = int(image_count * 0.15)write_file('train', images[:-eval_num], labels[:-eval_num])write_file('test', images[-eval_num:], labels[-eval_num:])write_file('predict', images[-eval_num:], labels[-eval_num:])登录后复制 用于训练的图片样本数量: 7390登录后复制 In [4]
with open('./train.txt', 'r') as f: i = 0 for line in f.readlines(): image_path, label_path = line.strip().split('\t') image = np.array(PilImage.open(image_path)) label = np.array(PilImage.open(label_path)) if i > 2: break # 进行图片的展示 plt.figure() plt.subplot(1,2,1), plt.title('Train Image') plt.imshow(image.astype('uint8')) plt.axis('off') plt.subplot(1,2,2), plt.title('Label') plt.imshow(label.astype('uint8'), cmap='gray') plt.axis('off') plt.show() i = i + 1登录后复制 /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator):/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead a_min = np.asscalar(a_min.astype(scaled_dtype))/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead a_max = np.asscalar(a_max.astype(scaled_dtype))登录后复制
登录后复制登录后复制登录后复制
登录后复制登录后复制登录后复制
登录后复制登录后复制登录后复制
数据集类定义
In [5]import randomfrom paddle.io import Datasetfrom paddle.vision.transforms import transforms as Tclass PetDataset(Dataset): """ 数据集定义 """ def __init__(self, mode='train'): """ 构造函数 """ self.image_size = IMAGE_SIZE self.mode = mode.lower() assert self.mode in ['train', 'test', 'predict'], \ "mode should be 'train' or 'test' or 'predict', but got {}".format(self.mode) self.train_images = [] self.label_images = [] with open('./{}.txt'.format(self.mode), 'r') as f: for line in f.readlines(): image, label = line.strip().split('\t') self.train_images.append(image) self.label_images.append(label) def _load_img(self, path, color_mode='rgb', transforms=[]): """ 统一的图像处理接口封装,用于规整图像大小和通道 """ with open(path, 'rb') as f: img = PilImage.open(io.BytesIO(f.read())) if color_mode == 'grayscale': # if image is not already an 8-bit, 16-bit or 32-bit grayscale image # convert it to an 8-bit grayscale image. if img.mode not in ('L', 'I;16', 'I'): img = img.convert('L') elif color_mode == 'rgba': if img.mode != 'RGBA': img = img.convert('RGBA') elif color_mode == 'rgb': if img.mode != 'RGB': img = img.convert('RGB') else: raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') return T.Compose([ T.Resize(self.image_size) ] + transforms)(img) def __getitem__(self, idx): """ 返回 image, label """ train_image = self._load_img(self.train_images[idx], transforms=[ T.Transpose(), T.Normalize(mean=127.5, std=127.5) ]) # 加载原始图像 label_image = self._load_img(self.label_images[idx], color_mode='grayscale', transforms=[T.Grayscale()]) # 加载Label图像 # 返回image, label train_image = np.array(train_image, dtype='float32') label_image = np.array(label_image, dtype='int64') return train_image, label_image def __len__(self): """ 返回数据集总数 """ return len(self.train_images)登录后复制 模型组网
基础模块
In [6]class conv_block(nn.Layer): def __init__(self, ch_in, ch_out): super(conv_block, self).__init__() self.conv = nn.Sequential( nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1), nn.BatchNorm(ch_out), nn.ReLU(), nn.Conv2D(ch_out, ch_out, kernel_size=3, stride=1, padding=1), nn.BatchNorm(ch_out), nn.ReLU() ) def forward(self, x): x = self.conv(x) return xclass up_conv(nn.Layer): def __init__(self, ch_in, ch_out): super(up_conv, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1), nn.BatchNorm(ch_out), nn.ReLU() ) def forward(self, x): x = self.up(x) return xclass single_conv(nn.Layer): def __init__(self, ch_in, ch_out): super(single_conv, self).__init__() self.conv = nn.Sequential( nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1), nn.BatchNorm(ch_out), nn.ReLU() ) def forward(self, x): x = self.conv(x) return x登录后复制
Attention块
In [7]class Attention_block(nn.Layer): def __init__(self, F_g, F_l, F_int): super(Attention_block, self).__init__() self.W_g = nn.Sequential( nn.Conv2D(F_g, F_int, kernel_size=1, stride=1, padding=0), nn.BatchNorm(F_int) ) self.W_x = nn.Sequential( nn.Conv2D(F_l, F_int, kernel_size=1, stride=1, padding=0), nn.BatchNorm(F_int) ) self.psi = nn.Sequential( nn.Conv2D(F_int, 1, kernel_size=1, stride=1, padding=0), nn.BatchNorm(1), nn.Sigmoid() ) self.relu = nn.ReLU() def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi登录后复制
Attention U-Net
In [9]class AttU_Net(nn.Layer): def __init__(self, img_ch=3, output_ch=1): super(AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2D(kernel_size=2, stride=2) self.Maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2) self.Maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2) self.Maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2) self.Conv1 = conv_block(ch_in=img_ch, ch_out=64) self.Conv2 = conv_block(ch_in=64, ch_out=128) self.Conv3 = conv_block(ch_in=128, ch_out=256) self.Conv4 = conv_block(ch_in=256, ch_out=512) self.Conv5 = conv_block(ch_in=512, ch_out=1024) self.Up5 = up_conv(ch_in=1024, ch_out=512) self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) self.Up4 = up_conv(ch_in=512, ch_out=256) self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) self.Up_conv4 = conv_block(ch_in=512, ch_out=256) self.Up3 = up_conv(ch_in=256, ch_out=128) self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) self.Up_conv3 = conv_block(ch_in=256, ch_out=128) self.Up2 = up_conv(ch_in=128, ch_out=64) self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) self.Up_conv2 = conv_block(ch_in=128, ch_out=64) self.Conv_1x1 = nn.Conv2D(64, output_ch, kernel_size=1, stride=1, padding=0) def forward(self, x): # encoding path x1 = self.Conv1(x) x2 = self.Maxpool(x1) x2 = self.Conv2(x2) x3 = self.Maxpool1(x2) x3 = self.Conv3(x3) x4 = self.Maxpool2(x3) x4 = self.Conv4(x4) x5 = self.Maxpool3(x4) x5 = self.Conv5(x5) # decoding + concat path d5 = self.Up5(x5) x4 = self.Att5(g=d5, x=x4) d5 = paddle.concat(x=[x4, d5], axis=1) d5 = self.Up_conv5(d5) d4 = self.Up4(d5) x3 = self.Att4(g=d4, x=x3) d4 = paddle.concat(x=[x3, d4], axis=1) d4 = self.Up_conv4(d4) d3 = self.Up3(d4) x2 = self.Att3(g=d3, x=x2) d3 = paddle.concat(x=[x2, d3], axis=1) d3 = self.Up_conv3(d3) d2 = self.Up2(d3) x1 = self.Att2(g=d2, x=x1) d2 = paddle.concat(x=[x1, d2], axis=1) d2 = self.Up_conv2(d2) d1 = self.Conv_1x1(d2) return d1登录后复制
模型可视化
In [10]num_classes = 4network = AttU_Net(img_ch=3, output_ch=num_classes)model = paddle.Model(network)model.summary((-1, 3,) + IMAGE_SIZE)登录后复制
----------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # ============================================================================= Conv2D-1 [[1, 3, 160, 160]] [1, 64, 160, 160] 1,792 BatchNorm-1 [[1, 64, 160, 160]] [1, 64, 160, 160] 256 ReLU-1 [[1, 64, 160, 160]] [1, 64, 160, 160] 0 Conv2D-2 [[1, 64, 160, 160]] [1, 64, 160, 160] 36,928 BatchNorm-2 [[1, 64, 160, 160]] [1, 64, 160, 160] 256 ReLU-2 [[1, 64, 160, 160]] [1, 64, 160, 160] 0 conv_block-1 [[1, 3, 160, 160]] [1, 64, 160, 160] 0 MaxPool2D-1 [[1, 64, 160, 160]] [1, 64, 80, 80] 0 Conv2D-3 [[1, 64, 80, 80]] [1, 128, 80, 80] 73,856 BatchNorm-3 [[1, 128, 80, 80]] [1, 128, 80, 80] 512 ReLU-3 [[1, 128, 80, 80]] [1, 128, 80, 80] 0 Conv2D-4 [[1, 128, 80, 80]] [1, 128, 80, 80] 147,584 BatchNorm-4 [[1, 128, 80, 80]] [1, 128, 80, 80] 512 ReLU-4 [[1, 128, 80, 80]] [1, 128, 80, 80] 0 conv_block-2 [[1, 64, 80, 80]] [1, 128, 80, 80] 0 MaxPool2D-2 [[1, 128, 80, 80]] [1, 128, 40, 40] 0 Conv2D-5 [[1, 128, 40, 40]] [1, 256, 40, 40] 295,168 BatchNorm-5 [[1, 256, 40, 40]] [1, 256, 40, 40] 1,024 ReLU-5 [[1, 256, 40, 40]] [1, 256, 40, 40] 0 Conv2D-6 [[1, 256, 40, 40]] [1, 256, 40, 40] 590,080 BatchNorm-6 [[1, 256, 40, 40]] [1, 256, 40, 40] 1,024 ReLU-6 [[1, 256, 40, 40]] [1, 256, 40, 40] 0 conv_block-3 [[1, 128, 40, 40]] [1, 256, 40, 40] 0 MaxPool2D-3 [[1, 256, 40, 40]] [1, 256, 20, 20] 0 Conv2D-7 [[1, 256, 20, 20]] [1, 512, 20, 20] 1,180,160 BatchNorm-7 [[1, 512, 20, 20]] [1, 512, 20, 20] 2,048 ReLU-7 [[1, 512, 20, 20]] [1, 512, 20, 20] 0 Conv2D-8 [[1, 512, 20, 20]] [1, 512, 20, 20] 2,359,808 BatchNorm-8 [[1, 512, 20, 20]] [1, 512, 20, 20] 2,048 ReLU-8 [[1, 512, 20, 20]] [1, 512, 20, 20] 0 conv_block-4 [[1, 256, 20, 20]] [1, 512, 20, 20] 0 MaxPool2D-4 [[1, 512, 20, 20]] [1, 512, 10, 10] 0 Conv2D-9 [[1, 512, 10, 10]] [1, 1024, 10, 10] 4,719,616 BatchNorm-9 [[1, 1024, 10, 10]] [1, 1024, 10, 10] 4,096 ReLU-9 [[1, 1024, 10, 10]] [1, 1024, 10, 10] 0 Conv2D-10 [[1, 1024, 10, 10]] [1, 1024, 10, 10] 9,438,208 BatchNorm-10 [[1, 1024, 10, 10]] [1, 1024, 10, 10] 4,096 ReLU-10 [[1, 1024, 10, 10]] [1, 1024, 10, 10] 0 conv_block-5 [[1, 512, 10, 10]] [1, 1024, 10, 10] 0 Upsample-1 [[1, 1024, 10, 10]] [1, 1024, 20, 20] 0 Conv2D-11 [[1, 1024, 20, 20]] [1, 512, 20, 20] 4,719,104 BatchNorm-11 [[1, 512, 20, 20]] [1, 512, 20, 20] 2,048 ReLU-11 [[1, 512, 20, 20]] [1, 512, 20, 20] 0 up_conv-1 [[1, 1024, 10, 10]] [1, 512, 20, 20] 0 Conv2D-12 [[1, 512, 20, 20]] [1, 256, 20, 20] 131,328 BatchNorm-12 [[1, 256, 20, 20]] [1, 256, 20, 20] 1,024 Conv2D-13 [[1, 512, 20, 20]] [1, 256, 20, 20] 131,328 BatchNorm-13 [[1, 256, 20, 20]] [1, 256, 20, 20] 1,024 ReLU-12 [[1, 256, 20, 20]] [1, 256, 20, 20] 0 Conv2D-14 [[1, 256, 20, 20]] [1, 1, 20, 20] 257 BatchNorm-14 [[1, 1, 20, 20]] [1, 1, 20, 20] 4 Sigmoid-1 [[1, 1, 20, 20]] [1, 1, 20, 20] 0 Attention_block-1 [] [1, 512, 20, 20] 0 Conv2D-15 [[1, 1024, 20, 20]] [1, 512, 20, 20] 4,719,104 BatchNorm-15 [[1, 512, 20, 20]] [1, 512, 20, 20] 2,048 ReLU-13 [[1, 512, 20, 20]] [1, 512, 20, 20] 0 Conv2D-16 [[1, 512, 20, 20]] [1, 512, 20, 20] 2,359,808 BatchNorm-16 [[1, 512, 20, 20]] [1, 512, 20, 20] 2,048 ReLU-14 [[1, 512, 20, 20]] [1, 512, 20, 20] 0 conv_block-6 [[1, 1024, 20, 20]] [1, 512, 20, 20] 0 Upsample-2 [[1, 512, 20, 20]] [1, 512, 40, 40] 0 Conv2D-17 [[1, 512, 40, 40]] [1, 256, 40, 40] 1,179,904 BatchNorm-17 [[1, 256, 40, 40]] [1, 256, 40, 40] 1,024 ReLU-15 [[1, 256, 40, 40]] [1, 256, 40, 40] 0 up_conv-2 [[1, 512, 20, 20]] [1, 256, 40, 40] 0 Conv2D-18 [[1, 256, 40, 40]] [1, 128, 40, 40] 32,896 BatchNorm-18 [[1, 128, 40, 40]] [1, 128, 40, 40] 512 Conv2D-19 [[1, 256, 40, 40]] [1, 128, 40, 40] 32,896 BatchNorm-19 [[1, 128, 40, 40]] [1, 128, 40, 40] 512 ReLU-16 [[1, 128, 40, 40]] [1, 128, 40, 40] 0 Conv2D-20 [[1, 128, 40, 40]] [1, 1, 40, 40] 129 BatchNorm-20 [[1, 1, 40, 40]] [1, 1, 40, 40] 4 Sigmoid-2 [[1, 1, 40, 40]] [1, 1, 40, 40] 0 Attention_block-2 [] [1, 256, 40, 40] 0 Conv2D-21 [[1, 512, 40, 40]] [1, 256, 40, 40] 1,179,904 BatchNorm-21 [[1, 256, 40, 40]] [1, 256, 40, 40] 1,024 ReLU-17 [[1, 256, 40, 40]] [1, 256, 40, 40] 0 Conv2D-22 [[1, 256, 40, 40]] [1, 256, 40, 40] 590,080 BatchNorm-22 [[1, 256, 40, 40]] [1, 256, 40, 40] 1,024 ReLU-18 [[1, 256, 40, 40]] [1, 256, 40, 40] 0 conv_block-7 [[1, 512, 40, 40]] [1, 256, 40, 40] 0 Upsample-3 [[1, 256, 40, 40]] [1, 256, 80, 80] 0 Conv2D-23 [[1, 256, 80, 80]] [1, 128, 80, 80] 295,040 BatchNorm-23 [[1, 128, 80, 80]] [1, 128, 80, 80] 512 ReLU-19 [[1, 128, 80, 80]] [1, 128, 80, 80] 0 up_conv-3 [[1, 256, 40, 40]] [1, 128, 80, 80] 0 Conv2D-24 [[1, 128, 80, 80]] [1, 64, 80, 80] 8,256 BatchNorm-24 [[1, 64, 80, 80]] [1, 64, 80, 80] 256 Conv2D-25 [[1, 128, 80, 80]] [1, 64, 80, 80] 8,256 BatchNorm-25 [[1, 64, 80, 80]] [1, 64, 80, 80] 256 ReLU-20 [[1, 64, 80, 80]] [1, 64, 80, 80] 0 Conv2D-26 [[1, 64, 80, 80]] [1, 1, 80, 80] 65 BatchNorm-26 [[1, 1, 80, 80]] [1, 1, 80, 80] 4 Sigmoid-3 [[1, 1, 80, 80]] [1, 1, 80, 80] 0 Attention_block-3 [] [1, 128, 80, 80] 0 Conv2D-27 [[1, 256, 80, 80]] [1, 128, 80, 80] 295,040 BatchNorm-27 [[1, 128, 80, 80]] [1, 128, 80, 80] 512 ReLU-21 [[1, 128, 80, 80]] [1, 128, 80, 80] 0 Conv2D-28 [[1, 128, 80, 80]] [1, 128, 80, 80] 147,584 BatchNorm-28 [[1, 128, 80, 80]] [1, 128, 80, 80] 512 ReLU-22 [[1, 128, 80, 80]] [1, 128, 80, 80] 0 conv_block-8 [[1, 256, 80, 80]] [1, 128, 80, 80] 0 Upsample-4 [[1, 128, 80, 80]] [1, 128, 160, 160] 0 Conv2D-29 [[1, 128, 160, 160]] [1, 64, 160, 160] 73,792 BatchNorm-29 [[1, 64, 160, 160]] [1, 64, 160, 160] 256 ReLU-23 [[1, 64, 160, 160]] [1, 64, 160, 160] 0 up_conv-4 [[1, 128, 80, 80]] [1, 64, 160, 160] 0 Conv2D-30 [[1, 64, 160, 160]] [1, 32, 160, 160] 2,080 BatchNorm-30 [[1, 32, 160, 160]] [1, 32, 160, 160] 128 Conv2D-31 [[1, 64, 160, 160]] [1, 32, 160, 160] 2,080 BatchNorm-31 [[1, 32, 160, 160]] [1, 32, 160, 160] 128 ReLU-24 [[1, 32, 160, 160]] [1, 32, 160, 160] 0 Conv2D-32 [[1, 32, 160, 160]] [1, 1, 160, 160] 33 BatchNorm-32 [[1, 1, 160, 160]] [1, 1, 160, 160] 4 Sigmoid-4 [[1, 1, 160, 160]] [1, 1, 160, 160] 0 Attention_block-4 [] [1, 64, 160, 160] 0 Conv2D-33 [[1, 128, 160, 160]] [1, 64, 160, 160] 73,792 BatchNorm-33 [[1, 64, 160, 160]] [1, 64, 160, 160] 256 ReLU-25 [[1, 64, 160, 160]] [1, 64, 160, 160] 0 Conv2D-34 [[1, 64, 160, 160]] [1, 64, 160, 160] 36,928 BatchNorm-34 [[1, 64, 160, 160]] [1, 64, 160, 160] 256 ReLU-26 [[1, 64, 160, 160]] [1, 64, 160, 160] 0 conv_block-9 [[1, 128, 160, 160]] [1, 64, 160, 160] 0 Conv2D-35 [[1, 64, 160, 160]] [1, 4, 160, 160] 260 =============================================================================Total params: 34,894,392Trainable params: 34,863,144Non-trainable params: 31,248-----------------------------------------------------------------------------Input size (MB): 0.29Forward/backward pass size (MB): 563.67Params size (MB): 133.11Estimated Total Size (MB): 697.07-----------------------------------------------------------------------------登录后复制
{'total_params': 34894392, 'trainable_params': 34863144}登录后复制 模型训练
In [11]train_dataset = PetDataset(mode='train') # 训练数据集val_dataset = PetDataset(mode='test') # 验证数据集optim = paddle.optimizer.RMSProp(learning_rate=0.001, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False, parameters=model.parameters())model.prepare(optim, paddle.nn.CrossEntropyLoss(axis=1))model.fit(train_dataset, val_dataset, epochs=15, batch_size=32, verbose=1)登录后复制
模型预测
In [12]predict_dataset = PetDataset(mode='predict')predict_results = model.predict(predict_dataset)登录后复制
Predict begin...step 1108/1108 [==============================] - 20ms/step Predict samples: 1108登录后复制 In [13]
plt.figure(figsize=(10, 10))i = 0mask_idx = 0with open('./predict.txt', 'r') as f: for line in f.readlines(): image_path, label_path = line.strip().split('\t') resize_t = T.Compose([ T.Resize(IMAGE_SIZE) ]) image = resize_t(PilImage.open(image_path)) label = resize_t(PilImage.open(label_path)) image = np.array(image).astype('uint8') label = np.array(label).astype('uint8') if i > 8: break plt.subplot(3, 3, i + 1) plt.imshow(image) plt.title('Input Image') plt.axis("off") plt.subplot(3, 3, i + 2) plt.imshow(label, cmap='gray') plt.title('Label') plt.axis("off") # 模型只有一个输出,通过predict_results[0]来取出1000个预测的结果 # 映射原始图片的index来取出预测结果,提取mask进行展示 data = predict_results[0][mask_idx][0].transpose((1, 2, 0)) mask = np.argmax(data, axis=-1) plt.subplot(3, 3, i + 3) plt.imshow(mask.astype('uint8'), cmap='gray') plt.title('Predict') plt.axis("off") i += 3 mask_idx += 1plt.show()登录后复制 登录后复制
游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。
同类文章
LightPDF
LightPDF是什么 提到在线处理PDF,很多人首先想到的可能是一系列繁琐的步骤和复杂的软件。但现在,这事儿变得简单多了。LightPDF,这款由开发者推出的AI智能在线工具,可以说是一个集大成者。它本质上是一个基于云服务的PDF解决方案,集编辑、转换、阅读于一体,并且完全免费使用。它的目标很明确
isitPossible AI
isitPossible AI是什么 在当今快速迭代的创业市场里,一个想法究竟能不能落地成真,往往考验着产品经理和创业者的核心判断。而isitPossible AI,就是为此而生的一款专业工具。它由资深的AI专家团队打造,专门帮助产品经理、创业者以及市场评估专业人士,系统性地分析新产品的推出可行性。
MyQRCode
Free AI QR Code Generator by MyQRCode是什么 简单来说,Free AI QR Code Generator by MyQRCode就是一个能让你轻松做出漂亮二维码的智能工具。它由MyQRCode团队打造,核心就是把一串枯燥的链接、联系人信息或者文件,变成一块可以任
Nittii
Nittii是什么 让我们从一个核心问题开始:在不用敲一行代码的情况下,如何快速搭建起一套智能化的业务管理系统?Nittii这款AI工具给出的答案,或许会让你眼前一亮。它本质上是一个由专业团队打造的业务赋能平台,核心目标很明确——帮助企业和个人轻松构建、自动化并管理关键业务流程。无论是展示产品、打通
Tweeteasy
什么是Tweeteasy 如果运营过推特账号,你大概深有体会:保持高频率、有质量的互动,实在是件耗时耗力的事。这时候,一个得力的AI助手就显得尤为关键。Tweeteasy正是为此而生的一款工具,它由开发者精心设计,核心目标就是帮助用户在推特上提升互动效率与内容水准。 简单来说,它能帮你生成高质量的推
- 日榜
- 周榜
- 月榜
1
2
3
4
5
6
7
8
9
10
1
2
3
4
5
6
7
8
9
10
1
2
3
4
5
6
7
8
9
10
相关攻略
2015-03-10 11:25
2015-03-10 11:05
2021-08-04 13:30
2015-03-10 11:22
2015-03-10 12:39
2022-05-16 18:57
2025-05-23 13:43
2025-05-23 14:01
热门教程
- 游戏攻略
- 安卓教程
- 苹果教程
- 电脑教程
热门话题

