哪一个更有效:tf.where 还是元素乘法?
Posted
技术标签:
【中文标题】哪一个更有效:tf.where 还是元素乘法?【英文标题】:Which one is more efficient: tf.where or element-wise multiplication? 【发布时间】:2018-03-20 13:34:27 【问题描述】:我正在实现一个损失函数,它将使用由0s and 1s
组成的掩码张量(M)
来消除给定预测(P)
和ground-truth(G)
张量的一些损失值。
所以,我有两种可能的方法:
元素乘法:
loss = K.sum(M * K.binary_crossentropy(G, P))
条件选择:
bin_ce = K.binary_crossentropy(G, P)
loss = K.sum(tf.where(tf.equal(M, 1), bin_ce, 0))
那么,就运行时间而言,哪个更高效?
【问题讨论】:
您自己运行过任何基准测试吗? 我正在运行基准测试,但尚未完成。我事先征求您的意见。 我非常相信乘法的情况会更好......等待你的测试结果。我无法想象第二种情况使用少于 2 个步骤。 你是对的 :) 我做了基准测试,结果在我的答案中。 【参考方案1】:我做了基准测试,很明显乘法比条件选择好得多。
结果如下:
一张图表值一千字。
基准代码:
import keras.backend as K
import tensorflow as tf
import numpy as np
import sys
import time
import matplotlib.pyplot as plt
def elm(G, P, M):
return K.sum(M * K.binary_crossentropy(G, P))
def cond(G, P, M, t):
C = K.variable(np.zeros((t, t)))
bin_ce = K.binary_crossentropy(G, P)
return K.sum(tf.where(tf.equal(M, 1), bin_ce, C))
s = [100, 1000, 10000, 100000]
elms = []
conds = []
for t in s:
print t
t = int(t)
# number of 1s in mask
n = int(t/2)
M = np.zeros((t,t))
P = np.random.rand(t, t)
G = np.random.rand(t, t)
for i in range(n):
r = np.random.randint(0, t)
c = np.random.randint(0, t)
M[r,c] = 1
M = K.variable(M)
P = K.variable(P)
G = K.variable(G)
start_time = time.time()
elm(G, P, M)
elms.append(time.time() - start_time)
start_time = time.time()
cond(G, P, M, t)
conds.append(time.time() - start_time)
print elms
print conds
# create plot
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.35
opacity = 0.8
rects1 = plt.bar(index, elms, bar_width,
alpha=opacity,
color='b',
label='Element-wise')
rects2 = plt.bar(index + bar_width, conds, bar_width,
alpha=opacity,
color='g',
label='Conditional')
plt.xlabel('Input tensor size')
plt.ylabel('Execution time (s)')
plt.title('')
plt.xticks(index + bar_width, ('100', '10e3', '10e4', '10e5'))
plt.legend()
plt.tight_layout()
plt.show()
【讨论】:
以上是关于哪一个更有效:tf.where 还是元素乘法?的主要内容,如果未能解决你的问题,请参考以下文章