数据增强(每10度进行旋转,进行一次增强,然后对每张图片进行扩充10张patch,最后得到原始图片数*37*10数量的图片)
Posted fourmi_gsj
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了数据增强(每10度进行旋转,进行一次增强,然后对每张图片进行扩充10张patch,最后得到原始图片数*37*10数量的图片)相关的知识,希望对你有一定的参考价值。
# -*- coding: utf-8 -*- """ Fourmi This is a temporary script file. """ import cv2 import os import numpy as np import random import math def extract_random(full_imgs,full_masks,patch_h,patch_w,N_patches): if(N_patches%(len(full_imgs))!=0): print("N_patches: please enter a multiple of 115") exit() patches=np.empty((N_patches,patch_h,patch_w)) patches_masks = np.empty((N_patches,patch_h,patch_w)) img_h=full_imgs[0].shape[0] img_w=full_imgs[0].shape[1] patch_per_img=int(N_patches/(full_imgs.shape[0])) print("patches per full image: "+str(patch_per_img)) iter_tot=0 for i in range(full_imgs.shape[0]): k=0 while k<patch_per_img: x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2)) y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2)) patch=full_imgs[i][y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patch_mask=full_masks[i][y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] #print(patch_mask.shape) patches[iter_tot]=patch patches_masks[iter_tot]=patch_mask iter_tot+=1 k+=1 return patches,patches_masks def imagePadding(img): img_h=img.shape[0] img_w=img.shape[1] scale=int(math.sqrt(img_h*img_h+img_w*img_w)) scale=scale*2 size=(int(scale),int(scale)) out=cv2.resize(img,size,interpolation=cv2.INTER_AREA) return out def get_data(data_imgs_org, data_groundTruth, patch_height, patch_width, N_subimgs): imgs_org,imgs_groundTruth=ReadandProcessImage(data_imgs_org,data_groundTruth) print(‘imgs.shape‘,imgs_org.shape) print(‘imgs_groundTruth‘,imgs_groundTruth.shape) patches_imgs_train,patches_masks_train=extract_random(imgs_org, imgs_groundTruth,patch_height,patch_width,N_subimgs) return patches_imgs_train,patches_masks_train def ReadandProcessImage(orgImgPath,groundTruthPath): images=[] labels=[] for root, dirs, files in os.walk(orgImgPath, topdown=False): for file in files: temp=file[:-4] ImgPath=os.path.join(root,file) LabelPath=os.path.join(groundTruthPath,temp+‘.png‘) myimg=cv2.imread(ImgPath,0) mylabel=cv2.imread(LabelPath,0) print(‘ImgPath:‘,ImgPath) print(‘LabelPath:‘,LabelPath) #img=cv2.cvtColor(myimg,cv2.COLOR_BGR2GRAY) #mylabel=cv2.cvtColor(mylabel,cv2.COLOR_BGR2GRAY) assert(len(myimg.shape)==len(mylabel.shape)) assert(myimg.shape[0]==mylabel.shape[0]) assert(myimg.shape[1]==mylabel.shape[1]) img=myimg #org_h=img.shape[0] #org_w=img.shape[1] img=cv2.equalizeHist(img) img=imagePadding(img) mylabel=imagePadding(mylabel) images.append(img) labels.append(mylabel) return np.array(images),np.array(labels) def roatate_img_label_to_file(imgPath,labelPath): global Iter Iter=1 def rotateImg(img,label,orgHeight,orgWidth,imgPath,labelPath): global Iter (h,w)=img.shape center=(h/2,w/2) for i in range(360): if (i%10!=0): continue M = cv2.getRotationMatrix2D(center, i, 1) imgRotated = cv2.warpAffine(img, M, (h, w)) img0=imgRotated[int(center[0])-int(orgHeight/2):int(center[0])+int(orgHeight/2), int(center[1])-int(orgWidth/2):int(center[1])+int(orgWidth/2)] labelRotated = cv2.warpAffine(label, M, (h, w)) label0=labelRotated[int(center[0])-int(orgHeight/2):int(center[0])+int(orgHeight/2), int(center[1])-int(orgWidth/2):int(center[1])+int(orgWidth/2)] path0=os.path.join(imgPath,str(Iter+115)+‘.jpg‘) cv2.imwrite(path0,img0) path=os.path.join(labelPath,str(Iter+115)+‘.png‘) cv2.imwrite(path,label0) Iter=Iter+1 print("ROTATW DONE!!!!") for root,dirs,files in os.walk(imgPath,topdown=False): for file in files: imgpath=os.path.join(root,file) temp=file[:-4] labelpath=os.path.join(labelPath,temp+‘.png‘) img=cv2.imread(imgpath,0) label=cv2.imread(labelpath,0) print(‘imgpath:‘,imgpath) print(‘labelpath:‘,labelpath) print(‘imgshape:‘,img.shape) print(‘labelshape:‘,label.shape) assert(len(img.shape)==len(label.shape)) assert(img.shape[0]==label.shape[0]) assert(img.shape[1]==label.shape[1]) org_h=img.shape[0] org_w=img.shape[1] img=imagePadding(img) label=imagePadding(label) print(‘imgPadding:‘,img.shape) print(‘labelPadding:‘,label.shape) rotateImg(img,label,org_h,org_w,imgPath,labelPath) data_train_imgs_org="/home/chendali1/Gsj/JX/Image/train/" data_test_imgs_org="/home/chendali1/Gsj/JX/Image/test/" data_train_grountTruth="/home/chendali1/Gsj/JX/GT/train/" data_test_grountTruth="/home/chendali1/Gsj/JX/GT/test/" patches_path_train=‘/home/chendali1/Gsj/JX/Patches/Org/train/‘ patches_path_test=‘/home/chendali1/Gsj/JX/Patches/Org/test/‘ patches_path_label_train=‘/home/chendali1/Gsj/JX/Patches/Label/train/‘ patches_path_label_test=‘/home/chendali1/Gsj/JX/Patches/Label/test/‘ #rotate_train_imgs_path="/home/chendali1/Gsj/JX/Image/train/" #rotate_test_imgs_path="/home/chendali1/Gsj/JX/Image/test/" #rotate_train_label_path="/home/chendali1/Gsj/JX/GT/train/" #rotate_test_label_path="/home/chendali1/Gsj/JX/GT/test/" if not os.path.exists(patches_path_train): os.makedirs(patches_path_train) if not os.path.exists(patches_path_test): os.makedirs(patches_path_test) if not os.path.exists(patches_path_label_train): os.makedirs(patches_path_label_train) if not os.path.exists(patches_path_label_test): os.makedirs(patches_path_label_test) roatate_img_label_to_file(data_train_imgs_org,data_train_grountTruth) train_patches,train_groundTruth=get_data(data_train_imgs_org,data_train_grountTruth,224,224,37*115*10) for i in range(train_patches.shape[0]): b=np.zeros([train_patches.shape[1],train_patches.shape[2],3]) b[:,:,0]=train_patches[i,:,:] b[:,:,1]=train_patches[i,:,:] b[:,:,2]=train_patches[i,:,:] cv2.imwrite(patches_path_train+str(i)+‘.jpg‘,train_patches[i,:,:]) cv2.imwrite(patches_path_label_train+str(i)+‘.png‘,train_groundTruth[i,:,:])
以上是关于数据增强(每10度进行旋转,进行一次增强,然后对每张图片进行扩充10张patch,最后得到原始图片数*37*10数量的图片)的主要内容,如果未能解决你的问题,请参考以下文章
Pytroch torchvision 数据增强 翻转 旋转