数据增强
卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。
2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。
常用的数据增强方法
常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化
这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。
import sys sys.path.append('..') from PIL import Image from torchvision import transforms as tfs # 读入一张图片 im = Image.open('./cat.png') im
随机比例放缩
随机比例缩放主要使用的是 torchvision.transforms.Resize()
这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档
# 比例缩放 print('before scale, shape: {}'.format(im.size)) new_im = tfs.Resize((100, 200))(im) print('after scale, shape: {}'.format(new_im.size)) new_im
随机位置截取
随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop()
,传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop()
,同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取
# 随机裁剪出 100 x 100 的区域 random_im1 = tfs.RandomCrop(100)(im) random_im1
# 中心裁剪出 100 x 100 的区域 center_im = tfs.CenterCrop(100)(im) center_im
随机的水平和竖直方向翻转
对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()
和 torchvision.transforms.RandomVerticalFlip()
# 随机水平翻转 h_filp = tfs.RandomHorizontalFlip()(im) h_filp
# 随机竖直翻转 v_flip = tfs.RandomVerticalFlip()(im) v_flip
随机角度旋转
一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation()
来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转
rot_im = tfs.RandomRotation(45)(im) rot_im
亮度、对比度和颜色的变化
除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色
# 亮度 bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图 bright_im
# 对比度 contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图 contrast_im
# 颜色 color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化 color_im
上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子
im_aug = tfs.Compose([ tfs.Resize(120), tfs.RandomHorizontalFlip(), tfs.RandomCrop(96), tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5) ])
import matplotlib.pyplot as plt %matplotlib inline nrows = 3 ncols = 3 figsize = (8, 8) _, figs = plt.subplots(nrows, ncols, figsize=figsize) for i in range(nrows): for j in range(ncols): figs[i][j].imshow(im_aug(im)) figs[i][j].axes.get_xaxis().set_visible(False) figs[i][j].axes.get_yaxis().set_visible(False) plt.show()
可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练
使用数据增强
import numpy as np import torch from torch import nn import torch.nn.functional as F from torch.autograd import Variable from torchvision.datasets import CIFAR10 from utils import train, resnet from torchvision import transforms as tfs # 使用数据增强 def train_tf(x): im_aug = tfs.Compose([ tfs.Resize(120), tfs.RandomHorizontalFlip(), tfs.RandomCrop(96), tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x def test_tf(x): im_aug = tfs.Compose([ tfs.Resize(96), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x train_set = CIFAR10('./data', train=True, transform=train_tf) train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = CIFAR10('./data', train=False, transform=test_tf) test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False) net = resnet(3, 10) optimizer = torch.optim.SGD(net.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() train(net, train_data, test_data, 10, optimizer, criterion)
不使用数据增强
# 不使用数据增强 def data_tf(x): im_aug = tfs.Compose([ tfs.Resize(96), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x train_set = CIFAR10('./data', train=True, transform=data_tf) train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = CIFAR10('./data', train=False, transform=data_tf) test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False) net = resnet(3, 10) optimizer = torch.optim.SGD(net.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() train(net, train_data, test_data, 10, optimizer, criterion)
从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。
而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。
以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注其它相关文章!
Pytorch,数据增强
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
《魔兽世界》大逃杀!60人新游玩模式《强袭风暴》3月21日上线
暴雪近日发布了《魔兽世界》10.2.6 更新内容,新游玩模式《强袭风暴》即将于3月21 日在亚服上线,届时玩家将前往阿拉希高地展开一场 60 人大逃杀对战。
艾泽拉斯的冒险者已经征服了艾泽拉斯的大地及遥远的彼岸。他们在对抗世界上最致命的敌人时展现出过人的手腕,并且成功阻止终结宇宙等级的威胁。当他们在为即将于《魔兽世界》资料片《地心之战》中来袭的萨拉塔斯势力做战斗准备时,他们还需要在熟悉的阿拉希高地面对一个全新的敌人──那就是彼此。在《巨龙崛起》10.2.6 更新的《强袭风暴》中,玩家将会进入一个全新的海盗主题大逃杀式限时活动,其中包含极高的风险和史诗级的奖励。
《强袭风暴》不是普通的战场,作为一个独立于主游戏之外的活动,玩家可以用大逃杀的风格来体验《魔兽世界》,不分职业、不分装备(除了你在赛局中捡到的),光是技巧和战略的强弱之分就能决定出谁才是能坚持到最后的赢家。本次活动将会开放单人和双人模式,玩家在加入海盗主题的预赛大厅区域前,可以从强袭风暴角色画面新增好友。游玩游戏将可以累计名望轨迹,《巨龙崛起》和《魔兽世界:巫妖王之怒 经典版》的玩家都可以获得奖励。
更新日志
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]