高斯混合模型 (GMM) 提供与训练数据无关的结果

Posted

技术标签:

【中文标题】高斯混合模型 (GMM) 提供与训练数据无关的结果【英文标题】:Gaussian Mixture Model (GMM) delivers results that are unrelated to the training data 【发布时间】:2021-12-15 19:50:09 【问题描述】:

正如您在图片中看到的,聚类似乎与提供的数据完全无关。 我有 34 个数据点。 这可能是什么原因?

用不同的组件拟合 K GMM:

  def calculate_zones(self):

    mini_data = []
    for id, enter_x, enter_y, enter_time, exit_x, exit_y, exit_time in self.data:
        mini_data.append([enter_x, enter_y])
        mini_data.append([exit_x, exit_y])

    K = range(2, 4)

    for k in K:
        # Set the model and its parameters
        self.gms.append(
            GaussianMixture(n_components=k, n_init=20, covariance_type='spherical', init_params='kmeans').fit(
                mini_data))

根据 GMM 的结果屏蔽图像:

  def display_zones(self):
    video = cv2.VideoCapture(self.video_path)
    if not video.isOpened():
        print("Cannot open stream")
        exit()
    _, frame = video.read()

    mask = []

    colors = random_color(11)

    for model in self.gms:
        curr_mask = np.zeros_like(frame)
        mask.append(curr_mask)
        row_index = 0
        for pixel_row in frame:
            column_index = 0
            for _ in pixel_row:
                prediction = model.predict_proba([[row_index, column_index]])

                best_proba = 0
                counter = 0
                for one_prediction in prediction[0]:
                    if one_prediction > prediction[0][best_proba]:
                        best_proba = counter
                    counter += 1

                curr_mask[row_index][column_index] = colors[best_proba]
                column_index += 1

            row_index += 1

【问题讨论】:

【参考方案1】:

首先,您应该为每个班级安装一个 GMM。 GMM 是一种无监督的概率算法,因此您不能为 .fit() 函数提供多个类并期望它天生就可以区分类(除非您的目的是将这些区域映射在一起,那么这种方法很好)。

其次,您使用以下格式的数据训练模型:

mini_data = []
for id, enter_x, enter_y, enter_time, exit_x, exit_y, exit_time in self.data:
    mini_data.append([enter_x, enter_y])
    mini_data.append([exit_x, exit_y])

然后使用以下方法获得预测:

for _ in pixel_row:
     prediction = model.predict_proba([[row_index, column_index]])

您的函数的输入不应该是 [[column_index, row_index]] 以匹配您的火车数据的 [x,y] 格式吗?

【讨论】:

你的第二点是正确的,我在row_index和colum_index之间切换。但是关于第一点,据我了解,GMM 与 BIC 和 EM 一起可以估计类的数量并对其进行聚类,我绘制了均值和协方差,结果看起来还不错。为什么你认为每个班级都应该有一个 GMM?无监督的关键是你不知道你有多少个班级...... 是的。因此,如果您想进行聚类并找到数据点属于哪个集群(无监督),那很好,因为希望每个集群应该主要由一个类组成。然而,如果你想描述你的类,你需要在你的类中分离你的 GMM。当您对来自每个 GMM 的新数据执行 predict_proba() 时,您最大的可能性将代表正确的 GMM 类。 好的,我明白了,如果一个类有多个集群,那么我应该给每个类一个 GMM,但就我而言,我没有那种意义上的类,我只是在寻找对于集群。您能否编辑您的答案并解释一个类可以有多个集群。因为它一开始让我很困惑。 那我可以接受你的回答:) 这完全取决于您的数据。如果您将身高和体重归类为性别的函数,那么您会期望女性更矮/更轻与男性相反之间存在相关性。但是会有包含超过 1 个类的集群,因为这显然不是超离散类分离。因此,集群不能被明确地标记为一个类,因为数据没有那么明显的分离。

以上是关于高斯混合模型 (GMM) 提供与训练数据无关的结果的主要内容,如果未能解决你的问题,请参考以下文章

如何简单易懂的解释高斯混合(GMM)模型?

05 EM算法 - 高斯混合模型 - GMM

高斯混合模型GMM核心参数高斯混合模型GMM的数学形式

使用经过训练的高斯混合模型标记新数据

高斯混合模型(GMM)和EM算法

高斯混合模型(GMM)