pytorch的 torchvision transforms
torchvision是pytorch的数据集,也包含常用数据处理工具,包含几个模块:
- datasets(包含常用的数据集:minist,COCO等)
- models(包含常用的著名网络结构:AlexNet,VGG,ResNet等等,你可以使用随机初始化的网络结构,也可以使用已经训练好的网络)
- transforms(对PIL.Image进行变换处理:Scale(缩放)、CenterCrop(中心切割)、Pad(填充)等),PIL(Python Imaging Library)是python对图形处理的库。
下面具体讲一下transforms中常用函数的使用
transforms.Scale(size)
将输入的PIL.Image
重新改变大小成给定的size,size是最小边的边长。举个例子,如果原图的height>width,那么改变大小后的图片大小是(size*height/width, size)
,若是height<width,那么就是(size, size*width/height)
。
例:
from PIL import Image
from torchvision import transforms
crop=transforms.Scale(12)
img=Image.open('test.jpg')
print(type(img))
print(img.size)
print(crop(img).size)
输出:
<class ‘PIL.JpegImagePlugin.JpegImageFile’>
(261, 230)
(13, 12)
transforms.ToTensor()
把一个取值范围是[0,255]的PIL.Image
或者shape
为(Height,Width,Channel)
的numpy.ndarray
,转换成形状为[Channel,Height,Width]
,(也就是把通道数放第一维度了)且取值范围是[0,1.0]的torch.FloadTensor
。
例:
from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
im_arry=np.asarray(im)
print(im_arry.shape)
im_tensor=transforms.ToTensor()(im)#或者用im_arry也是可以的
print(im_tensor)
print(im_tensor.shape)
输出:
(230, 261, 3)
tensor([[[0.2314, 0.2392, 0.2392, …, 0.2314, 0.2314, 0.2392],
[0.2314, 0.2314, 0.2314, …, 0.2314, 0.2314, 0.2314],
[0.2314, 0.2314, 0.2314, …, 0.2314, 0.2314, 0.2314],
…,
torch.Size([3, 230, 261])
可以看出通道数的确放前面去了,且取值范围都在0-1之间,而且transforms.ToTensor()是直接处理PIL image也可以是image array
transforms.ToPILImage
与前面的相反:将shape为(C,H,W)
的Tensor或shape为(H,W,C)
的numpy.ndarray
转换成PIL.Image
,值不变 。
transforms.Normalize(mean, std)
给定均值:(R,G,B) ,方差:(R,G,B),将会把Tensor正则化。即:Normalized_image=(image-mean)/std
例:
from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
im_arry=np.asarray(im)
print(im_arry.shape)
im_tensor=transforms.ToTensor()(im)
im_Normal=transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(im_tensor)
print(im_tensor)
print(im_tensor.shape)
输出:
(230, 261, 3)
tensor([[[-0.5373, -0.5216, -0.5216, …, -0.5373, -0.5373, -0.5216],
[-0.5373, -0.5373, -0.5373, …, -0.5373, -0.5373, -0.5373],
[-0.5373, -0.5373, -0.5373, …, -0.5373, -0.5373, -0.5373],
…,
[-0.6627, -0.6627, -0.6627, …, -0.6627, -0.6627, -0.6627],
[-0.6627, -0.6627, -0.6627, …, -0.6627, -0.6627, -0.6627],
[-0.6627, -0.6627, -0.6627, …, -0.6627, -0.6627, -0.6627]],
torch.Size([3, 230, 261])
一定要把图像先转换为tensor,在用此函数。
transforms.Pad(padding, fill=0)
将给定的PIL.Image
的所有边用给定的pad value填充。 padding:要填充多少像素 fill:用什么值填充
例:
from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
print(im)
im_re=transforms.Resize((3, 3))(im)
print(np.asarray(im_re))
print(np.asarray(im_re).shape)
im_pad=transforms.Pad(padding=1,fill=0)(im_re)
print(np.asarray(im_pad))
print(np.asarray(im_pad).shape)
输出:
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=261x230 at 0x20FBDA02EB8>
[[[60 61 61]
[61 62 62]
[62 64 64]]
[[53 54 53]
[53 53 52]
[51 51 50]]
[[45 45 45]
[45 45 45]
[45 45 45]]]
(3, 3, 3)
[[[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]]
[[ 0 0 0]
[60 61 61]
[61 62 62]
[62 64 64]
[ 0 0 0]]
[[ 0 0 0]
[53 54 53]
[53 53 52]
[51 51 50]
[ 0 0 0]]
[[ 0 0 0]
[45 45 45]
[45 45 45]
[45 45 45]
[ 0 0 0]]
[[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]]]
(5, 5, 3)
可以看出图片填充在每一个通道上都进行了填充,可以把初始[3,3,3]想像成一个333的立方体,然后上下两个面不动,周围4个面各向外推出2,就得到553的立方体。
transforms.Resize((height, width))
resize图像,例子见上
transforms.Compose()
就是把多个transforms组合起来.
例子
from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
print(im)
com=transforms.Compose([
transforms.Resize((3,4)),
transforms.ToTensor(),
])
im_com=com(im)
print(im_com)
输出:
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=261x230 at 0x1BDDCFC2E10>
tensor([[[0.2353, 0.2314, 0.2471, 0.2392],
[0.2078, 0.2078, 0.2039, 0.1961],
[0.1765, 0.1765, 0.1765, 0.1765]],…
transforms.Lambda()
用户可以用transforms.Lambda()
函数自行定义transform操作,该操作不是由torchvision库所拥有的,其中参数是lambda表示的是自定义函数。
举例说明:
比如当我们想要截取图像,但并不想在随机位置截取,而是希望在一个自己指定的位置去截取
那么你就需要自定义一个截取函数,然后使用transforms.Lambda
去封装它即可,如:
# coding:utf-8
from torchvision import transforms as T
def __crop(img, pos, size):
"""
:param img: 输入的图像
:param pos: 图像截取的位置,类型为元组,包含(x, y)
:param size: 图像截取的大小
:return: 返回截取后的图像
"""
ow, oh = img.size
x1, y1 = pos
tw = th = size
# 有足够的大小截取
# img.crop坐标表示 (left, upper, right, lower)
if (ow > tw or oh > th):
return img.crop((x1, y1, x1+tw, y1+th))
return img
# 然后使用transforms.Lambda封装其为transforms策略
# 然后定义新的transforms为
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_transforms = T.Compose([
T.Lambda(lambda img: __crop(img, (5,5), 224)),
T.RandomHorizontalFlip(), # 随机水平翻转给定的PIL.Image,翻转概率为0.5
T.ToTensor(), # 转成Tensor格式,大小范围为[0,1]
normalize
])