Tensorflow2 tf.nn.maxpool2d()进行池化运算及其可视化

Posted 空中旋转篮球

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow2 tf.nn.maxpool2d()进行池化运算及其可视化相关的知识,希望对你有一定的参考价值。

1.tf.nn.maxpool2d()函数介绍

tf.nn.max_pool2d(input, ksize, strides, padding, data_format='NHWC', name=None)

参数说明:

Args

inputA 4-D Tensor of the format specified by data_format.
ksizeAn int or list of ints that has length 12 or 4. The size of the window for each dimension of the input tensor.
stridesAn int or list of ints that has length 12 or 4. The stride of the sliding window for each dimension of the input tensor.
paddingEither the string "SAME" or "VALID" indicating the type of padding algorithm to use, or a list indicating the explicit paddings at the start and end of each dimension. When explicit padding is used and data_format is "NHWC", this should be in the form [[0, 0], [pad_top,pad_bottom], [pad_left, pad_right], [0, 0]]. When explicit padding used and data_format is "NCHW", this should be in the form [[0, 0], [0, 0],[pad_top, pad_bottom], [pad_left, pad_right]]. When using explicit padding, the size of the paddings cannot be greater than the sliding window size.
data_formatA string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
nameOptional name for the operation.

2.使用数据

flower_photos数据集中选一张玫瑰花图片

3.池化运算代码

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf

data = Image.open("roses_4483444865_65962cea07_m.jpg")  # 返回一个PIL图像对象
plt.imshow(data)
plt.show()

x = np.array(data)
x = x / 255
x = x.reshape(1, 240, 180, 3)

image_tensor = tf.convert_to_tensor(x)
x_input = tf.cast(image_tensor, tf.float32)

print("x_in{}", x_input.shape)

kernel_in = np.array(
    [[[-1, 1]], [[1, 1]],])
print(kernel_in.shape)
kernel = tf.constant(kernel_in, dtype=tf.float32)

Z1=tf.nn.max_pool2d(x_input, [1,2,2,1], strides=[1, 2, 2, 1], padding='SAME')

x_max_pool2d = np.array(Z1)
print(x_max_pool2d.shape)
x_max_pool2d = x_max_pool2d.reshape(120, 90, 3)
plt.imshow(x_max_pool2d)
plt.show()

4.计算结果

 

 

以上是关于Tensorflow2 tf.nn.maxpool2d()进行池化运算及其可视化的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow2.0--TensorFlow2.0构架

TensorFlow2 入门指南 | 06 TensorFLow2 高阶操作汇总

TensorFlow2 入门指南 | 06 TensorFLow2 高阶操作汇总

Tensorflow2.0笔记

tensorflow2.0(2)-自定义Dense层以及训练过程

Tensorflow2.0笔记