为啥 SHAP 的 Deep Explainer 在 ResNet-50 预训练模型上失败?
Posted
技术标签:
【中文标题】为啥 SHAP 的 Deep Explainer 在 ResNet-50 预训练模型上失败?【英文标题】:Does anyone know why SHAP's Deep Explainer fails on ResNet-50 pretrained model?有谁知道为什么 SHAP 的 Deep Explainer 在 ResNet-50 预训练模型上失败? 【发布时间】:2021-06-10 10:41:30 【问题描述】:我已经训练了一个 ResNet-50 模型,并且我正在尝试实现可解释性方法 SHAP。我发现这对https://github.com/slundberg/shap/issues/1479 中的其他人来说是个问题。
我的代码严格遵循https://github.com/slundberg/shap/blob/master/notebooks/image_examples/image_classification/PyTorch%20Deep%20Explainer%20MNIST%20example.ipynb 中 SHAP pytorch 示例中的说明:
def main():
output_format = Pupils()
model_path = ".../trained_model_30_epochs"
ellipse_overlay_path = ".../ellipse_overlay.png"
path_to_dataset = [DatasetPath(path='...',
image_type='png')]
path_to_dataset_to_explain = [DatasetPath(path='...',
image_type='png')]
dataset = RealDataset(path_to_dataset,
output_format.ground_truth_from_annotations,
transform=get_transforms(is_training=False, is_synthetic=False),
load_into_ram=True
)
dataset_to_explain = RealDataset(path_to_dataset_to_explain,
output_format.ground_truth_from_annotations,
transform=get_transforms(is_training=False, is_synthetic=False),
load_into_ram=True
)
dataloader = DataLoader(dataset,
batch_size=100,
shuffle=False)
dataloader_to_explain = DataLoader(dataset_to_explain,
batch_size=32,
shuffle=False)
resnet = models.resnet50(pretrained=True)
resnet.fc = torch.nn.Linear(2048, 5)
resnet.load_state_dict(torch.load(model_path))
resnet.eval()
batch = next(iter(dataloader))
images = batch['image'].expand(-1, 3, -1, -1)
# select a set of background examples to take an expectation over
background = images[:100] # dataset[0]['image']
# background = background.expand(3, -1, -1).unsqueeze(0)
batch = next(iter(dataloader_to_explain))
test_images = batch['image'].expand(-1, 3, -1, -1) # dataset_to_explain[1]['image']
# test_images = test_images.expand(3, -1, -1).unsqueeze(0)
e = shap.DeepExplainer(resnet, background)
shap_values = e.shap_values(test_images)
shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
test_numpy = np.swapaxes(np.swapaxes(test_images[:3].numpy(), 1, -1), 1, 2)
# plot the feature attributions
shap.image_plot(shap_numpy, -test_numpy)
回溯错误是:
Traceback (most recent call last):
File "/.../PycharmProjects/thesis/SHAP.py", line 133, in <module>
main()
File ".../PycharmProjects/thesis/SHAP.py", line 123, in main
shap_values = e.shap_values(test_images)
File ".../anaconda3/lib/python3.8/site-packages/shap/explainers/_deep/__init__.py", line 124, in shap_values
return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
File ".../anaconda3/lib/python3.8/site-packages/shap/explainers/_deep/deep_pytorch.py", line 185, in shap_values
sample_phis = self.gradient(feature_ind, joint_x)
File ".../anaconda3/lib/python3.8/site-packages/shap/explainers/_deep/deep_pytorch.py", line 121, in gradient
grad = torch.autograd.grad(selected, x,
File ".../anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 202, in grad
return Variable._execution_engine.run_backward(
File ".../anaconda3/lib/python3.8/site-packages/shap/explainers/_deep/deep_pytorch.py", line 226, in deeplift_grad
return op_handler[module_type](module, grad_input, grad_output)
File ".../anaconda3/lib/python3.8/site-packages/shap/explainers/_deep/deep_pytorch.py", line 358, in nonlinear_1d
grad_output[0] * (delta_out / delta_in).repeat(dup0))
RuntimeError: The size of tensor a (512) must match the size of tensor b (2048) at non-singleton dimension 1
有没有人知道是什么原因造成的?我已经检查了很多次代码,但我找不到哪里出了问题……在训练有素的网络架构上,张量大小不应该不匹配。
请注意,我有意删除了任何个人识别路径:)。
非常感谢您的任何意见!
【问题讨论】:
问题解决了吗? 【参考方案1】:我认为出现这个错误是因为 ResNet 在 FC 层需要 softmax。尝试这样做:
resnet = models.resnet50(pretrained=True)
resnet.fc = torch.nn.Linear(2048, 5)
resnet.load_state_dict(torch.load(model_path))
# add this code above, after load model
resnet.fc = nn.Sequential(
resnet.fc,
nn.Softmax(1),
)
resnet.eval()
【讨论】:
以上是关于为啥 SHAP 的 Deep Explainer 在 ResNet-50 预训练模型上失败?的主要内容,如果未能解决你的问题,请参考以下文章
ClipByValue 不存在于 Shap 包的 tf_ops._gradient_registry._registry
带有 TensorFlow 2.4+ 错误的 SHAP DeepExplainer