在 Sklearn 中再次绘制测试、验证和训练 Acc Epochs
Posted
技术标签:
【中文标题】在 Sklearn 中再次绘制测试、验证和训练 Acc Epochs【英文标题】:Plotting Test, Valid and Train Acc again Epochs in Sklearn 【发布时间】:2019-02-20 06:46:10 【问题描述】:是否有任何内置方法可以在 Sklearn 中为 MLP 分类器在每个时期绘制 Train、Valid、Test 图?
【问题讨论】:
我不太确定你在问什么。 scikit-learn verbose parameter 是您要找的吗? 这仅显示训练进度。我需要有关每次迭代或时期的训练、验证和测试数据的准确性的信息。或者一种获得这三个精度对迭代或时期的图的方法。感谢您的反馈。 是的。我正在寻找这个解决方案。谢谢兄弟.... 【参考方案1】:这个解决方案(代码取自here)应该可以帮助你:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_mldata
from sklearn.neural_network import MLPClassifier
np.random.seed(1)
""" Example based on sklearn's docs """
mnist = fetch_mldata("MNIST original")
# rescale the data, use the traditional train/test split
X, y = mnist.data / 255., mnist.target
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
solver='adam', verbose=0, tol=1e-8, random_state=1,
learning_rate_init=.01)
""" Home-made mini-batch learning
-> not to be used in out-of-core setting!
"""
N_TRAIN_SAMPLES = X_train.shape[0]
N_EPOCHS = 25
N_BATCH = 128
N_CLASSES = np.unique(y_train)
scores_train = []
scores_test = []
# EPOCH
epoch = 0
while epoch < N_EPOCHS:
print('epoch: ', epoch)
# SHUFFLING
random_perm = np.random.permutation(X_train.shape[0])
mini_batch_index = 0
while True:
# MINI-BATCH
indices = random_perm[mini_batch_index:mini_batch_index + N_BATCH]
mlp.partial_fit(X_train[indices], y_train[indices], classes=N_CLASSES)
mini_batch_index += N_BATCH
if mini_batch_index >= N_TRAIN_SAMPLES:
break
# SCORE TRAIN
scores_train.append(mlp.score(X_train, y_train))
# SCORE TEST
scores_test.append(mlp.score(X_test, y_test))
epoch += 1
""" Plot """
fig, ax = plt.subplots(2, sharex=True, sharey=True)
ax[0].plot(scores_train)
ax[0].set_title('Train')
ax[1].plot(scores_test)
ax[1].set_title('Test')
fig.suptitle("Accuracy over epochs", fontsize=14)
plt.show()
【讨论】:
以上是关于在 Sklearn 中再次绘制测试、验证和训练 Acc Epochs的主要内容,如果未能解决你的问题,请参考以下文章
在带有分组约束的 sklearn (python 2.7) 中创建训练、测试和交叉验证数据集?