tensorflow2.0 新特性小练习
Posted chrisinsistpy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow2.0 新特性小练习相关的知识,希望对你有一定的参考价值。
基于tf2.0 对Kaggel Google street view characters classify 项目练手, 熟悉一下tf2.0的新特性
下载下来kaggle的数据集如下:
所有训练数据在train文件夹中, labels在trainLabels.cvs文件中, label文件格式如下:
分别每个label对应其图片的名称
首先对数据进行预处理 总共有61个类别从a-z, A-Z, 0-9等,代码如下:
from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np import tensorflow as tf from tensorflow.python import keras import csv import pathlib keras.backend.clear_session() csv_filepath = ‘E:\\\\work\\\\Kaggle\\\\street-view-getting-started-with-julia\\\\trainLabels.csv‘ data_root_path = ‘E:\\\\work\\\\Kaggle\\\\street-view-getting-started-with-julia\\\\train‘ csv_file = csv.reader(open(csv_filepath, ‘r‘)) label_container = [] labels = [] all_image_labels = [] AUTOTUNE = tf.data.experimental.AUTOTUNE for cnt in csv_file: if cnt[1] not in labels: labels.append(cnt[1]) label_container.append(cnt) labels = labels[1:] label_container = label_container[1:] labels = np.sort(labels) labels_to_index = dict((name, index) for index,name in enumerate(labels)) data_root = pathlib.Path(data_root_path) all_image_paths = list(data_root.glob(‘*‘)) all_image_paths = [str(path) for path in all_image_paths] for item in data_root.iterdir(): # all_img_path.append(item) name = item.name[:-4] for match in label_container: if name == match[0]:
all_image_labels.append(key if value == match[1]
for key, value in enumerate(labels_to_index))
生成的all_image_paths 和 all_image_labels分别包含如下:
因为测试集没有label,所以把训练集分三份,并用tf.data.Data去映射数据空间
train_img_path = all_image_paths[:4000] val_img_path = all_image_paths[4000:5000] test_img_path = all_image_paths[5000:] train_img_labels = all_image_labels[:4000] val_img_labels = all_image_labels[4000:5000] test_img_labels = all_image_labels[5000:] raw_train_ds = tf.data.Dataset.from_tensor_slices((train_img_path, train_img_labels)) raw_val_ds = tf.data.Dataset.from_tensor_slices((val_img_path, val_img_labels)) raw_test_ds = tf.data.Dataset.from_tensor_slices((test_img_path, test_img_labels))
Scale 图片,并对其做数据增强,来满足translation invarience
以上是关于tensorflow2.0 新特性小练习的主要内容,如果未能解决你的问题,请参考以下文章
详解深度强化学习展现TensorFlow 2.0新特性(代码)
译ECMAScript 2016, 2017, 2018 新特性之必读篇