在将TensorFlow模型转换为Pytorch时出现大小不匹配的错误。

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了在将TensorFlow模型转换为Pytorch时出现大小不匹配的错误。相关的知识,希望对你有一定的参考价值。

我试图将TensorFlow模型转换为Pytorch,但卡在了这个错误中。谁能帮帮我?

#getting weights and biases from tensorflow model
weights, biases = model.layers[0].get_weights()
#[1] is the dropout layer
weights2, biases2 = model.layers[2].get_weights()

#initializing pytorch
class TwoLayerNet(torch.nn.Module):
def __init__(self, weights, biases, weights2, biases2):

    super(TwoLayerNet, self).__init__()
    #created the model in the same dimensions as tensorflow´s model
    self.linear1 = torch.nn.Linear(9, 2048)
    self.hidden1 = nn.Dropout(0.2)
    self.linear2 = torch.nn.Linear(2048,5)

    weights = torch.from_numpy(weights)
    biases = torch.from_numpy(biases)
    weights2 = torch.from_numpy(weights2)
    biases2 = torch.from_numpy(biases2)

    self.linear1.weight = torch.nn.Parameter(weights)
    self.linear1.bias = torch.nn.Parameter(biases)
    self.linear2.weight.data = weights2
    self.linear2.bias.data = biases2
    #in this print the dimensions are ok (Linear(in_features=9, out_features=2048, bias=True))
    print(self.linear1)

def forward(self, x):
  print(self.linear1)
  x = self.linear1(x)
  x = self.hidden1(x)
  x = self.linear2(x)
  return x

model_pytorch = TwoLayerNet(weights, biases, weights2, biases2)

model_pytorch.eval()
exemplo_input_torch = torch.from_numpy(exemplo_input)
exemplo_input_torch = exemplo_input_torch.float()
print(exemplo_input_torch)
result = model_pytorch(exemplo_input_torch)

错误是

RuntimeError: size mismatch, m1: [1 x 9],m2: 在pytorchatensrcTHgenericTHTensorMath.cpp:41处出现[2048 x 9]。

答案

你需要对权重和偏置进行转置。

weights = torch.from_numpy(weights).T
biases = torch.from_numpy(biases).T
weights2 = torch.from_numpy(weights2).T
biases2 = torch.from_numpy(biases2).T

以上是关于在将TensorFlow模型转换为Pytorch时出现大小不匹配的错误。的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow和pytorch模型之间转换

pytorch 转tensorflow注意

如何在 TensorFlow、Keras 或 PyTorch 中部署 CoreML 模型?

是否可以使用 C++ 训练在 tensorflow 和 pytorch 中开发的 ONNX 模型?

面临从 tensorflow core 转换为 tensorflow lite 的问题

Pytorch | BERT模型实现,提供转换脚本横扫NLP