识别绘画风格的卷积神经网络
Posted 一只夫夫
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了识别绘画风格的卷积神经网络相关的知识,希望对你有一定的参考价值。
实现从图像到艺术风格的映射
这个是挺久之前在轻薄本上跑的了,显卡是MX250,2G
显存。我将西方美术史PPT中 几百张图片做成了一个小数据集,以此数据集为基础进行训练。
from data_utils.data_utils import load_art5
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from model.squeezenet import SqueezeNet
from tensorflow.keras.applications.resnet50 import ResNet50
from model.backprop_squeezenet import SqueezeNet as backprop_SqueezeNet
from pathlib import Path
import os
import datetime
%load_ext tensorboard
x_train, y_train, x_test, y_test, label_to_name = load_art5()
print('x_train',x_train.shape)
print('y_train',y_train.shape)
print('x_test', x_test.shape)
print('y_test', y_test.shape)
print("\\n标签对应的风格:\\n",label_to_name)
x_train (250, 224, 224, 3)
y_train (250,)
x_test (40, 224, 224, 3)
y_test (40,)
标签对应的风格:
0: '古埃及', 1: '印象派', 2: '立体派', 3: '纯粹主义', 4: '灯光效应艺术、奥普艺术'
learning_rate = [2e-4]
model = SqueezeNet(num_classes = 5)
def lr_schedule(epoch):
# 学习速率的变化表
learning_rate = 1e-2
if epoch > 100:
learning_rate = 1e-3
if epoch > 200:
learning_rate = 1e-5
if epoch > 1000:
learning_rate = 2e-7
if epoch > 1500:
learning_rate = 1e-9
tf.summary.scalar('learning rate', data=learning_rate, step=epoch)
return learning_rate
#tune learning rate
for le in learning_rate:
print('learning_rate', le)
print('---------------------------------------------------------------------------------')
log_dir="logs/train1/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=Path(log_dir), histogram_freq=1)
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule, verbose=1)
optimizer = tf.keras.optimizers.Adam(le)
#model.load_weights('./model/squeezenet_weight')
model.compile(optimizer= optimizer,loss='sparse_categorical_crossentropy',
metrics=[tf.keras.metrics.sparse_categorical_accuracy])
model.fit(x_train, y_train, epochs=100,batch_size=50, validation_data = (x_test, y_test), verbose = 2,
callbacks=[tensorboard_callback])
model.evaluate(x_test, y_test, verbose = 1)
model.save_weights('./model/trained_model/art5')
learning_rate 0.0002
---------------------------------------------------------------------------------
Epoch 1/100
WARNING:tensorflow:From C:\\Users\\Yishif\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\python\\ops\\summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0319s vs `on_train_batch_end` time: 0.4966s). Check your callbacks.
5/5 - 3s - loss: 2.9243 - sparse_categorical_accuracy: 0.1600 - val_loss: 2.8109 - val_sparse_categorical_accuracy: 0.1000
Epoch 2/100
5/5 - 2s - loss: 2.4745 - sparse_categorical_accuracy: 0.1840 - val_loss: 2.3348 - val_sparse_categorical_accuracy: 0.3500
Epoch 3/100
5/5 - 2s - loss: 2.1843 - sparse_categorical_accuracy: 0.3840 - val_loss: 2.1565 - val_sparse_categorical_accuracy: 0.3750
Epoch 4/100
5/5 - 2s - loss: 2.0166 - sparse_categorical_accuracy: 0.3560 - val_loss: 2.0870 - val_sparse_categorical_accuracy: 0.3500
Epoch 5/100
5/5 - 2s - loss: 1.8970 - sparse_categorical_accuracy: 0.4160 - val_loss: 2.0318 - val_sparse_categorical_accuracy: 0.3750
Epoch 6/100
5/5 - 2s - loss: 1.8280 - sparse_categorical_accuracy: 0.3720 - val_loss: 1.9484 - val_sparse_categorical_accuracy: 0.4000
Epoch 7/100
5/5 - 2s - loss: 1.7702 - sparse_categorical_accuracy: 0.4120 - val_loss: 1.9097 - val_sparse_categorical_accuracy: 0.3750
Epoch 8/100
5/5 - 2s - loss: 1.7212 - sparse_categorical_accuracy: 0.4160 - val_loss: 1.8980 - val_sparse_categorical_accuracy: 0.4000
Epoch 9/100
5/5 - 2s - loss: 1.6827 - sparse_categorical_accuracy: 0.4080 - val_loss: 1.8328 - val_sparse_categorical_accuracy: 0.4000
Epoch 10/100
5/5 - 2s - loss: 1.6674 - sparse_categorical_accuracy: 0.4560 - val_loss: 1.8411 - val_sparse_categorical_accuracy: 0.4000
Epoch 11/100
5/5 - 2s - loss: 1.6331 - sparse_categorical_accuracy: 0.4200 - val_loss: 1.8294 - val_sparse_categorical_accuracy: 0.3750
Epoch 12/100
5/5 - 2s - loss: 1.6046 - sparse_categorical_accuracy: 0.4480 - val_loss: 1.8084 - val_sparse_categorical_accuracy: 0.3750
Epoch 13/100
5/5 - 2s - loss: 1.5849 - sparse_categorical_accuracy: 0.4640 - val_loss: 1.7869 - val_sparse_categorical_accuracy: 0.3750
Epoch 14/100
5/5 - 2s - loss: 1.5783 - sparse_categorical_accuracy: 0.4520 - val_loss: 1.7202 - val_sparse_categorical_accuracy: 0.3750
Epoch 15/100
5/5 - 2s - loss: 1.5890 - sparse_categorical_accuracy: 0.4480 - val_loss: 1.7938 - val_sparse_categorical_accuracy: 0.3750
Epoch 16/100
5/5 - 2s - loss: 1.5535 - sparse_categorical_accuracy: 0.5000 - val_loss: 1.7424 - val_sparse_categorical_accuracy: 0.3750
Epoch 17/100
5/5 - 2s - loss: 1.5241 - sparse_categorical_accuracy: 0.4760 - val_loss: 1.7203 - val_sparse_categorical_accuracy: 0.3750
Epoch 18/100
5/5 - 2s - loss: 1.5358 - sparse_categorical_accuracy: 0.4640 - val_loss: 1.6901 - val_sparse_categorical_accuracy: 0.4000
Epoch 19/100
5/5 - 2s - loss: 1.5553 - sparse_categorical_accuracy: 0.4400 - val_loss: 1.8723 - val_sparse_categorical_accuracy: 0.3750
Epoch 20/100
5/5 - 2s - loss: 1.5542 - sparse_categorical_accuracy: 0.4840 - val_loss: 1.7149 - val_sparse_categorical_accuracy: 0.4250
Epoch 21/100
5/5 - 2s - loss: 1.5115 - sparse_categorical_accuracy: 0.4960 - val_loss: 1.8497 - val_sparse_categorical_accuracy: 0.4000
Epoch 22/100
5/5 - 2s - loss: 1.4885 - sparse_categorical_accuracy: 0.4760 - val_loss: 1.7004 - val_sparse_categorical_accuracy: 0.3750
Epoch 23/100
5/5 - 2s - loss: 1.4982 - sparse_categorical_accuracy: 0.4560 - val_loss: 1.8699 - val_sparse_categorical_accuracy: 0.4000
Epoch 24/100
5/5 - 2s - loss: 1.4815 - sparse_categorical_accuracy: 0.4520 - val_loss: 1.6570 - val_sparse_categorical_accuracy: 0.2750
Epoch 25/100
5/5 - 2s - loss: 1.4591 - sparse_categorical_accuracy: 0.4600 - val_loss: 1.8539 - val_sparse_categorical_accuracy: 0.4000
Epoch 26/100
5/5 - 2s - loss: 1.4765 - sparse_categorical_accuracy: 0.5080 - val_loss: 1.6151 - val_sparse_categorical_accuracy: 0.3500
Epoch 27/100
5/5 - 2s - loss: 1.4180 - sparse_categorical_accuracy: 0.4960 - val_loss: 1.6629 - val_sparse_categorical_accuracy: 0.3750
Epoch 28/100
5/5 - 2s - loss: 1.4058 - sparse_categorical_accuracy: 0.5400 - val_loss: 1.7261 - val_sparse_categorical_accuracy: 0.3500
Epoch 29/100
5/5 - 2s - loss: 1.3546 - sparse_categorical_accuracy: 0.5280 - val_loss: 1.6914 - val_sparse_categorical_accuracy: 0.3500
Epoch 30/100
5/5 - 2s - loss: 1.3454 - sparse_categorical_accuracy: 0.5320 - val_loss: 1.7979 - val_sparse_categorical_accuracy: 0.3500
Epoch 31/100
5/5 - 2s - loss: 1.3643 - sparse_categorical_accuracy: 0.5440 - val_loss: 1.6637 - val_sparse_categorical_accuracy: 0.4250
Epoch 32/100
5/5 - 2s - loss: 1.3620 - sparse_categorical_accuracy: 0.5800 - val_loss: 1.5831 - val_sparse_categorical_accuracy: 0.4250
Epoch 33/100
5/5 - 2s - loss: 1.3816 - sparse_categorical_accuracy: 0.5480 - val_loss: 1.7300 - val_sparse_categorical_accuracy: 0.4250
Epoch 34/100
5/5 - 2s - loss: 1.3421 - sparse_categorical_accuracy: 0.5920 - val_loss: 1.6342 - val_sparse_categorical_accuracy: 0.3750
Epoch 35/100
5/5 - 2s - loss: 1.2748 - sparse_categorical_accuracy: 0.5960 - val_loss: 1.6406 - val_sparse_categorical_accuracy: 0.4750
Epoch 36/100
5/5 - 2s - loss: 1.2682 - sparse_categorical_accuracy: 0.5680 - val_loss: 1.7089 - val_sparse_categorical_accuracy: 0.4000
Epoch 37/100
5/5 - 2s - loss: 1.1762 - sparse_categorical_accuracy: 0.6240 - val_loss: 1.6893 - val_sparse_categorical_accuracy: 0.5000
Epoch 38/100
5/5 - 2s - loss: 1.1809 - sparse_categorical_accuracy: 0.6440 - val_loss: 1.5903 - val_sparse_categorical_accuracy: 0.4000
Epoch 39/100
5/5 - 2s - loss: 1.2160 - sparse_categorical_accuracy: 0.6480 - val_loss: 1.8914 - val_sparse_categorical_accuracy: 0.4000
Epoch 40/100
5/5 - 2s - loss: 1.2396 - sparse_categorical_accuracy: 0.6160 - val_loss: 1.6752 - val_sparse_categorical_accuracy: 0.4250
Epoch 41/100
5/5 - 2s - loss: 1.2391 - sparse_categorical_accuracy: 0.6680 - val_loss: 1.7163 - val_sparse_categorical_accuracy: 0.4500
Epoch 42/100
5/5 - 2s - loss: 1.1911 - sparse_categorical_accuracy: 0.6600 - val_loss: 1.6151 - val_sparse_categorical_accuracy: 0.4250
Epoch 43/100
5/5 - 2s - loss: 1.1188 - sparse_categorical_accuracy: 0.6840 - val_loss: 1.7541 - val_sparse_categorical_accuracy: 0.4750
Epoch 44/100
5/5 - 2s - loss: 1.0513 - sparse_categorical_accuracy: 0.6840 - val_loss: 1.6569 - val_sparse_categorical_accuracy: 0.5000
Epoch 45/100
5/5 - 2s - loss: 0.9712 - sparse_categorical_accuracy: 0.7440 - val_loss: 1.8007 - val_sparse_categorical_accuracy: 0.5250
Epoch 46/100
5/5 - 2s - loss: 0.9272 - sparse_categorical_accuracy: 0.7440 - val_loss: 2.0097 - val_sparse_categorical_accuracy: 0.5500
Epoch 47/100
5/5 - 2s - loss: 0.8684 - sparse_categorical_accuracy: 0.7880 - val_loss: 2.0083 - val_sparse_categorical_accuracy: 0.5250
Epoch 48/100
5/5 - 2s - loss: 0.8466 - sparse_categorical_accuracy: 0.7560 - val_loss: 1.9569 - val_sparse_categorical_accuracy: 0.5500
Epoch 49/100
5/5 - 2s - loss: 0.8995 - sparse_categorical_accuracy: 0.7760 - val_loss: 2.0226 - val_sparse_categorical_accuracy: 0.5000
Epoch 50/100
5/5 - 2s - loss: 0.9518 - sparse_categorical_accuracy: 0.7400 - val_loss: 1.9933 - val_sparse_categorical_accuracy: 0.4750
Epoch 51/100
5/5 - 2s - loss: 0.8644 - sparse_categorical_accuracy: 0.7920 - val_loss: 2.3111 - val_sparse_categorical_accuracy: 0.4500
Epoch 52/100
5/5 - 2s - loss: 0.8003 - sparse_categorical_accuracy: 0.8200 - val_loss: 2.5071 - val_sparse_categorical_accuracy: 0.4500
Epoch 53/100
5/5 - 2s - loss: 0.7072 - sparse_categorical_accuracy: 0.8760 - val_loss: 2.4365 - val_sparse_categorical_accuracy: 0.5000
Epoch 54/100
5/5 - 2s - loss: 0.6829 - sparse_categorical_accuracy: 0.8640 - val_loss: 2.8239 - val_sparse_categorical_accuracy: 0.4750
Epoch 55/100
5/5 - 2s - loss: 0.7343 - sparse_categorical_accuracy: 0.8480 - val_loss: 2.5380 - val_sparse_categorical_accuracy: 0.5000
Epoch 56/100
5/5 - 2s - loss: 0.6952 - sparse_categorical_accuracy: 0.8520 - val_loss: 2.7319 - val_sparse_categorical_accuracy: 0.4500
Epoch 57/100
5/5 - 2s - loss: 0.6705 - sparse_categorical_accuracy: 0.8840 - val_loss: 2.6535 - val_sparse_categorical_accuracy: 0.5000
Epoch 58/100
5/5 - 2s - loss: 0.6237 - sparse_categorical_accuracy: 0.9040 - val_loss: 2.7103 - val_sparse_categorical_accuracy: 0.5250
Epoch 59/100
5/5 - 2s - loss: 0.6020 - sparse_categorical_accuracy: 0.8960 - val_loss: 2.8843 - val_sparse_categorical_accuracy: 0.5000
Epoch 60/100
5/5 - 2s - loss: 0.5429 - sparse_categorical_accuracy: 0.9200 - val_loss: 2.7795 - val_sparse_categorical_accuracy: 0.5250
Epoch 61/100
5/5 - 2s - loss: 0.5071 - sparse_categorical_accuracy: 0.9520 - val_loss: 2.9273 - val_sparse_categorical_accuracy: 0.5750
Epoch 62/100
5/5 - 2s - loss: 0.4594 - sparse_categorical_accuracy: 0.9600 - val_loss: 3.2589 - val_sparse_categorical_accuracy: 0.5250
Epoch 63/100
5/5 - 2s - loss: 0.4720 - sparse_categorical_accuracy: 0.9600 - val_loss: 3.6066 - val_sparse_categorical_accuracy: 0.5750
Epoch 64/100
5/5 - 2s - loss: 0.4593 - sparse_categorical_accuracy: 0.9360 - val_loss: 3.8621 - val_sparse_categorical_accuracy: 0.5000
Epoch 65/100
5/5 - 2s - loss: 0.4724 - sparse_categorical_accuracy: 0.9400 - val_loss: 4.0097 - val_sparse_categorical_accuracy: 0.5500
Epoch 66/100
5/5 - 2s - loss: 0.4126 - sparse_categorical_accuracy: 0.9640 - val_loss: 3.4587 - val_sparse_categorical_accuracy: 0.5250
Epoch 67/100
5/5 - 2s - loss: 0.4057 - sparse_categorical_accuracy: 0.9840 - val_loss: 4.7477 - val_sparse_categorical_accuracy: 0.5250
Epoch 68/100
5/5 - 2s - loss: 0.4589 - sparse_categorical_accuracy: 0.9440 - val_loss: 3.6056 - val_sparse_categorical_accuracy: 0.5750
Epoch 69/100
5/5 - 2s - loss: 0.5638 - sparse_categorical_accuracy: 0.9160 - val_loss: 4.1461 - val_sparse_categorical_accuracy: 0.5250
Epoch 70/100
5/5 - 2s - loss: 1.0732 - sparse_categorical_accuracy: 0.7840 - val_loss: 3.0761 - val_sparse_categorical_accuracy: 0.4750
Epoch 71/100
5/5 - 2s - loss: 0.7008 - sparse_categorical_accuracy: 0.8560 - val_loss: 2.8879 - val_sparse_categorical_accuracy: 0.4000
Epoch 72/100
5/5 - 2s - loss: 0.7639 - sparse_categorical_accuracy: 0.8600 - val_loss: 1.9902 - val_sparse_categorical_accuracy: 0.5750
Epoch 73/100
5/5 - 2s - loss: 0.6854 - sparse_categorical_accuracy: 0.9000 - val_loss: 2.7760 - val_sparse_categorical_accuracy: 0.6250
Epoch 74/100
5/5 - 2s - loss: 0.5475 - sparse_categorical_accuracy: 0.9520 - val_loss: 3.0067 - val_sparse_categorical_accuracy: 0.6000
Epoch 75/100
5/5 - 2s - loss: 0.5402 - sparse_categorical_accuracy: 0.9520 - val_loss: 3.2242 - val_sparse_categorical_accuracy: 0.5500
Epoch 76/100
5/5 - 2s - loss: 0.4982 - sparse_categorical_accuracy: 0.9560 - val_loss: 3.4900 - val_sparse_categorical_accuracy: 0.5750
Epoch 77/100
5/5 - 2s - loss: 0.4636 - sparse_categorical_accuracy: 0.9800 - val_loss: 3.5054 - val_sparse_categorical_accuracy: 0.6000
Epoch 78/100
5/5 - 2s - loss: 0.4421 - sparse_categorical_accuracy: 0.9800 - val_loss: 3.8175 - val_sparse_categorical_accuracy: 0.5500
Epoch 79/100
5/5 - 2s - loss: 0.4156 - sparse_categorical_accuracy: 0.9800 - val_loss: 4.0961 - val_sparse_categorical_accuracy: 0.5500
Epoch 80/100
5/5 - 2s - loss: 0.3983 - sparse_categorical_accuracy: 0.9840 - val_loss: 4.0084 - val_sparse_categorical_accuracy: 0.5000
Epoch 81/100
5/5 - 2s - loss: 0.3933 - sparse_categorical_accuracy: 0.9880 - val_loss: 4.3952 - val_sparse_categorical_accuracy: 0.5500
Epoch 82/100
5/5 - 2s - loss: 0.3802 - sparse_categorical_accuracy: 0.9880 - val_loss: 4.5125 - val_sparse_categorical_accuracy: 0.5750
Epoch 83/100
5/5 - 2s - loss: 0.3598 - sparse_categorical_accuracy: 0.9920 - val_loss: 4.6252 - val_sparse_categorical_accuracy: 0.5250
Epoch 84/100
5/5 - 2s - loss: 0.3486 - sparse_categorical_accuracy: 0.9960 - val_loss: 4.8283 - val_sparse_categorical_accuracy: 0.5250
Epoch 85/100
5/5 - 2s - loss: 0.3814 - sparse_categorical_accuracy: 0.9840 - val_loss: 4.2472 - val_sparse_categorical_accuracy: 0.5750
Epoch 86/100
5/5 - 2s - loss: 0.3722 - sparse_categorical_accuracy: 0.9840 - val_loss: 4.4263 - val_sparse_categorical_accuracy: 0.5500
Epoch 87/100
5/5 - 2s - loss: 0.3801 - sparse_categorical_accuracy: 0.9720 - val_loss: 4.6447 - val_sparse_categorical_accuracy: 0.5750
Epoch 88/100
5/5 - 2s - loss: 0.3513 - sparse_categorical_accuracy: 0.9920 - val_loss: 4.6057 - val_sparse_categorical_accuracy: 0.5750
Epoch 89/100
5/5 - 2s - loss: 0.3604 - sparse_categorical_accuracy: 0.9880 - val_loss: 4.5952 - val_sparse_categorical_accuracy: 0.5250
Epoch 90/100
5/5 - 2s - loss: 0.4851 - sparse_categorical_accuracy: 0.9600 - val_loss: 4.0727 - val_sparse_categorical_accuracy: 0.5750
Epoch 91/100
5/5 - 2s - loss: 0.3744 - sparse_categorical_accuracy: 0.9800 - val_loss: 3.9599 - val_sparse_categorical_accuracy: 0.5750
Epoch 92/100
5/5 - 2s - loss: 0.3490 - sparse_categorical_accuracy: 1.0000 - val_loss: 4.1361 - val_sparse_categorical_accuracy: 0.6250
Epoch 93/100
5/5 - 2s - loss: 0.3401 - sparse_categorical_accuracy: 0.9960 - val_loss: 4.3469 - val_sparse_categorical_accuracy: 0.6000
Epoch 94/100
5/5 - 2s - loss: 0.3290 - sparse_categorical_accuracy: 0.9960 - val_loss: 4.9883 - val_sparse_categorical_accuracy: 0.6250
Epoch 95/100
5/5 - 2s - loss: 0.3250 - sparse_categorical_accuracy: 0.9960 - val_loss: 5.0136 - val_sparse_categorical_accuracy: 0.6000
Epoch 96/100
5/5 - 2s - loss: 0.3181 - sparse_categorical_accuracy: 1.0000 - val_loss: 4.8905 - val_sparse_categorical_accuracy: 0.6000
Epoch 97/100
5/5 - 2s - loss: 0.3051 - sparse_categorical_accuracy: 1.0000 - val_loss: 5.0910 - val_sparse_categorical_accuracy: 0.6000
Epoch 98/100
5/5 - 2s - loss: 0.3001 - sparse_categorical_accuracy: 1.0000 - val_loss: 5.3141 - val_sparse_categorical_accuracy: 0.6000
Epoch 99/100
5/5 - 2s - loss: 0.2903 - sparse_categorical_accuracy: 1.0000 - val_loss: 5.3422 - val_sparse_categorical_accuracy: 0.6000
Epoch 100/100
5/5 - 2s - loss: 0.2850 - sparse_categorical_accuracy: 1.0000 - val_loss: 5.2941 - val_sparse_categorical_accuracy: 0.6000
2/2 [==============================] - 1s 341ms/step - loss: 5.2941 - sparse_categorical_accuracy: 0.6000
%tensorboard --logdir logs/train1 --port=6008
Reusing TensorBoard on port 6008 (pid 20356), started 0:00:08 ago. (Use '!kill 20356' to kill it.)
#得到测试输出
y = model.predict(x_test)
y = np.argmax(y, axis = 1)
从上图训练过程可以看出,训练集正确率能达到100%。只是由于训练集比较少,测试集准确率只达到了60%,出现了过拟合现象,但这足以说明,卷积神经网络并不仅仅局限于特征的识别,它本质上是一种拟合函数,CNN对于拟合与特征有关联的东西,能力还是很强的,比如对于风格的分类,卷积神经网络是能够实现这种拟合的
接下来开始测试训练结果
可见准确率超过了50%,如果训练数据样本足够大的话,理论上是能够和训练集一样达到100%准确率的
plt.figure(dpi=100)
for i in range(3*3):
plt.subplot(3,3,i+1)
i = i+6
title = "label:"+str(int(y_test[i]))+"predict:"+str(y[i])
plt.title(title)
plt.imshow(x_test[i])
plt.axis('off')
print("标签与类别的对应关系:",label_to_name)
标签与类别的对应关系: 0: '古埃及', 1: '印象派', 2: '立体派', 3: '纯粹主义', 4: '灯光效应艺术、奥普艺术'
以上是关于识别绘画风格的卷积神经网络的主要内容,如果未能解决你的问题,请参考以下文章
基于深度卷积神经网络的图像风格迁移 与神经涂鸦系统的设计与实现