在Tensorflow中生成关键点热图
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了在Tensorflow中生成关键点热图相关的知识,希望对你有一定的参考价值。
我正在尝试训练面部关键点检测模型。这是Stacked HourGlass模型。它输出256x256x68尺寸张量。 68个输出中的每一个都将在关键点周围具有热区域。我已经定义了模型和图形构造很好。我的问题是生成数据集。
我需要从68x2维地标张量生成256x256x68维标签_tensor。虽然我可以在numpy中完成它并将其保存在TFRecord中,但我想在tf.data.Dataset API的parse_function中探索并查看是否可以在训练时执行此操作。
对于每个热图,我需要在相应的地标点的x,y位置绘制高斯。
码
我在parse_function中有以下代码:
# heatmaps
joints = tf.stack([points_x, points_y], axis=1)
heatmaps = _generate_heatmaps(joints, 1., IMG_DIM)
这是_generate_heatmaps函数:
def _generate_heatmaps(joints, sigma, outres):
npart = 68
gtm = tf.placeholder(tf.float32, shape=[None, outres, outres, npart])
gtmaps = tf.zeros_like(gtm)
for i in range(npart):
visibility = 1
if visibility > 0:
gtmaps[:, :, :, i] = _draw_hotspot(gtmaps[:, :, :, i], joints[:, i, :], sigma)
return gtmaps
_draw_hotspot函数:
def _draw_hotspot(img, pt, sigma, type='Gaussian'):
# Draw a 2D gaussian
# Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
# Check that any part of the gaussian is in-bounds
ul = [(pt[:,0] - 3 * sigma), (pt[:,1] - 3 * sigma)]
br = [(pt[:,0] + 3 * sigma + 1), (pt[:,1] + 3 * sigma + 1)]
# if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
# br[0] < 0 or br[1] < 0):
# # If not, just return the image as is
# return img
# Generate gaussian
size = 6 * sigma + 1
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
# if type == 'Gaussian':
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
# elif type == 'Cauchy':
# g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
# Usable gaussian range
g_x = [tf.clip_by_value(-1*ul[0], -100, 0)*-1, tf.minimum(br[0], img.shape[2].value) - ul[0]]
g_y = [tf.clip_by_value(-1*ul[1], -100, 0)*-1, tf.minimum(br[1], img.shape[1].value) - ul[1]]
g_x = tf.cast(g_x, tf.int64)
g_y = tf.cast(g_y, tf.int64)
# Image range
img_x = [tf.clip_by_value(ul[0], 0, img.shape[1].value), tf.clip_by_value(br[0], 0, img.shape[1].value)]
img_y = [tf.clip_by_value(ul[1], 0, img.shape[2].value), tf.clip_by_value(br[1], 0, img.shape[2].value)]
img_x = tf.cast(img_x, tf.int64)
img_y = tf.cast(img_y, tf.int64)
# img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
img_slice = tf.image.extract_glimpse # ... stuck ...
return img
我需要将这个numpy代码转换为tensorflow代码img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
就在最后一行!有人可以帮忙吗?
答案
这是一种使用SciPy的方法,您可以使用tf.py_func
处理TF管道:
from scipy.stats import multivariate_normal
pos = np.dstack(np.mgrid[0:68:1, 0:68:1])
rv = multivariate_normal(mean=[22,43], cov=4)
plt.imshow(rv.pdf(pos))
以上是关于在Tensorflow中生成关键点热图的主要内容,如果未能解决你的问题,请参考以下文章
使用在另一个片段(NPE)中生成的值设置片段的 TextView [重复]
Android:将片段和弹出窗口的点击事件中生成的变量传递给活动的方法