一文看懂Keras和TensorFlow到底哪家强

Posted 新智元

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了一文看懂Keras和TensorFlow到底哪家强相关的知识,希望对你有一定的参考价值。






  新智元推荐  

来源:AI前线(ID:ai-front)

编辑:Debra

【新智元导读】本文的作者经常在电子邮箱中、社交媒体上,甚至在与深度学习研究人员、从业者和工程师面对面交谈时,会被问到这些问题:我应该在项目中使用 Keras 还是 TensorFlow?TensorFlow 和 Keras 哪个更好?我应该花时间研究TensorFlow 还是 Keras?你是不是也有相同的疑问?如果有,相信这篇文章会给你答案。


实际上,到 2017 年中,Keras 已经被大规模采用,并与 TensorFlow 集成在一起。这种 TensorFlow + Keras 的组合让你可以:


  1. 使用 Keras 的接口定义模型;

  2. 如果你需要特定的 TensorFlow 功能或者需要实现 Keras 不支持但 TensorFlow 支持的自定义功能,可以回到 TensorFlow。


简单地说,你可以将 TensorFlow 代码直接插入到 Keras 的模型或训练管道中!


但请别误会,我并不是说你就不需要了解 TensorFlow 了。我的意思是,如果你:


  1. 刚开始接触深度学习……

  2. 在为下一个项目选型……

  3. 想知道 Keras 或 TensorFlow 哪个“更好”……


我的建议是先从 Keras 着手,然后深入 TensorFlow,这样可以获得你需要的某些特定功能。


在这篇文章中,我将向你展示如何使用 Keras 训练神经网络,以及如何使用直接构建在 TensorFlow 库中的 Keras + TensorFlow 组合来训练模型。


Keras 与 TF 我该学哪个?


在文章的其余部分,我将继续讨论有关 Keras 与 TensorFlow 的争论以及为什么说这个问题其实是个错误的问题。


我们将使用标准的 keras 模块以及 TensorFlow 的 tf.keras 模块实现一个卷积神经网络(CNN)。


我们将在一个样本数据集上训练 CNN,然后检查结果——你会发现,Keras 和 TensorFlow 可以很融洽地合作。


最重要的是,你将会了解为什么 Keras 与 TensorFlow 之间的争论其实是没有意义的。


尽管从 TensorFlow 宣布将 Keras 集成到官方 TensorFlow 版本中已经一年多时间了,但很多深度学习从业者仍然不知道他们可以通过 tf.keras 子模块访问 Keras,为此我感到很惊讶。


更重要的是,Keras + TensorFlow 的集成是无缝的,你可以直接将 TensorFlow 代码放到 Keras 模型中。


在 TensorFlow 中使用 Keras 将为你带来两全其美的好处:


  1. 你可以使用 Keras 提供的简单直观的 API 来创建模型;

  2. Keras API 与 scikit-learn(被认为是机器学习 API 的“黄金标准”)很像;

  3. Keras API 采用了模块化,易于使用;

  4. 当你需要自定义实现或者更复杂的损失函数时,可以直接进入 TensorFlow,并让代码自动与 Keras 模型集成。


在过去几年中,深度学习研究人员、从业人员和工程师通常需要做出以下选择:


  1. 我是选择易用但难以定制的 Keras 库?

  2. 还是选择难用的 TensorFlow API,并编写更多的代码?


所幸的是,我们不必再纠结了。


如果你发现自己还在问这样的问题,那么请退后一步——你问的是错误的问题——你可以同时拥有这两个框架。


一文看懂Keras和TensorFlow到底哪家强


如图所示,导入 TensorFlow(tf),然后调用 tf.keras,可见 Keras 实际上已经成为 TensorFlow 的一部分。


在 tf.keras 中包含 Keras 让你可以使用标准的 Keras 包实现简单的前馈神经网络:


一文看懂Keras和TensorFlow到底哪家强


然后使用 tf.keras 子模块实现相同的网络:


一文看懂Keras和TensorFlow到底哪家强


这是否意味着你必须使用 tf.keras?标准的 Keras 包是不是已经过时?当然不是。


作为一个库,Keras 仍然可以单独使用,因此未来两者可能会分道扬镳。不过,因为谷歌官方支持 Keras 和 TensorFlow,所以似乎不太可能出现这种情况。

关键是:


如果你习惯使用 Keras 编写代码,那么请继续这样做。


但如果你主要使用的是 TensorFlow,那么应该开始考虑一下 Keras API:


  1. 它内置于 TensorFlow 中;

  2. 它更容易使用;

  3. 当你需要使用 TensorFlow 来实现特定功能时,可以直接将其集成到 Keras 模型中。


我们的样本数据集


一文看懂Keras和TensorFlow到底哪家强


CIFAR-10 数据集包含了 10 个分类,我们将它用在我们的演示中。


为简单起见,我们将使用以下方法在 CIFAR-10 数据集上训练两个单独的卷积神经网络:


  1. TensorFlow + Keras;

  2. tf.keras 的 Keras 子模块。


我还将展示如何将自定义的 TensorFlow 代码包含在 Keras 模型中。


我们的项目结构


可以使用 tree 命令在终端中查看我们的项目结构:


一文看懂Keras和TensorFlow到底哪家强


pyimagesearch 模块不能通过 pip 安装,请点击文末提供的下载链接。现在让我们看一下该模块的两个重要 Python 文件:


  • minivggnetkeras.py:MiniVGGNet(一个机遇 VGGNet 的深度学习模型)的 Keras 实现。

  • minivggnettf.py:MiniVGGNet 的 TensorFlow + Keras(即 tf.keras)实现。


项目根目录包含两个 Python 文件:


  • train_network_keras.py:Keras 版本的训练脚本。

  • train_network_tf.py:TensorFlow + Keras 版本的训练脚本,几乎与前一个一模一样。


每个脚本都将生成相应的训练准确率和损失:


  • plot_keras.png

  • plot_tf.png


使用 Keras 训练网络


一文看懂Keras和TensorFlow到底哪家强


训练的第一步是使用 Keras 实现网络架构。


打开 minivggnetkeras.py 文件,并插入以下代码:


一文看懂Keras和TensorFlow到底哪家强


我们先导入构建模型需要的一系列 Keras 包。


然后定义我们的 MiniVGGNetKeras 类:


一文看懂Keras和TensorFlow到底哪家强


我们定义了 build 方法、inputShape 和 input。


然后定义卷积神经网络的主要部分:


一文看懂Keras和TensorFlow到底哪家强


你会发现我们在应用池化层之前堆叠了一系列卷积、ReLU 激活和批量规范化层,以便减少卷的空间维度。还使用了 Dropout 来减少过拟合。


现在将全连接层添加到网络中:


一文看懂Keras和TensorFlow到底哪家强


我们已经使用 Keras 实现了 CNN,现在创建将用于训练的驱动脚本。


打开 train_network_keras.py 并插入以下代码:


一文看懂Keras和TensorFlow到底哪家强


我们先导入需要的包。


  • matplotlib 设置为“Agg”,这样就可以将训练结果保存为图像文件。

  • 然后导入 MiniVGGNetKeras 类。

  • 我们使用 scikit-learn 的 LabelBinarizer 进行“独热”编码,并使用 classification_report 打印分类精度。

  • 然后导入数据集。


我们通过 --plot 传入命令行参数,也就是图像的保存路径。


现在让我们加载 CIFAR-10 数据集,并对标签进行编码:


一文看懂Keras和TensorFlow到底哪家强


我们先加载和提取训练和测试分割,并将它们转换为浮点数和进行数据缩放。


然后我们对标签进行编码,并初始化 labelNames。


接下来,让我们开始训练模型:


一文看懂Keras和TensorFlow到底哪家强


我们先设置训练参数和优化方法。


然后我们使用 MiniVGGNetKeras.build 方法初始化和编译模型。


随后,我们启动了训练程序。


现在让我们来评估网络并生成结果图:



一文看懂Keras和TensorFlow到底哪家强


我们基于数据的测试分割来评估网络,并生成 classification_report,最后再导出结果。


注意:通常我会序列化并导出模型,以便可以将其用在图像或视频的处理脚本中,但这里不打算这样做,因为这超出了本文的范围。


打开一个终端并执行以下命令:


一文看懂Keras和TensorFlow到底哪家强

一文看懂Keras和TensorFlow到底哪家强

一文看懂Keras和TensorFlow到底哪家强


我的 CPU 完成一个 epoch 需要 5 分多钟。

一文看懂Keras和TensorFlow到底哪家强


我们获得了 75%的准确率——当然不是最先进的,不过它比随机猜测(1/10)要好得多。


对于小型网络来说,我们的准确率算是非常好的了,而且没有发生过拟合。


使用 TensorFlow 和 tf.keras 训练网络


使用 tf.keras 构建的 MiniVGGNet CNN 与我们直接使用 Keras 构建的模型是一样的,除了为演示目的而修改的激活函数。


现在我们已经使用 Keras 库实现并训练了一个简单的 CNN,接下来我们要:


  1. 使用 TensorFlow 的 tf.keras 实现相同的网络;

  2. 在 Keras 模型中包含一个 TensorFlow 激活函数,这个函数不是使用 Keras 实现的。


首先,打开 minivggnettf.py 文件,我们将实现 TensorFlow 版本的 MiniVGGNet:


一文看懂Keras和TensorFlow到底哪家强

一文看懂Keras和TensorFlow到底哪家强

一文看懂Keras和TensorFlow到底哪家强


请注意,导入部分只有一行。tf.keras 子模块包含了我们可以直接调用的所有 Keras 函数。


我想强调一下 Lambda 层——它们用来插入自定义激活函数 CRELU(Concatenated ReLU)。


Keras 并没有实现 CRELU,但 TensorFlow 实现了——通过使用 TensorFlow 和 tf.keras,我们可以使用一行代码将 CRELU 添加到 Keras 模型中。


下一步是编写 TensorFlow + Keras 驱动脚本来训练 MiniVGGNetTF。


打开 train_network_tf.py 并插入以下代码:


一文看懂Keras和TensorFlow到底哪家强

一文看懂Keras和TensorFlow到底哪家强


然后是解析命令行参数。


接着像之前一样加载数据集。


其余的行都一样——提取训练 / 测试分割和编码标签。


现在让我们开始训练模型:


一文看懂Keras和TensorFlow到底哪家强

一文看懂Keras和TensorFlow到底哪家强


训练过程几乎是一样的。我们已经实现了完全相同的训练流程,只是这次使用的是 tf.keras。


打开一个终端并执行以下命令:


一文看懂Keras和TensorFlow到底哪家强

一文看懂Keras和TensorFlow到底哪家强


训练完成后,你将获得类似于下面这样的结果:


通过使用 CRELU 替换 RELU 激活函数,我们获得了 76%的准确率。不过,这 1%的提升可能是因为网络权重的随机初始化,需要通过进一步的交叉验证实验来证明这种准确率的提升确实是因为 CRELU。


不管怎样,原始准确率并不是本节的重点。我们需要关注的是如何在 Keras 模型内部使用 TensorFlow 激活函数替换标准的 Keras 激活函数!


你也可以使用自己的自定义激活函数、损失 / 成本函数或层。


总结


在这篇文章中,我们讨论了 Keras 和 TensorFlow 相关的问题,包括:


  • 我应该在项目中使用 Keras 还是 TensorFlow?

  • TensorFlow 和 Keras 哪个更好?

  • 我应该花时间研究 TensorFlow 还是 Keras?


最后我们发现,在 Keras 和 TensorFlow 之间做出选择变得不那么重要。

因为 Keras 库已经通过 tf.keras 模块直接集成到 TensorFlow 中了。


相关代码下载:

https://app.monstercampaigns.com/c/hvovin011avqlrtdtz0j/


英文原文:

https://www.pyimagesearch.com/2018/10/08/keras-vs-tensorflow-which-one-is-better-and-which-one-should-i-learn/


(本文经授权转载自“AI前线”,ID:ai-front)



【加入社群】



以上是关于一文看懂Keras和TensorFlow到底哪家强的主要内容,如果未能解决你的问题,请参考以下文章

深度学习上演神仙打架,PyTorch与TensorFlow到底哪家强?

深度学习框架哪家强?MXNet称霸CNNRNN和情感分析,TensorFlow仅擅长推断特征提取

萌新必看——10种客户端存储哪家强,一文读尽!

一文带你了解知识图谱融入预训练模型哪家强?九大模型集中放送

PyTorch 和 TensorFlow 哪家强:九项对比读懂各自长项短板

PyTorch和TensorFlow哪家强:九项对比读懂各自长项短板