color-loss pytorch实现

Posted 流浪若相惜

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了color-loss pytorch实现相关的知识,希望对你有一定的参考价值。

DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks---colorloss-Pytorch实现

1.实现原理

最近在做图像增强相关的工作,偶然间看到了这篇文章,作者提出了一个损失叫做color-loss,根据文章描述该方法是通过模糊输入图像与ground-Truth的纹理、内容,仅仅保存图像的颜色信息实现图像颜色的校正。实现过程比较简单,首先构建一个高斯模糊核,然后利用高斯模糊核作为卷积核对图像进行卷积运算,得到模糊后的图像;然后计算输入图像与ground-Truth的MSE作为损失函数。
作者的github中有该模型的代码,但是是用TensorFlow实现的。因为我的代码pytorch的,所以自己重新改写了一下。在作者的代码中用到了深度可分离卷积,在pytorch中我没有对其进行深度可分离操作。
算是个深度学习的小白吧,有问题可以给我留言呀~~~

2.代码

// An highlighted block
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from math import exp, pi
import numpy as np
import cv2 as cv
import scipy.stats as st
import matplotlib.pyplot as plt

def gauss_kernel(kernlen=21, nsig=3, channels=1):
    interval = (2*nsig+1.)/(kernlen)
    x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
    kern1d = np.diff(st.norm.cdf(x))
    kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
    kernel = kernel_raw/kernel_raw.sum()
    out_filter = np.array(kernel, dtype = np.float32)
    out_filter = out_filter.reshape((kernlen, kernlen))
    # out_filter = np.repeat(out_filter, channels, axis = 0)
    return out_filter   # kernel_size=21

class SeparableConv2d(nn.Module):
    def __init__(self):
        super(SeparableConv2d, self).__init__()
        kernel = gauss_kernel(21, 3, 3)
        kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        ## kernel_point = [[1.0]]
        ## kernel_point = torch.FloatTensor(kernel_point).unsqueeze(0).unsqueeze(0)
        # kernel = torch.FloatTensor(kernel).expand(3, 3, 21, 21)   # torch.expand()向输入的维度前面进行扩充,输入为三通道时,将weight扩展为[3,3,21,21]
        ## kernel_point = torch.FloatTensor(kernel_point).expand(3,3,1,1)
        self.weight = nn.Parameter(data=kernel, requires_grad=False)
        # self.pointwise = nn.Conv2d(1, 1, 1, 1, 0, 1, 1,bias=False)    # 单通道时in_channels=1,out_channels=1,三通道时,in_channels=3, out_channels=3  卷积核为随机的
        ## self.weight_point = nn.Parameter(data=kernel_point, requires_grad=False)

    def forward(self, img1):
        x = F.conv2d(img1, self.weight, groups=1,padding=10)
        ## x = F.conv2d(x, self.weight_point, groups=1, padding=0)  #卷积核为[1]
        # x = self.pointwise(x)
        return x
# plt.imshow(out_kernel)
# plt.imshow(out_kernel)

以上是关于color-loss pytorch实现的主要内容,如果未能解决你的问题,请参考以下文章

说话人识别损失函数的PyTorch实现与代码解读

Pytorch:为啥在 nn.modules.loss 和 nn.functional 模块中都实现了损失函数?

PyTorch 交叉熵损失函数内部原理简单实现

学习笔记Pytorch十二损失函数与反向传播

Pytorch 线性回归损失增加

Pytorch:损失函数