color-loss pytorch实现
Posted 流浪若相惜
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了color-loss pytorch实现相关的知识,希望对你有一定的参考价值。
DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks---colorloss-Pytorch实现
1.实现原理
最近在做图像增强相关的工作,偶然间看到了这篇文章,作者提出了一个损失叫做color-loss,根据文章描述该方法是通过模糊输入图像与ground-Truth的纹理、内容,仅仅保存图像的颜色信息实现图像颜色的校正。实现过程比较简单,首先构建一个高斯模糊核,然后利用高斯模糊核作为卷积核对图像进行卷积运算,得到模糊后的图像;然后计算输入图像与ground-Truth的MSE作为损失函数。
作者的github中有该模型的代码,但是是用TensorFlow实现的。因为我的代码pytorch的,所以自己重新改写了一下。在作者的代码中用到了深度可分离卷积,在pytorch中我没有对其进行深度可分离操作。
算是个深度学习的小白吧,有问题可以给我留言呀~~~
2.代码
// An highlighted block
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from math import exp, pi
import numpy as np
import cv2 as cv
import scipy.stats as st
import matplotlib.pyplot as plt
def gauss_kernel(kernlen=21, nsig=3, channels=1):
interval = (2*nsig+1.)/(kernlen)
x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
kern1d = np.diff(st.norm.cdf(x))
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
kernel = kernel_raw/kernel_raw.sum()
out_filter = np.array(kernel, dtype = np.float32)
out_filter = out_filter.reshape((kernlen, kernlen))
# out_filter = np.repeat(out_filter, channels, axis = 0)
return out_filter # kernel_size=21
class SeparableConv2d(nn.Module):
def __init__(self):
super(SeparableConv2d, self).__init__()
kernel = gauss_kernel(21, 3, 3)
kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
## kernel_point = [[1.0]]
## kernel_point = torch.FloatTensor(kernel_point).unsqueeze(0).unsqueeze(0)
# kernel = torch.FloatTensor(kernel).expand(3, 3, 21, 21) # torch.expand()向输入的维度前面进行扩充,输入为三通道时,将weight扩展为[3,3,21,21]
## kernel_point = torch.FloatTensor(kernel_point).expand(3,3,1,1)
self.weight = nn.Parameter(data=kernel, requires_grad=False)
# self.pointwise = nn.Conv2d(1, 1, 1, 1, 0, 1, 1,bias=False) # 单通道时in_channels=1,out_channels=1,三通道时,in_channels=3, out_channels=3 卷积核为随机的
## self.weight_point = nn.Parameter(data=kernel_point, requires_grad=False)
def forward(self, img1):
x = F.conv2d(img1, self.weight, groups=1,padding=10)
## x = F.conv2d(x, self.weight_point, groups=1, padding=0) #卷积核为[1]
# x = self.pointwise(x)
return x
# plt.imshow(out_kernel)
# plt.imshow(out_kernel)
以上是关于color-loss pytorch实现的主要内容,如果未能解决你的问题,请参考以下文章