基于LSTM的验证码识别
Posted 人工智能与图像处理
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于LSTM的验证码识别相关的知识,希望对你有一定的参考价值。
一,基本版本
1.1 训练图片示例:
1.2 代码:
详细代码见原文,不能贴太多源码
1.3 结果:
二,改进版本
改进版本改成了实时生成图片,不用读取本地图片,无限生成数据集。代码改动还是比较大的,有兴趣可以对比一下。同时添加了一些功能,比如:可视化,指数下降学习率,断点续训等等
2.1 训练图片示例:
可以根据个人口味酌情更改验证码生成样式:
2.2 代码:
2.2.1 验证码生成代码一:
1#encoding=utf-8
2import random
3# import matplotlib.pyplot as plt
4import string
5import sys
6import math
7from PIL import Image,ImageDraw,ImageFont,ImageFilter
8filename="./My_captcha/"
9#字体的位置,不同版本的系统会有不同BuxtonSketch.ttf
10font_path = 'C:/Windows/Fonts/Georgia.ttf'
11#font_path = 'C:/Windows/Fonts/默陌肥圆手写体.ttf'
12#生成几位数的验证码
13number = 4
14#生成验证码图片的高度和宽度
15size = (80,26)
16#背景颜色,默认为白色
17bgcolor = (255,255,255)
18#字体颜色,默认为蓝色
19fontcolor = (0,0,0)
20#干扰线颜色。默认为红色
21linecolor = (0,0,0)
22#是否要加入干扰线
23draw_line = True
24#加入干扰线条数的上下限
25line_number = (1,5)
26
27#用来随机生成一个字符串
28def gene_text():
29 # source = list(string.letters)
30 # for index in range(0,10):
31 # source.append(str(index))
32 source = ['0','1','2','3','4','5','6','7','8','9','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H','I','J', 'K','L', 'M', 'N','O','P','Q','R',
33 'S', 'T', 'U', 'V', 'W', 'Z','X', 'Y']
34 # source = [ 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H','I','J', 'K','L', 'M', 'N','O','P','Q','R',
35 # 'S', 'T', 'U', 'V', 'W', 'Z','X', 'Y']
36 return ''.join(random.sample(source,number))#number是生成验证码的位数
37#用来绘制干扰线
38def gene_line(draw,width,height):
39 # begin = (random.randint(0, width), random.randint(0, height))
40 # end = (random.randint(0, width), random.randint(0, height))
41 begin = (0, random.randint(0, height))
42 end = (74, random.randint(0, height))
43 draw.line([begin, end], fill = linecolor,width=3)
44
45#生成验证码
46def gene_code():
47 width,height = size #宽和高
48 image = Image.new('RGBA',(width,height),bgcolor) #创建图片
49 font = ImageFont.truetype(font_path,25) #验证码的字体
50 draw = ImageDraw.Draw(image) #创建画笔
51 text = gene_text() #生成字符串
52 font_width, font_height = font.getsize(text)
53 draw.text(((width - font_width) / number, (height - font_height) / number),text,\
54 font= font,fill=fontcolor) #填充字符串
55 if draw_line:
56 gene_line(draw,width,height)
57 image = image.transform((width+30,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR) #创建扭曲
58 # image = image.transform((width+20,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR) #创建扭曲
59 image = image.filter(ImageFilter.EDGE_ENHANCE_MORE) #滤镜,边界加强
60 # a = str(m)
61 aa = str(".png")
62 path = filename + text + aa
63 # cv2.imwrite(path, I1)
64 # image.save('idencode.jpg') #保存验证码图片
65 image.save(path)
66
67
68x=1
69# if __name__ == "__main__":
70# for k in(1,1000):
71while x<10:
72 gene_code()
73 x+=1
74 if(x%100==0):
75 print("Iter:%d" % x)
2.2.2 验证码生成代码2:
1from captcha.image import ImageCaptcha # pip install captcha
2import numpy as np
3import matplotlib.pyplot as plt
4from PIL import Image
5import random
6
7# 验证码中的字符, 就不用汉字了
8number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
9alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
10 'v', 'w', 'x', 'y', 'z']
11ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U',
12 'V', 'W', 'X', 'Y', 'Z']
13
14
15# 验证码一般都无视大小写;验证码长度4个字符
16def random_captcha_text(char_set=number+ ALPHABET, captcha_size=4):
17 captcha_text = []
18 for i in range(captcha_size):
19 c = random.choice(char_set)
20 captcha_text.append(c)
21 return captcha_text
22
23
24# 生成字符对应的验证码
25def gen_captcha_text_and_image():
26 image = ImageCaptcha()
27
28 captcha_text = random_captcha_text()
29 captcha_text = ''.join(captcha_text)
30
31 captcha = image.generate(captcha_text)
32 captcha_image = Image.open(captcha)
33
34 captcha_image = np.array(captcha_image) #转成np.array
35 captcha_image=Image.fromarray(np.uint8(captcha_image)) #转成PIL Image
36 captcha_image = captcha_image.resize((80, 26), Image.ANTIALIAS) #缩放
37 captcha_image.save("./imgs/"+captcha_text + '.jpg') #存图
38 # captcha_image.write(captcha_text, "./imgs/"+captcha_text + '.jpg') # 写到文件
39 return captcha_text, captcha_image
40
41
42if __name__ == '__main__':
43 # 测试
44 for i in range(10000):
45 text, image = gen_captcha_text_and_image()
46 if(i%100==0):
47 print("生成第%s张图" % i)
48
49
50 # f = plt.figure()
51 # ax = f.add_subplot(111)
52 # ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes)
53 # plt.imshow(image)
54 #
55 # plt.show()
2.2.3 模型训练整体代码:
1#-*- coding:utf-8 -*
2import os
3import random
4import captcha
5import numpy as np
6import tensorflow as tf
7from captcha.image import ImageCaptcha # pip install captcha
8from PIL import Image,ImageDraw,ImageFont,ImageFilter
9
10#########全局变量###########################################
11path = os.getcwd() #项目所在路径
12output_path = path + '/result/result.txt' #测试结果存放路径
13MODEL_SAVE_PATH = "./model/"
14MODEL_NAME = "LSTM_Captcha" # 保存模型名称
15
16#要识别的字符
17number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
18ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
19alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
20
21batch_size = 64 # size of batch
22time_steps = 26 # 每个time_step是图像的一行像素 height
23n_input = 80 # rows of 28 pixels #width
24image_channels = 1 # 图像的通道数
25captcha_num = 4 # 验证码中字符个数
26n_classes = len(number) + len(ALPHABET) #类别分类
27
28learning_rate = 0.001 #learning rate for adam
29decaystep = 5000 # 实现衰减的频率
30decay_rate = 0.5 # 衰减率
31num_units = 64 #hidden LSTM units
32layer_num = 2 #网络层数
33iteration = 20000 #训练迭代次数
34
35#自动生成图像
36IMAGE_HEIGHT = 26 # 图像高
37IMAGE_WIDTH = 80 # 图像宽
38#生成验证码图片的宽度和高度
39size = (IMAGE_WIDTH,IMAGE_HEIGHT)
40#背景颜色,默认为白色
41bgcolor = (255,255,255)
42#字体颜色,默认为黑色
43fontcolor = (0,0,0)
44#字体的位置,不同版本的系统会有不同BuxtonSketch.ttf
45font_path = 'C:/Windows/Fonts/Georgia.ttf'
46#########全局变量###########################################
47
48# 随机生成4个数字+大小写字母的数组
49def random_captcha_text(char_set=number+ALPHABET, captcha_size=4): #数字
50 captcha_text = [] # 初始化一个空列表
51 for i in range(captcha_size): # 产生字符的个数
52 c = random.choice(char_set) # 随机产生数字
53 captcha_text.append(c) # 加入列表
54 return ''.join(captcha_text) # 返回生成的字符
55
56# 随机生成4个数字的图片
57def gen_captcha_text_and_image():
58 width,height = size #宽和高
59 image = Image.new('RGBA',(width,height),bgcolor) #创建图片
60 font = ImageFont.truetype(font_path,25) #验证码的字体
61 draw = ImageDraw.Draw(image) #创建画笔
62 captcha_text = random_captcha_text() # 随机生成4个数字的数组
63 font_width, font_height = font.getsize(captcha_text) #字体大小
64 draw.text(((width - font_width) / captcha_num, (height - font_height) / captcha_num),\
65 captcha_text,font= font,fill=fontcolor) #填充字符串
66 image = image.filter(ImageFilter.EDGE_ENHANCE_MORE) # 滤镜,边界加强
67 # aa = str(".png")
68 # path = "./" + captcha_text + aa
69 # image.save(path)
70 captcha_image = np.array(image) # 转化成array数组
71 return captcha_text, captcha_image
72
73# 转换成灰度图
74def convert2gray(img):
75 if len(img.shape) > 2:
76 gray = np.mean(img, -1)
77 # 上面的转法较快,正规转法如下
78 # r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
79 # gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
80 return gray
81 else:
82 return img
83
84# 字符串转换成0000100的数组
85def text2vec(text):
86 text_len = len(text)
87 if text_len > captcha_num:
88 raise ValueError('验证码最长4个字符')
89
90 vector = np.zeros(captcha_num*n_classes)
91
92 def char2pos(c):
93 if c == '_':
94 k = 62
95 return k
96 k = ord(c) - 48
97 if k > 9:
98 k = ord(c) - 55
99 if k > 35:
100 k = ord(c) - 61
101 if k > 61:
102 raise ValueError('No Map')
103 return k
104
105 for i, c in enumerate(text):
106 idx = i * n_classes + char2pos(c)
107 vector[idx] = 1
108 return vector
109
110# 0000100的数组转换成字符串
111def vec2text(vec):
112 char_pos = vec.nonzero()[0]
113 text = []
114 for i, c in enumerate(char_pos):
115 char_at_pos = i # c/63
116 char_idx = c % n_classes
117 if char_idx < 10:
118 char_code = char_idx + ord('0')
119 elif char_idx < 36:
120 char_code = char_idx - 10 + ord('A')
121 elif char_idx < 62:
122 char_code = char_idx - 36 + ord('a')
123 elif char_idx == 62:
124 char_code = ord('_')
125 else:
126 raise ValueError('error')
127 text.append(chr(char_code))
128 return "".join(text)
129
130# [22,32,1,5]类型转换成字符
131def index2char(vec):
132 text=[]
133 chr=''
134 for i in range(len(vec[0])):
135 subVec=vec[0][i]
136 listChr=[]
137 for id in range(captcha_num):
138 if subVec[id]<10:
139 chr=number[subVec[id]]
140 listChr.append(chr)
141 elif subVec[id]<36:
142 chr=ALPHABET[subVec[id]-10]
143 listChr.append(chr)
144 elif subVec[id] < 62:
145 chr = ALPHABET[subVec[id] - 36]
146 listChr.append(chr)
147 elif subVec[id] == 62:
148 listChr.append('_')
149 else:
150 raise ValueError('error')
151 str=''.join(listChr)
152 text.append(str)
153 return text
154
155# 产生用于训练的bacth_size0大小的数据集
156def get_next_batch(batch_size0=64):
157 batch_x = np.zeros([batch_size0, time_steps, n_input])
158 batch_y = np.zeros([batch_size0, captcha_num, n_classes])
159
160 # 内部定义一个用于产生图片和标签的函数
161 def wrap_gen_captcha_text_and_image():
162 while True:
163 text, image = gen_captcha_text_and_image()
164 if image.shape == (IMAGE_HEIGHT, IMAGE_WIDTH, 4):
165 return text, image
166
167 for i in range(batch_size0): # 按batch_size0大小循环产生图片
168 text, image = wrap_gen_captcha_text_and_image() # 产生图片
169 image = convert2gray(image) # 转化成灰度图
170 image = np.array(image)
171 image=image/255
172 # image = image.flatten() / 255 # image.flatten()是转化为一行,除以255是归一化
173 # image = np.reshape(np.array(image), [IMAGE_HEIGHT, IMAGE_WIDTH]) # 转换格式:(2080,) => (26,80)
174 batch_x[i] =image
175 ss=text2vec(text)
176 batch_y[i] = np.reshape(text2vec(text), [captcha_num,n_classes])# 转换为标签
177 return batch_x, batch_y
178
179#构建lstm网络
180def computational_graph_lstm(x, y, global_step):
181 #weights and biases of appropriate shape to accomplish above task
182 out_weights = tf.Variable(tf.random_normal([num_units,n_classes]), name = 'out_weight')
183 out_bias = tf.Variable(tf.random_normal([n_classes]),name = 'out_bias')
184
185 #构建网络
186 lstm_layer = [tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=True) for _ in range(layer_num)] #创建两层的lstm
187 mlstm_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_layer, state_is_tuple = True) #将lstm连接在一起
188 init_state = mlstm_cell.zero_state(batch_size, tf.float32) #cell的初始状态
189
190 outputs = list() #每个cell的输出
191 state = init_state
192 with tf.variable_scope('RNN'):
193 for timestep in range(time_steps):
194 if timestep > 0:
195 tf.get_variable_scope().reuse_variables()
196 (cell_output, state) = mlstm_cell(x[:, timestep, :], state) # 这里的state保存了每一层 LSTM 的状态
197 outputs.append(cell_output)
198 # h_state = outputs[-1] #取最后一个cell输出
199
200 #计算输出层的第一个元素
201 prediction_1 = tf.nn.softmax(tf.matmul(outputs[-4],out_weights)+out_bias) #获取最后time-step的输出,使用全连接, 得到第一个验证码输出结果
202 #计算输出层的第二个元素
203 prediction_2 = tf.nn.softmax(tf.matmul(outputs[-3],out_weights)+out_bias) #输出第二个验证码预测结果
204 #计算输出层的第三个元素
205 prediction_3 = tf.nn.softmax(tf.matmul(outputs[-2],out_weights)+out_bias) #输出第三个验证码预测结果
206 #计算输出层的第四个元素
207 prediction_4 = tf.nn.softmax(tf.matmul(outputs[-1],out_weights)+out_bias) #输出第四个验证码预测结果,size:[batch,num_class]
208 #输出连接
209 prediction_all = tf.concat([prediction_1, prediction_2, prediction_3, prediction_4],1) # 4 * [batch, num_class] => [batch, 4 * num_class]
210 prediction_all = tf.reshape(prediction_all,[batch_size, captcha_num, n_classes],name ='prediction_merge') # [4, batch, num_class] => [batch, 4, num_class]
211
212 #loss_function
213 # 损失
214 with tf.name_scope('loss'): # 损失
215 loss = -tf.reduce_mean(y * tf.log(prediction_all),name = 'loss')
216 tf.summary.scalar('loss', loss)
217 # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction_all,labels=y))
218 #optimization
219 opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name = 'opt').minimize(loss,global_step=global_step) # 断点续训这里不加global_step=global_step会出错
220 #model evaluation
221 pre_arg = tf.argmax(prediction_all,2,name = 'predict')
222 y_arg = tf.argmax(y,2)
223 correct_prediction = tf.equal(pre_arg, y_arg)
224
225 with tf.name_scope('accuracy'): # 损失
226 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name = 'accuracy')
227 tf.summary.scalar('accuracy', accuracy)
228
229 return opt, loss, accuracy, pre_arg, y_arg
230
231#训练
232def train():
233 # defining placeholders
234 x = tf.placeholder("float",[None,time_steps,n_input], name = "x") #input image placeholder
235 y = tf.placeholder("float",[None,captcha_num,n_classes], name = "y") #input label placeholder
236
237 # 运行了几轮batch_size的计数器,初值给0,设为不被训练
238 global_step = tf.Variable(0, trainable=False)
239
240 # 学习率自然指数衰减
241 learing_rate_decay = tf.train.natural_exp_decay(learning_rate, global_step, decaystep, decay_rate, staircase=True)
242
243 # computational graph
244 opt, loss, accuracy, pre_arg, y_arg = computational_graph_lstm(x, y, global_step)
245
246 # 创建训练模型保存类
247 saver = tf.train.Saver(max_to_keep=1)
248
249 # 初始化变量值
250 init = tf.global_variables_initializer()
251
252 # 将图形、训练过程等数据合并在一起
253 merged = tf.summary.merge_all()
254
255 with tf.Session() as sess: # 创建tensorflow session
256 sess.run(init)
257
258 writer = tf.summary.FileWriter('logs', sess.graph) # 将训练日志写入到logs文件夹下
259
260 # ----------断点续训--------------------------
261 ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
262 if ckpt and ckpt.model_checkpoint_path:
263 saver.restore(sess, ckpt.model_checkpoint_path)
264 # ----------断点续训--------------------------
265
266 iter = 1 #迭代次数计数器
267 while iter < iteration:
268 batch_x, batch_y = get_next_batch(batch_size)
269 sess.run(opt, feed_dict={x: batch_x, y: batch_y}) #只运行优化迭代计算图
270
271
272 if iter %100==0:
273 result = sess.run(merged, feed_dict={x: batch_x, y: batch_y}) # 只运行优化迭代计算图
274 writer.add_summary(result, iter) # 将日志数据写入文件
275
276 los, acc, parg, yarg, iter = sess.run([loss, accuracy, pre_arg, y_arg, global_step],feed_dict={x:batch_x,y:batch_y})
277 print("iter:%d,Accuracy:%f,Loss:%f " % (iter, acc, los))
278
279 if iter % 1000 == 0: #保存模型
280 # ----------指数衰减型学习率-------------------
281 learning_rate_val = sess.run(learing_rate_decay)
282 print("After %s steps,learing rate is %f" % (iter, learning_rate_val))
283 # ----------指数衰减型学习率-------------------
284
285 # ----------断点续训--------------------------
286 saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
287 # ----------断点续训--------------------------
288
289 iter += 1
290 # 计算验证集准确率
291 valid_x, valid_y = get_next_batch(batch_size)
292 print("Validation Accuracy:", sess.run(accuracy, feed_dict={x: valid_x, y: valid_y}))
293
294#预测
295def predict():
296 with tf.Session() as sess:
297 saver = tf.train.import_meta_graph(path + "/model/" + "LSTM_Captcha-19000.meta")
298 saver.restore(sess, tf.train.latest_checkpoint(path + "/model/")) #读取已训练模型
299
300 graph = tf.get_default_graph() #获取原始计算图,并读取其中的tensor
301 x = graph.get_tensor_by_name("x:0")
302 y = graph.get_tensor_by_name("y:0")
303 pre_arg = graph.get_tensor_by_name("predict:0")
304
305 # test_x, file_list = get_test_set() #获取测试集
306 test_x, test_y =get_next_batch(batch_size)
307 batch_test_y = np.zeros([batch_size, captcha_num, n_classes]) # 创建空的y输入
308 test_predict = sess.run([pre_arg], feed_dict={x: test_x, y: batch_test_y})
309 predict_result=index2char(np.array(test_predict)) #转成字符串
310 predict_result = predict_result[:len(test_y)] #预测结果
311 write_to_file(predict_result, test_y) #保存到文件
312
313#预测结果写入文档
314def write_to_file(predict_list, test_y):
315 label_y = np.reshape(test_y, [batch_size, captcha_num * n_classes])
316 with open(output_path, 'w') as f:
317 for i, res in enumerate(predict_list):
318 y_ = vec2text(label_y[i]) #转成字符串
319 if i == 0:
320 f.write("id\tfile\tresult\n")
321 f.write(str(i) + "\t" + y_ + "\t" + res + "\n")
322 f.write("\n")
323 print("预测结果保存在:",output_path)
324
325#训练
326train()
327
328#预测
329# predict()
2.3 结果:
欢迎扫码关注我的微信公众号
日语口语小学堂:
以上是关于基于LSTM的验证码识别的主要内容,如果未能解决你的问题,请参考以下文章