Python OpenCV RTrees 无法正确加载

Posted

技术标签:

【中文标题】Python OpenCV RTrees 无法正确加载【英文标题】:Python OpenCV RTrees does not load properly 【发布时间】:2019-04-10 09:21:55 【问题描述】:

我有一个使用 OpenCV 3.3.1 版的训练有素的 RTrees 模型,我已保存该模型,需要稍后加载以进行预测。

我正在尝试使用 OpenCV 的 RTrees 创建一个随机森林分类器。但是,我无法成功加载使用 save 功能保存的模型。 我使用的代码示例如下:

import numpy as np
import cv2

def train(samples, class_labels, save_file=None):
    """
    samples      : np.ndarray of type np.float32
    class_labels :  np.ndarray of type int
    save_file    : str of the absolute path to the save file
    """
    model = cv2.ml.RTrees_create()

    # Set paremeters
    model.setMaxDepth(20)
    model.setActiveVarCount(0)
    term_type, n_trees, epsilon = cv2.TERM_CRITERIA_MAX_ITER, 128, 1
    model.setTermCriteria((term_type, n_trees, epsilon))

    train_data = cv2.ml.TrainData_create(samples=samples,
                                         layout=cv2.ml.ROW_SAMPLE,
                                         responses=class_labels)
    model.train(trainData=train_data)

    if save_file:
        model.save(save_file)

def test(samples, load_file):
        """
        samples   : np.ndarray of type np.float32
        load_file : str of the absolute path to the load file
        """
        model = cv2.ml.RTrees_create()
        model.load(load_file)

        _ret, responses = model.predict(samples)
        return responses.ravel()

if __name__ == '__main__':
    model_file = '/home/user/bin/tmp/m.xml'

    x = np.random.randint(0, 5, (1000, 15))
    y = np.random.randint(0, 4, 1000)
    train(x, y, model_file)

    x_test = np.random.randint(0, 5, (100, 15))
    test(x_test, model_file)

错误如下:

Traceback (most recent call last):
  File "tmp.py", line 45, in <module>
    test(x_test, model_file)
  File "tmp.py", line 34, in test
    _ret, responses = model.predict(samples)
cv2.error: OpenCV(3.4.2) /io/opencv/modules/ml/src/tree.cpp:1498: error: (-215:Assertion failed) !roots.empty() in function 'predict'

【问题讨论】:

请显示您正在使用的代码。另见here 和here 添加示例@Miki 【参考方案1】:

问题是加载函数是一个静态函数,它返回加载的模型。 OpenCV 的文档和示例并未具体说明这一点。

...
model = cv2.ml.RTrees_create()
model = model.load(file_name)

# Or:
model = cv2.ml.RTrees_load(file_name)

请参阅 OpenCV 网站上的 this。

【讨论】:

以上是关于Python OpenCV RTrees 无法正确加载的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 OpenCV RTrees 进行二元分类?

OpenCV 随机决策森林:如何获得后验概率

cv2.imshow命令在opencv-python中无法正常工作

OpenCV 中的 KLT 跟踪器无法与 Python 一起正常工作

HoughCircles 在 OpenCV 中无法正确检测圆

安装opencv python mac osx 10.11?