使用 Adam 优化器在 FashionMNIST 上训练逻辑回归时出错

Posted

技术标签:

【中文标题】使用 Adam 优化器在 FashionMNIST 上训练逻辑回归时出错【英文标题】:Error while training logistic regression on FashionMNIST with Adam optimizer 【发布时间】:2020-10-28 17:53:07 【问题描述】:

数据集是 FashionMNIST(784 个输入,10 个输出)。我正在尝试使用 Adam 优化器训练逻辑回归(也对其进行编码):

weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()

bias = torch.zeros(10, requires_grad=True)

optimizer = Adam([weights, bias])
criterion = nn.CrossEntropyLoss()

火车功能是:

def train_logistic_regression(weights, bias, batch, loss, optimizer):

    inputs, labels = batch

    inputs = inputs.view(inputs.shape[0], -1)

    optimizer.zero_grad()
    y_pred = torch.sigmoid(weights@inputs + bias) # there must be the problem
    loss = criterion(y_pred, labels)
    loss.backward()
    optimizer.step()


from IPython.display import clear_output


for epoch in range(1, 5):

    for batch in train_dataloader: # have to go with batches
      metrics = train_logistic_regression(weights, bias, batch, criterion, optimizer)

每次我得到错误:

RuntimeError                              Traceback (most recent call last)
<ipython-input-161-408b80d71db1> in <module>()
      5 
      6     for batch in train_dataloader:
----> 7       metrics = train_logistic_regression(weights, bias, batch, criterion, optimizer)
      8 
      9 

<ipython-input-160-9c2f95ee56ee> in train_logistic_regression(weights, bias, batch, loss, optimizer)
      6 
      7     optimizer.zero_grad()
----> 8     y_pred = torch.sigmoid(weights@inputs + bias)
      9     # y_pred = model(inputs)
     10     loss = criterion(y_pred, labels)

RuntimeError: size mismatch, m1: [784 x 10], m2: [128 x 784] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41

如果有人可以帮助我,将不胜感激。

【问题讨论】:

【参考方案1】:

而不是y_pred = torch.sigmoid(weights@inputs + bias) 应该是y_pred = torch.sigmoid(inputs.mm(weights) + bias)

【讨论】:

以上是关于使用 Adam 优化器在 FashionMNIST 上训练逻辑回归时出错的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow:如何正确使用 Adam 优化器

Adam优化器

Tensorflow Adam 优化器与 Keras Adam 优化器

Keras 中 Adam 优化器的衰减参数

SGD、Adam优化器

Pytorch Note20 优化算法6 Adam算法