Python:如何从 3D numpy/torch 数组中提取连接的组件(边界框)?
Posted
技术标签:
【中文标题】Python:如何从 3D numpy/torch 数组中提取连接的组件(边界框)?【英文标题】:Python: How to extract connected components (bounding boxes) from 3D numpy / torch array? 【发布时间】:2021-10-27 19:30:04 【问题描述】:我在 NumPy/Torch 中有 3D 数组的二进制分割掩码。我想将这些转换为边界框(又名连接组件)。作为免责声明,每个数组可以包含多个连接的组件/边界框,这意味着我不能只取最小和最大非零索引值。
具体而言,假设我有一个二进制值的 3D 数组(我将使用 2D,因为 2D 更容易可视化)。我想知道连接的组件是什么。比如我想用这个分割掩码:
>>> segmentation_mask
array([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 0, 1, 0],
[1, 1, 0, 0, 1]], dtype=int32)
并将其转换为连通分量,其中连通分量具有任意标签,即
>>> connected_components
array([[1, 0, 0, 0, 0],
[0, 2, 0, 0, 0],
[2, 2, 2, 0, 0],
[2, 2, 0, 3, 0],
[2, 2, 0, 0, 4]], dtype=int32)
如何使用 3D 数组执行此操作?我愿意使用 Numpy、Scipy、Torchvision、opencv 和任何库。
【问题讨论】:
【参考方案1】:这应该适用于任意数量的维度:
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
segmentation_mask = np.array([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 0, 1, 0],
[1, 1, 0, 0, 1]], dtype=np.int32)
row = []
col = []
segmentation_mask_reader = segmentation_mask.reshape(-1)
n_nodes = len(segmentation_mask_reader)
for node in range(n_nodes):
idxs = np.unravel_index(node, segmentation_mask.shape)
if segmentation_mask[idxs] == 0:
col.append(n_nodes)
else:
for i in range(len(idxs)):
if idxs[i] > 0:
new_idxs = list(idxs)
new_idxs[i] -= 1
new_node = np.ravel_multi_index(new_idxs, segmentation_mask.shape)
if segmentation_mask_reader[new_node] != 0:
col.append(new_node)
while len(col) > len(row):
row.append(node)
row = np.array(row, dtype=np.int32)
col = np.array(col, dtype=np.int32)
data = np.ones(len(row), dtype=np.int32)
graph = csr_matrix((np.array(data), (np.array(row), np.array(col))),
shape=(n_nodes+1, n_nodes+1))
n_components, labels = connected_components(csgraph=graph)
background_label = labels[-1]
solution = np.zeros(segmentation_mask.shape, dtype=segmentation_mask.dtype)
solution_writer = solution.reshape(-1)
for node in range(n_nodes):
label = labels[node]
if label < background_label:
solution_writer[node] = label+1
elif label > background_label:
solution_writer[node] = label
print(solution)
【讨论】:
以上是关于Python:如何从 3D numpy/torch 数组中提取连接的组件(边界框)?的主要内容,如果未能解决你的问题,请参考以下文章