在 PyTorch 中进行数据增强后得到糟糕的图像
Posted
技术标签:
【中文标题】在 PyTorch 中进行数据增强后得到糟糕的图像【英文标题】:Getting Bad Images After Data Augmentation in PyTorch 【发布时间】:2020-09-05 17:27:07 【问题描述】:我正在研究一个核分割问题,我试图在染色组织的图像中识别核的位置。给定的训练数据集有一张染色组织的图片和一个带有细胞核位置的掩码。由于数据集很小,我想尝试在 PyTorch 中进行数据增强,但在这样做之后,由于某种原因,当我输出我的蒙版图像时,它看起来很好,但对应的组织图像不正确。
我所有的训练图像都在X_train
中,形状为(128, 128, 3)
,对应的掩码在Y_train
中,形状为(128, 128, 1)
,类似的交叉验证图像和掩码分别在X_val
和Y_val
。
Y_train
和Y_val
有dtype = np.bool
、X_train
和X_val
有dtype = np.uint8
。
在数据增强之前,我会像这样检查我的图像:
fig, axis = plt.subplots(2, 2)
axis[0][0].imshow(X_train[0].astype(np.uint8))
axis[0][1].imshow(np.squeeze(Y_train[0]).astype(np.uint8))
axis[1][0].imshow(X_val[0].astype(np.uint8))
axis[1][1].imshow(np.squeeze(Y_val[0]).astype(np.uint8))
输出如下: Before Data Augmentation
对于数据扩充,我定义了一个自定义类如下:
在这里,我将torchvision.transforms.functional
导入为TF
和torchvision.transforms as transforms
。 images_np
和 masks_np
是 numpy 数组的输入。
class Nuc_Seg(Dataset):
def __init__(self, images_np, masks_np):
self.images_np = images_np
self.masks_np = masks_np
def transform(self, image_np, mask_np):
ToPILImage = transforms.ToPILImage()
image = ToPILImage(image_np)
mask = ToPILImage(mask_np.astype(np.int32))
angle = random.uniform(-10, 10)
width, height = image.size
max_dx = 0.2 * width
max_dy = 0.2 * height
translations = (np.round(random.uniform(-max_dx, max_dx)), np.round(random.uniform(-max_dy, max_dy)))
scale = random.uniform(0.8, 1.2)
shear = random.uniform(-0.5, 0.5)
image = TF.affine(image, angle = angle, translate = translations, scale = scale, shear = shear)
mask = TF.affine(mask, angle = angle, translate = translations, scale = scale, shear = shear)
image = TF.to_tensor(image)
mask = TF.to_tensor(mask)
return image, mask
def __len__(self):
return len(self.images_np)
def __getitem__(self, idx):
image_np = self.images_np[idx]
mask_np = self.masks_np[idx]
image, mask = self.transform(image_np, mask_np)
return image, mask
接下来是:
我用过from torch.utils.data import DataLoader
train_dataset = Nuc_Seg(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
val_dataset = Nuc_Seg(X_val, Y_val)
val_loader = DataLoader(val_dataset, batch_size = 16, shuffle = True)
在这一步之后,我尝试使用以下方法检查我的第一组训练图像和蒙版:
%matplotlib inline
for ex_img, ex_mask in train_loader:
img = ex_img[0]
img = img.reshape(128, 128, 3)
mask = ex_mask[0]
mask = mask.reshape(128, 128)
img = img.numpy()
mask = mask.numpy()
fig, (axis_1, axis_2) = plt.subplots(1, 2)
axis_1.imshow(img.astype(np.uint8))
axis_2.imshow(mask.astype(np.uint8))
break
我得到这个作为我的输出: After Data Augmentation 1
当我将axis_1.imshow(img.astype(np.uint8))
更改为axis_1.imshow(img)
时,
我得到这张图片: After Data Augmentation 2
面具的图像是正确的,但由于某种原因,细胞核的图像是错误的。使用.astype(np.uint8)
,组织图像是完全黑色的。
没有.astype(np.uint8)
,原子核的位置是正确的,但是配色方案全乱了(我希望图像像数据增强之前看到的那样,灰色或粉红色),加上 9 个副本出于某种原因,在网格中显示相同的图像。你能帮我得到组织图像的正确输出吗?
【问题讨论】:
【参考方案1】:您正在将图像转换为 PyTorch 张量,并且在 PyTorch 中,图像的大小为 [C, H, W]。当您将它们可视化时,您会将张量转换回 NumPy 数组,其中图像的大小为 [H, W, C]。因此,您尝试重新排列维度,但您使用的是torch.reshape
,它不会交换维度,而只会以不同的方式对数据进行分区。
一个例子更清楚地说明了这一点:
# Incrementing numbers with size 2 x 3 x 3
image = torch.arange(2 * 3 * 3).reshape(2, 3, 3)
# => tensor([[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8]],
#
# [[ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17]]])
# Reshape keeps the same order of elements but for a different size
# The numbers are still incrementing from left to right
image.reshape(3, 3, 2)
# => tensor([[[ 0, 1],
# [ 2, 3],
# [ 4, 5]],
#
# [[ 6, 7],
# [ 8, 9],
# [10, 11]],
#
# [[12, 13],
# [14, 15],
# [16, 17]]])
要重新排序维度,您可以使用permute
:
# Dimensions are swapped
# Now the numbers increment from top to bottom
image.permute(1, 2, 0)
# => tensor([[[ 0, 9],
# [ 1, 10],
# [ 2, 11]],
#
# [[ 3, 12],
# [ 4, 13],
# [ 5, 14]],
#
# [[ 6, 15],
# [ 7, 16],
# [ 8, 17]]])
使用
.astype(np.uint8)
,组织图像是全黑的。
PyTorch 图像表示为具有 [0, 1] 之间值的浮点数,但 NumPy 使用 [0, 255] 之间的整数值。将浮点值转换为 np.uint8
将只得到 0 和 1,其中不等于 1 的所有内容都将设置为 0,因此整个图像是黑色的。
您需要将这些值乘以 255 以使它们进入 [0, 255] 的范围内。
img = img.permute(1, 2, 0) * 255
img = img.numpy().astype(np.uint8)
当您使用transforms.ToPILImage
(或使用TF.to_pil_image
,如果您更喜欢功能版本)将张量转换为 PIL 图像时,此转换也会自动完成,并且 PIL 图像可以直接转换为 NumPy 数组。这样您就不必担心尺寸、值范围或类型,上面的代码可以替换为:
img = np.array(TF.to_pil_image(img))
【讨论】:
非常感谢!我有最后一个问题,当我用img = img.numpy()
做img = img.permute(1, 2, 0)
然后去plt.imshow(img)
时,我也得到了正确的图像。但是由于 img 的值在 [0, 1] 范围内,而 numpy 表示图像在 [0, 255] 范围内,那么 matplotlib 是如何知道如何缩放图像的呢?这可能是一个愚蠢的问题,因为我是新手。
matplotlib.imshow
也接受图像作为浮点数 [0, 1]。它可能会检查类型以将一种转换为另一种。
啊,这就解释了。对不起,我还有最后一个疑问。在数据增强步骤之后,在输出图像中,图像边缘以外的区域会自动填充为黑色,以完成 128x128 网格,如this。如何摆脱黑色区域并用与图像其余部分相似的颜色填充它? torchvision.transforms.functional.affine
中有一个 fillcolor
选项,但它只接受 int 值,不接受 fillcolor = reflect
之类的选项。
转换是用 PIL 完成的,据我所知,它只支持一个恒定的填充值。您需要为此使用其他库。以上是关于在 PyTorch 中进行数据增强后得到糟糕的图像的主要内容,如果未能解决你的问题,请参考以下文章