numpy.argmax 用在求解混淆矩阵用

Posted 将者,智、信、仁、勇、严也。

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了numpy.argmax 用在求解混淆矩阵用相关的知识,希望对你有一定的参考价值。

numpy.argmax

numpy.argmax(a, axis=None, out=None)[source]

Returns the indices of the maximum values along an axis.

Parameters:

a : array_like

Input array.

axis : int, optional

By default, the index is into the flattened array, otherwise along the specified axis.

out : array, optional

If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.

Returns:

index_array : ndarray of ints

Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

See also

ndarray.argmax, argmin

amax
The maximum value along a given axis.
unravel_index
Convert a flat index into an index tuple.

Notes

In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

Examples

>>> a = np.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b) # Only the first occurrence is returned.
1

在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别
if __name__ == "__main__":
    width, height = 32, 32
    X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
    trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
    print("sample data:")
    print(trainX[0])
    print(trainY[0])
    print(testX[-1])
    print(testY[-1])

    model = get_model(width, height, classes=100)

    filename = ‘cnn_handwrite-acc0.8.tflearn‘
    # try to load model and resume training
    #try:
    #    model.load(filename)
    #    print("Model loaded OK. Resume training!")
    #except:
    #    pass

    # Initialize our callback with desired accuracy threshold.
    early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.6)
    try:
        model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                  snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                  show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id=‘cnn_handwrite‘)
    except StopIteration as e:
        print("OK, stop iterate!Good!")

    model.save(filename)

    # predict all data and calculate confusion_matrix
    model.load(filename)

    pro_arr =model.predict(X)
    predict_labels = np.argmax(pro_arr, axis=1)
    print(classification_report(org_labels, predict_labels))
    print(confusion_matrix(org_labels, predict_labels))

 

以上是关于numpy.argmax 用在求解混淆矩阵用的主要内容,如果未能解决你的问题,请参考以下文章

二维矩阵中某个轴上的 Numpy.argmax()

Numpy反向keras to_categorical

在多维数组上使用 numpy.argmax()

模型评价 AUC详细概述

matlab混淆矩阵怎么变大

存储为numpy的图像中的XY坐标?