pysyft torch.jit.脚本 RuntimeError:未定义的值 _Reduction

Posted

技术标签:

【中文标题】pysyft torch.jit.脚本 RuntimeError:未定义的值 _Reduction【英文标题】:pysyft torrch.jit. script RuntimeError: undefined value _Reduction 【发布时间】:2020-04-02 00:31:20 【问题描述】:

我试图从其高级示例中重现 Pysyft Asynchronous-federated-learning-on-MNIST。其中@torch.jit.script 在损失函数之前使用。我收到了这个错误,不知道这是怎么回事

RuntimeError: undefined value _Reduction: at /home/ab/.virtualenvs/aic/lib/python3.6/site-packages/syft/generic/frameworks/hook/hook.py:1829:20

reduction = _Reduction.legacy_get_string(size_average, reduce)


其实是这几行造成的

@torch.jit.script
def loss_fn(pred, target):
    return F.nll_loss(input=pred, target=target)

train_config = sy.TrainConfig(
        model=traced_model,
        loss_fn=loss_fn,
        batch_size=batch_size,
        shuffle=True,
        max_nr_batches=max_nr_batches,
        epochs=1,
        optimizer="SGD",
        optimizer_args="lr": lr,
    )

【问题讨论】:

【参考方案1】:

编写答案以便对其他人有所帮助。事实证明,@torch.jit.script 需要位于文件的顶部(导入后),而我在两个函数定义之后才拥有它。

将它移到顶部工作

【讨论】:

以上是关于pysyft torch.jit.脚本 RuntimeError:未定义的值 _Reduction的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch模型导出到ONNX文件示例(LeNet-5)

Pysyft学习笔记一:dome思路

Pysyft学习笔记一:dome思路

Pysyft学习笔记二:伪分布式模型训练的实现

Pysyft学习笔记二:伪分布式模型训练的实现

如何在 Maskcrnn libtorch 中获取元组对象返回的值