稀疏分类交叉熵未按预期工作

Posted

技术标签:

【中文标题】稀疏分类交叉熵未按预期工作【英文标题】:Sparse categorical cross entropy not working as expected 【发布时间】:2020-09-10 07:45:13 【问题描述】:

我正在研究 MIMIC 3 数据集以对患者进行分类。我有 7000 个类,类的值范围高达 50000,标签是整数。所以我使用了稀疏分类熵损失,这样我就不必担心标签的范围很大。但我的模型在大于 7000 的标签上给出了 nan 损失。

n_timesteps, n_features, n_outputs = 50, 59, 7000
batch_size=32
model = Sequential()
model.add(CuDNNLSTM(32, input_shape=(n_timesteps,n_features)))
model.add(Dropout(0.3))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(n_outputs, activation='softmax'))
model.compile(loss=keras.losses.sparse_categorical_crossentropy, optimizer=keras.optimizers.Adam(), metrics=['sparse_categorical_accuracy'])

我正在通过

对数据进行规范化
def normalize(df):

  header=df.columns
  subid_col=df['SUBJECT_ID']
  label_col=df['label']
  df = df.replace(0, np.NaN)
  df=pd.DataFrame(df).fillna(method="ffill")
  df=pd.DataFrame(df).fillna(df.mean())
  df=pd.DataFrame(df).fillna(-99)
  x = df.values
  header=df.columns

  min_max_scaler = preprocessing.MinMaxScaler(feature_range = (1,2))
  x_scaled = min_max_scaler.fit_transform(x)
  df = pd.DataFrame(x_scaled)
  df=pd.DataFrame(df).fillna(-1)

  df.columns=header

  df['SUBJECT_ID']=subid_col
  df['label']=label_col
  df.drop(df.columns[0], axis = 1, inplace = True)

  return df

我正在做一个时间序列分类。

【问题讨论】:

依次使用 3 种不同的 fillna 方法有什么意义? 第一个填充Nan下面的数据(在一列中)第二个填充剩余的nans上面的数据在列中。第三个将填充只有 nans 的列 【参考方案1】:

由于您的最后一层的长度为n_outputs=7000sparse_categorical_crossentropy 需要在[0, 7000] 范围内的索引符号中的标签。因此,您需要在训练之前映射您的标签,即:

from sklearn import preprocessing

le = preprocessing.LabelEncoder()
le.fit(label_col)

df['label'] = le.transform(label_col)

要从编码中检索“真实”标签,您可以使用le.inverse_transform(label_col)


或者,您也可以只更改您的 n_outputs = 50000,但如果您只训练/期望看到其中的 7000 个子集,这可能会效率低下,因为您不必要地添加了理想情况下可以学习为 @987654328 的参数@,因为您没有带有这些标签的数据。

【讨论】:

以上是关于稀疏分类交叉熵未按预期工作的主要内容,如果未能解决你的问题,请参考以下文章

具有对数损失的 TensorFlow 单 sigmoid 输出与具有稀疏 softmax 交叉熵损失的两个线性输出,用于二进制分类

Keras 和 TensorFlow 中所有这些交叉熵损失之间有啥区别?

多个正分类的 TensorFlow 损失计算

Softmax 的交叉熵是不是适用于多标签分类?

均方差交叉熵及公式推导

交叉熵损失函数