六、用自然语言描述图像
如果图像分类和物体检测是明智的任务,那么用自然语言描述图像绝对是一项更具挑战性的任务,需要更多的智能-请片刻考虑一下每个人如何从新生儿成长(他们学会了识别物体并检测它们的位置)到三岁的孩子(他们学会讲述图片故事)。 用自然语言描述图像的任务的正式术语是图像标题。 与具有长期研究和发展历史的语音识别不同,图像字幕(具有完整的自然语言,而不仅仅是关键词输出)由于其复杂性和 2012 年的深度学习突破而仅经历了短暂而令人兴奋的研究历史。
在本章中,我们将首先回顾基于深度学习的图像字幕模型如何赢得 2015 年 Microsoft COCO(大规模对象检测,分割和字幕数据集),我们在第 3 章,“检测对象及其位置”中简要介绍了该有效模型。 然后,我们将总结在 TensorFlow 中训练模型的步骤,并详细介绍如何准备和优化要在移动设备上部署的复杂模型。 之后,我们将向您展示有关如何构建 iOS 和 Android 应用以使用该模型生成描述图像的自然语言语句的分步教程。 由于该模型同时涉及计算机视觉和自然语言处理,因此您将首次看到两种主要的深度神经网络架构 CNN 和 RNN 如何协同工作,以及如何编写 iOS 和 Android 代码以访问经过训练的网络并进行多个推理。 总而言之,我们将在本章介绍以下主题:
- 图像字幕 – 工作原理
- 训练和冻结图像字幕模型
- 转换和优化图像字幕模型
- 在 iOS 中使用图像字幕模型
- 在 Android 中使用图像字幕模型
图像字幕 – 工作原理
Show and Tell:从 2015 年 MSCOCO 图像字幕挑战赛中汲取的经验教训。 在讨论训练过程之前,TensorFlow 的 im2txt 模型文档网站中也对此进行了很好的介绍,让我们首先基本了解一下解模型的工作原理。 这也将帮助您了解 Python 中的训练和推理代码,以及本章稍后将介绍的 iOS 和 Android 中的推理代码。
获奖的 Show and Tell 模型是使用端到端方法进行训练的,类似于我们在上一章中简要介绍的最新的基于深度学习的语音识别模型。 它使用 MSCOCO 图像字幕 2014 数据集,可从这里下载,该数据集包含超过 82,000 个训练图像,并以描述它们的自然语言句子为目标。 训练模型以使为每个输入图像输出目标自然语言句子的可能性最大化。 与使用多个子系统的其他更复杂的训练方法不同,端到端方法优雅,简单,并且可以实现最新的结果。
为了处理和表示输入图像,Show and Tell 模型使用预训练的 Inception v3 模型,该模型与我们在第 2 章,“通过迁移学习对图像进行分类”所使用的相同。 Inception v3 CNN 网络的最后一个隐藏层用作输入图像的表示。 由于 CNN 模型的性质,较早的层捕获更多的基本图像信息,而较后的层捕获更高级的图像概念。 因此,通过使用输入图像的最后一个隐藏层来表示图像,我们可以更好地准备具有高级概念的自然语言输出。 毕竟,我们通常会开始用诸如“人”或“火车”之类的词来描述图片,而不是“带有尖锐边缘的东西”。
为了表示目标自然语言输出中的每个单词,使用了单词嵌入方法。 词嵌入只是词的向量表示。 TensorFlow 网站上有一个不错的教程,介绍如何构建模型来获取单词的向量表示。
现在,在既表示输入图像又表示输出单词的情况下(每个这样的单词对构成一个训练示例),给定的最佳训练模型可用于最大化在目标输出中生成每个单词w
的概率,给定输入图像和该单词w
之前的先前单词,它是 RNN 序列模型,或更具体地说,是长短期记忆(LSTM)的 RNN 模型类型。 LSTM 以解决常规 RNN 模型固有的消失和爆炸梯度问题而闻名。 为了更好地了解 LSTM,您应该查看这个热门博客。
梯度概念在反向传播过程中用于更新网络权重,因此它可以学习生成更好的输出。 如果您不熟悉反向传播过程,它是神经网络中最基本,功能最强大的算法之一,那么您绝对应该花些时间来理解它-只是 Google 的“反向传播”,排名前五的结果都不会令人失望。 消失的梯度意味着,在深度神经网络反向传播学习过程中,早期层中的网络权重几乎没有更新,因此网络永不收敛。 梯度爆炸意味着这些权重更新得过分疯狂,从而导致网络差异很大。 因此,如果某人头脑封闭,从不学习,或者某人对新事物疯狂而又失去兴趣就快,那么您就会知道他们似乎遇到了什么样的梯度问题。
训练后,可以将 CNN 和 LSTM 模型一起用于推理:给定输入图像,该模型可以估计每个单词的概率,从而预测最有可能为输出语句生成哪n
个最佳单词; 然后,给定输入图像和n
个最佳单词,可以生成n
个最佳的下一个单词,然后继续进行,直到模型返回句子的特定结尾单词,或达到了生成的句子的指定单词长度(以防止模型过于冗长)时,我们得到一个完整的句子。
在每次生成单词时使用n
个最佳单词(意味着在末尾具有n
个最佳句子)被称为集束搜索。 当n
(即集束大小)为 1 时,它仅基于模型返回的所有可能单词中的最高概率值,就成为贪婪搜索或最佳搜索。 TensorFlow im2txt 官方模型的下一部分中的训练和推理过程使用以 Python 实现的集束大小设置为 3 的集束搜索; 为了进行比较,我们将开发的 iOS 和 Android 应用使用更简单的贪婪或最佳搜索。 您将看到哪种方法可以生成更好的字幕。
训练和冻结图像字幕模型
在本部分中,我们将首先总结训练训练名为 im2txt 的 Show and Tell 模型的过程,该模型记录在这个页面中, 一些提示,以帮助您更好地了解该过程。 然后,我们将展示 im2txt 模型项目随附的 Python 代码的一些关键更改,以便冻结该模型以准备在移动设备上使用。
训练和测试字幕生成
如果您已按照第 3 章“检测对象及其位置”中的“设置 TensorFlow 对象检测 API”部分进行操作,那么您已经安装im2txt
文件夹; 否则,只需将cd
移至您的 TensorFlow 源根目录,然后运行:
git clone https://github.com/tensorflow/models
您可能尚未安装的一个 Python 库是 自然语言工具包(NLTK),这是最流行的用于自然语言处理的 Python 库之一。 只需访问其网站以获得安装说明。
现在,请按照以下步骤来训练模型:
- 通过打开终端并运行以下命令来设置保存 2014 MSCOCO 图像字幕训练和验证数据集的位置:
MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
请注意,尽管 2014 年要下载和保存的原始数据集约为 20GB,但该数据集将转换为 TFRecord 格式(我们还在第 3 章 “检测对象及其位置”来转换对象检测数据集,这是运行以下训练脚本所需的,并添加了大约 100GB 数据。 因此,使用 TensorFlow im2txt 项目总共需要约 140GB 的训练自己的图像字幕模型。
- 转到您的 im2txt 源代码所在的位置,然后下载并处理 MSCOCO 数据集:
cd <your_tensorflow_root>/models/research/im2txt
bazel build //im2txt:download_and_preprocess_mscoco
bazel-bin/im2txt/download_and_preprocess_mscoco "${MSCOCO_DIR}"
download_and_preprocess_mscoco
脚本完成后,您将在$MSCOCO_DIR
文件夹中看到所有 TFRecord 格式的训练,验证和测试数据文件。
在$MSCOCO_DIR
文件夹中还生成了一个名为word_counts.txt
的文件。 它总共有 11,518 个单词,每行包含一个单词,一个空格以及该单词出现在数据集中的次数。 文件中仅保存计数等于或大于 4 的单词。 还保存特殊词,例如句子的开头和结尾(分别表示为<S>
和 </S>
)。 稍后,您将看到我们如何在 iOS 和 Android 应用中专门使用和解析文件来生成字幕。
- 通过运行以下命令来获取 Inception v3 检查点文件:
INCEPTION_DIR="${HOME}/im2txt/data"
mkdir -p ${INCEPTION_DIR}
cd ${INCEPTION_DIR}
wget "http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz"
tar -xvf inception_v3_2016_08_28.tar.gz -C ${INCEPTION_DIR}
rm inception_v3_2016_08_28.tar.gz
之后,您将在${HOME}/im2txt/data
文件夹中看到一个名为inception_v3.ckpt
的文件,如下所示:
jeff@AiLabby:~/im2txt/data$ ls -lt inception_v3.ckpt
-rw-r----- 1 jeff jeff 108816380 Aug 28 2016 inception_v3.ckpt
- 现在,我们准备使用以下命令来训练我们的模型:
INCEPTION_CHECKPOINT="${HOME}/im2txt/data/inception_v3.ckpt"
MODEL_DIR="${HOME}/im2txt/model"
cd <your_tensorflow_root>/models/research/im2txt
bazel build -c opt //im2txt/...
bazel-bin/im2txt/train
--input_file_pattern="${MSCOCO_DIR}/train-?????-of-00256"
--inception_checkpoint_file="${INCEPTION_CHECKPOINT}"
--train_dir="${MODEL_DIR}/train"
--train_inception=false
--number_of_steps=1000000
即使在 GPU 上(例如第 1 章, “移动 TensorFlow 入门”中设置的 Nvidia GTX 1070),整个步骤(在前面的--number_of_steps
参数中指定)也会超过 5 个昼夜,因为运行 5 万步大约需要 6.5 个小时。 幸运的是,您很快就会看到,即使以大约 50K 的步长,图像字幕的结果也已经相当不错了。 另请注意,您可以随时取消train
脚本,然后稍后重新运行它,该脚本将从最后保存的检查点开始; 默认情况下,检查点会每 10 分钟保存一次,因此在最坏的情况下,您只会损失 10 分钟的训练时间。
经过几个小时的训练,取消前面的train
脚本,然后查看--train_dir
指向的位置。 您将看到类似这样的内容(默认情况下,将保存五组检查点文件,但此处仅显示三组):
ls -lt $MODEL_DIR/train
-rw-rw-r-- 1 jeff jeff 2171543 Feb 6 22:17 model.ckpt-109587.meta
-rw-rw-r-- 1 jeff jeff 463 Feb 6 22:17 checkpoint
-rw-rw-r-- 1 jeff jeff 149002244 Feb 6 22:17 model.ckpt-109587.data-00000-of-00001
-rw-rw-r-- 1 jeff jeff 16873 Feb 6 22:17 model.ckpt-109587.index
-rw-rw-r-- 1 jeff jeff 2171543 Feb 6 22:07 model.ckpt-109332.meta
-rw-rw-r-- 1 jeff jeff 16873 Feb 6 22:07 model.ckpt-109332.index
-rw-rw-r-- 1 jeff jeff 149002244 Feb 6 22:07 model.ckpt-109332.data-00000-of-00001
-rw-rw-r-- 1 jeff jeff 2171543 Feb 6 21:57 model.ckpt-109068.meta
-rw-rw-r-- 1 jeff jeff 149002244 Feb 6 21:57 model.ckpt-109068.data-00000-of-00001
-rw-rw-r-- 1 jeff jeff 16873 Feb 6 21:57 model.ckpt-109068.index
-rw-rw-r-- 1 jeff jeff 4812699 Feb 6 14:27 graph.pbtxt
您可以告诉每 10 分钟生成一组检查点文件(model.ckpt-109068.*
和model.ckpt-109332.*
和model.ckpt-109587.*
)。 graph.pbtxt
是模型的图定义文件(以文本格式),model.ckpt-??????.meta
文件还包含模型的图定义,以及特定检查点的其他一些元数据,例如model.ckpt-109587.data-00000-of-00001
(请注意, 大小几乎为 150MB,因为所有网络参数都保存在此处)。
- 测试字幕生成,如下所示:
CHECKPOINT_PATH="${HOME}/im2txt/model/train"
VOCAB_FILE="${HOME}/im2txt/data/mscoco/word_counts.txt"
IMAGE_FILE="${HOME}/im2txt/data/mscoco/raw-data/val2014/COCO_val2014_000000224477.jpg"
bazel build -c opt //im2txt:run_inference
bazel-bin/im2txt/run_inference
--checkpoint_path=${CHECKPOINT_PATH}
--vocab_file=${VOCAB_FILE}
--input_files=${IMAGE_FILE}
CHECKPOINT_PATH
被设置为与--train_dir
被设置为相同的路径。 run_inference
脚本将生成类似以下内容(不完全相同,具体取决于已执行了多少训练步骤):
Captions for image COCO_val2014_000000224477.jpg:
0) a man on a surfboard riding a wave . (p=0.015135)
1) a person on a surfboard riding a wave . (p=0.011918)
2) a man riding a surfboard on top of a wave . (p=0.009856)
这很酷。 如果我们可以在智能手机上运行此模型,会不会更酷? 但是在此之前,由于模型的相对复杂性以及 Python 中train
和run_inference
脚本的编写方式,我们还需要采取一些额外的步骤。
冻结图像字幕模型
在第 4 章,“转换具有惊人艺术风格的图片”,和第 5 章,“了解简单语音命令”中,我们使用了一个名为freeze.py
的脚本的两个略有不同的版本,将受过训练的网络权重与网络图定义合并到一个自足的模型文件中,这是我们可以在移动设备上使用的好处。 TensorFlow 还带有freeze
脚本的更通用版本,称为freeze_graph.py
,位于tensorflow/python/tools
文件夹中,可用于构建模型文件。 要使其正常运行,您需要为其提供至少四个参数(要查看所有可用参数,请查看 tensorflow/python/tools/freeze_graph.py
):
-
--input_graph
或--input_meta_graph
:模型的图定义文件。 例如,在上一节的第 4 步的命令ls -lt $MODEL_DIR/train
的输出中,model.ckpt-109587.meta
是一个元图文件,其中包含模型的图定义和其他与检查点相关的元数据,而graph.pbtxt
只是模型的图定义。 -
--input_checkpoint
:特定的检查点文件,例如model.ckpt-109587
。 注意,您没有指定大型检查点文件model.ckpt-109587.data-00000-of-00001
的完整文件名。 -
--output_graph
:冻结模型文件的路径–这是在移动设备上使用的路径。 -
--output_node_names
:输出节点名称列表,以逗号分隔,告诉freeze_graph
工具冻结模型中应包括模型的哪一部分和权重,因此生成特定输出不需要的节点名称和权重将保留。
因此,对于该模型,我们如何找出必备的输出节点名称以及输入节点名称,这些对推理也至关重要,正如我们在上一章的 iOS 和 Android 应用中所见到的那样? 因为我们已经使用run_inference
脚本来生成测试图像的标题,所以我们可以看到它是如何进行推理的。
转到您的 im2txt 源代码文件夹models/research/im2txt/im2txt
:您可能想在一个不错的编辑器(例如 Atom 或 Sublime Text)中打开它,或者在 Python IDE(例如 PyCharm)中打开它。 在run_inference.py
中,对inference_utils/inference_wrapper_base.py
中的build_graph_from_config
进行了调用,在inference_wrapper.py
中调用了build_model
,在show_and_tell_model.py
中进一步调用了build
方法。 最后,build
方法将调用build_input
方法,该方法具有以下代码:
if self.mode == "inference":
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
input_feed = tf.placeholder(dtype=tf.int64,
shape=[None], # batch_size
name="input_feed")
还有build_model
方法,它具有:
if self.mode == "inference":
tf.concat(axis=1, values=initial_state, name="initial_state")
state_feed = tf.placeholder(dtype=tf.float32,
shape=[None, sum(lstm_cell.state_size)],
name="state_feed")
...
tf.concat(axis=1, values=state_tuple, name="state")
...
tf.nn.softmax(logits, name="softmax")
因此,名为image_feed
,input_feed
和state_feed
的三个占位符应该是输入节点名称,而initial_state
,state
和softmax
应当是输出节点名称。 此外,inference_wrapper.py
中定义的两种方法证实了我们的侦探工作–第一种是:
def feed_image(self, sess, encoded_image):
initial_state = sess.run(fetches="lstm/initial_state:0",
feed_dict={"image_feed:0": encoded_image})
return initial_state
因此,我们提供image_feed
并返回initial_state
(lstm/
前缀仅表示该节点在lstm
范围内)。 第二种方法是:
def inference_step(self, sess, input_feed, state_feed):
softmax_output, state_output = sess.run(
fetches=["softmax:0", "lstm/state:0"],
feed_dict={
"input_feed:0": input_feed,
"lstm/state_feed:0": state_feed,
})
return softmax_output, state_output, None
我们输入input_feed
和state_feed
,然后返回softmax
和state
。 总共三个输入节点名称和三个输出名称。
注意,仅当mode
为“推断”时才创建这些节点,因为train.py
和run_inference.py
都使用了 show_and_tell_model.py
。 这意味着在运行run_inference.py
脚本后,将修改在步骤 5 中使用train
生成的--checkpoint_path
中模型的图定义文件和权重。 那么,我们如何保存更新的图定义和检查点文件?
事实证明,在run_inference.py
中,在创建 TensorFlow 会话后,还有一个调用restore_fn(sess)
来加载检查点文件,并且该调用在inference_utils/inference_wrapper_base.py
中定义:
def _restore_fn(sess):
saver.restore(sess, checkpoint_path)
在启动run_inference.py
之后到达saver.restore
调用时,已进行了更新的图定义,因此我们可以在此处保存新的检查点和图文件,从而使_restore_fn
函数如下:
def _restore_fn(sess):
saver.restore(sess, checkpoint_path)
saver.save(sess, "model/image2text")
tf.train.write_graph(sess.graph_def, "model", 'im2txt4.pbtxt')
tf.summary.FileWriter("logdir", sess.graph_def)
tf.train.write_graph(sess.graph_def, "model", 'im2txt4.pbtxt')
行是可选的,因为当通过调用saver.save
保存新的检查点文件时,也会生成一个元文件,freeze_graph.py
可以将其与检查点文件一起使用。 但是对于那些希望以纯文本格式查看所有内容,或者在冻结模型时更喜欢使用带有--in_graph
参数的图定义文件的人来说,它是在这里生成的。 最后一行tf.summary.FileWriter("logdir", sess.graph_def)
也是可选的,但它会生成一个可由 TensorBoard 可视化的事件文件。 因此,有了这些更改,在再次运行run_inference.py
之后(除非首先直接使用 Python 运行run_inference.py
,否则请记住首先运行bazel build -c opt //im2txt:run_inference
),您将在model
目录中看到以下新的检查点文件和新的图定义文件:
jeff@AiLabby:~/tensorflow-1.5.0/models/research/im2txt$ ls -lt model
-rw-rw-r-- 1 jeff jeff 2076964 Feb 7 12:33 image2text.pbtxt
-rw-rw-r-- 1 jeff jeff 1343049 Feb 7 12:33 image2text.meta
-rw-rw-r-- 1 jeff jeff 77 Feb 7 12:33 checkpoint
-rw-rw-r-- 1 jeff jeff 149002244 Feb 7 12:33 image2text.data-00000-of-00001
-rw-rw-r-- 1 jeff jeff 16873 Feb 7 12:33 image2text.index
在logdir
目录中:
jeff@AiLabby:~/tensorflow-1.5.0/models/research/im2txt$ ls -lt logdir
total 2124
-rw-rw-r-- 1 jeff jeff 2171623 Feb 7 12:33 events.out.tfevents.1518035604.AiLabby
Running the bazel build
command to build a TensorFlow Python script is optional. You can just run the Python script directly. For example, we can run python tensorflow/python/tools/freeze_graph.py
without building it first with bazel build tensorflow/python/tools:freeze_graph
then running bazel-bin/tensorflow/python/tools/freeze_graph
. But be aware that running the Python script directly will use the version of TensorFlow you’ve installed via pip, which may be different from the version you’ve downloaded as source and built by the bazel build
command. This can be the cause of some confusing errors so be sure you know the TensorFlow version used to run a script. In addition, for a C based tool, you have to build it first with bazel before you can run it. For example, the transform_graph
tool, which we’ll see soon, is implemented in transform_graph.cc
located at tensorflow/tools/graph_transforms
; another important tool called convert_graphdef_memmapped_format
, which we’ll use for our iOS app later, is also implemented in C located at tensorflow/contrib/util
.
现在我们到了,让我们快速使用 TensorBoard 看一下我们的图–只需运行tensorboard --logdir logdir
,然后从浏览器中打开http://localhost:6006
。 图 6.1 显示了三个输出节点名称(顶部为softmax
,以及lstm/initial_state
和红色矩形顶部的突出显示的lstm/state
)和一个输入节点名称(底部的state_feed
):
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9IJBSSEp-1681653119029)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/d47ea13b-441c-4fb1-bccc-bbceaf2a8bdf.png)]
图 6.1:该图显示了三个输出节点名称和一个输入节点名称
图 6.2 显示了另一个输入节点名称image_feed
:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pLuHCbVw-1681653119030)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/a7c39f91-68d3-4cd4-8e9b-0ef0b6dd11ff.png)]
图 6.2:该图显示了一个附加的输入节点名称image_feed
最后,图 6.3 显示了最后一个输入节点名称input_feed
:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Q2fMceyK-1681653119030)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/887360ce-b609-4f8a-95d4-c1c3d2e8f97b.png)]
图 6.3:该图显示了最后一个输入节点名称input_feed
当然,这里有很多我们不能也不会涵盖的细节。 但是,您将了解大局,同样重要的是,有足够的细节可以继续前进。 现在运行freeze_graph.py
应该像轻风(双关语):
python tensorflow/python/tools/freeze_graph.py --input_meta_graph=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text.meta --input_checkpoint=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text --output_graph=/tmp/image2text_frozen.pb --output_node_names="softmax,lstm/initial_state,lstm/state" --input_binary=true
请注意,我们在这里使用元图文件以及将--input_binary
参数设置为true
,因为默认情况下它为false
,这意味着freeze_graph
工具期望输入图或元图文件为文本格式。
您可以使用文本格式的图文件作为输入,在这种情况下,无需提供--input_binary
参数:
python tensorflow/python/tools/freeze_graph.py --input_graph=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text.pbtxt --input_checkpoint=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text --output_graph=/tmp/image2text_frozen2.pb --output_node_names="softmax,lstm/initial_state,lstm/state"
两个输出图文件image2text_frozen.pb
和image2text_frozen2.pb
的大小会稍有不同,但是在经过转换和可能的优化后,它们在移动设备上使用时,它们的行为完全相同。
转换和优化图像字幕模型
如果您真的等不及了,现在决定尝试在 iOS 或 Android 应用上尝试新近冻结的热模型,则可以,但是您会看到一个致命错误No OpKernel was registered to support Op 'DecodeJpeg' with these attrs
,迫使你重新考虑你的决定。
使用转换的模型修复错误
通常,您可以使用strip_unused.py,
工具,与 tensorflow/python/tools,
中的 freeze_graph.py
位于相同位置,来删除不包含在 TensorFlow 核心库中的DecodeJpeg
操作。但是由于输入节点image_feed
需要进行解码操作(图 6.2), strip_unused
之类的工具不会将DecodeJpeg
视为未使用,因此不会被剥夺。 您可以先运行strip_unused
命令,如下所示进行验证:
bazel-bin/tensorflow/python/tools/strip_unused --input_graph=/tmp/image2text_frozen.pb --output_graph=/tmp/image2text_frozen_stripped.pb --input_node_names="image_feed,input_feed,lstm/state_feed" --output_node_names="softmax,lstm/initial_state,lstm/state" --input_binary=True
然后在 iPython 中加载输出图并列出前几个节点,如下所示:
代码语言:javascript复制import tensorflow as tf
g=tf.GraphDef()
g.ParseFromString(open("/tmp/image2text_frozen_stripped", "rb").read())
x=[n.name for n in g.node]
x[:6]
输出如下:
代码语言:javascript复制[u'image_feed',
u'input_feed',
u'decode/DecodeJpeg',
u'convert_image/Cast',
u'convert_image/y',
u'convert_image']
解决您的 iOS 应用错误的第二种可能解决方案,像第 5 章, “了解简单语音命令”一样,是在 tf_op_files
文件中添加未注册的操作实现,并重建 TensorFlow iOS 库。 坏消息是,由于 TensorFlow 中没有DecodeJpeg
函数的实现,因此无法将DecodeJpeg
的 TensorFlow 实现添加到tf_op_files
中。
实际上,在图 6.2 中也暗示了对此烦恼的解决方法,其中convert_image
节点用作image_feed
输入的解码版本。 为了更准确,单击 TensorBoard 图中的转换和解码节点,如图 6.4 所示,您将从右侧的 TensorBoard 信息卡中看到输入转换(名为convert_image/Cast
)的输出为decode/DecodeJpeg
和convert_image
,解码的输入和输出为image_feed
和convert_image/Cast
:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-l7B550bD-1681653119031)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/eeb4e7ad-04ef-4ad5-ac89-a9a4a1e2f1e4.png)]
图 6.4:调查解码和conver_image
节点
实际上,在im2txt/ops/image_processing.py
中有一行image = tf.image.convert_image_dtype(image, dtype=tf.float32)
将解码的图像转换为浮点数。 让我们用convert_image/Cast
代替 TensorBoard 中显示的名称image_feed
,以及前面代码片段的输出,然后再次运行strip_unused
:
bazel-bin/tensorflow/python/tools/strip_unused --input_graph=/tmp/image2text_frozen.pb --output_graph=/tmp/image2text_frozen_stripped.pb --input_node_names="convert_image/Cast,input_feed,lstm/state_feed" --output_node_names="softmax,lstm/initial_state,lstm/state" --input_binary=True
现在,重新运行代码片段,如下所示:
代码语言:javascript复制g.ParseFromString(open("/tmp/image2text_frozen_stripped", "rb").read())
x=[n.name for n in g.node]
x[:6]
并且输出不再具有decode
/ DecodeJpeg
节点:
[u'input_feed',
u'convert_image/Cast',
u'convert_image/y',
u'convert_image',
u'ExpandDims_1/dim',
u'ExpandDims_1']
如果我们在 iOS 或 Android 应用中使用新的模型文件image2text_frozen_stripped.pb
,则No OpKernel was registered to support Op 'DecodeJpeg' with these attrs.
肯定会消失。 但是发生另一个错误, Not a valid TensorFlow Graph serialization: Input 0 of node ExpandDims_6 was passed float from input_feed:0 incompatible with expected int64
。 如果您通过名为 TensorFlow for Poets 2 的不错的 Google TensorFlow 代码实验室,可能会想起来,还有另一个名为optimize_for_inference
的工具,其功能类似于strip_unused
,并且可以很好地用于代码实验室中的图像分类任务。 您可以像这样运行它:
bazel build tensorflow/python/tools:optimize_for_inference
bazel-bin/tensorflow/python/tools/optimize_for_inference
--input=/tmp/image2text_frozen.pb
--output=/tmp/image2text_frozen_optimized.pb
--input_names="convert_image/Cast,input_feed,lstm/state_feed"
--output_names="softmax,lstm/initial_state,lstm/state"
但是在 iOS 或 Android 应用上加载输出模型文件 image2text_frozen_optimized.pb
会导致相同的Input 0 of node ExpandDims_6 was passed float from input_feed:0 incompatible with expected int64
错误。 看起来,尽管我们试图至少在某种程度上实现福尔摩斯在本章中可以做的事情,但有人希望我们首先成为福尔摩斯。
如果您在其他模型(例如我们在前几章中看到的模型)上尝试过strip_unused
或optimize_for_inference
工具,则它们可以正常工作。 事实证明,尽管官方 TensorFlow 1.4 和 1.5 发行版中包含了两个基于 Python 的工具,但在优化一些更复杂的模型时却存在一些错误。 更新和正确的工具是基于 C 的transform_graph
工具,现在是 TensorFlow Mobile 网站推荐的官方工具。 运行以下命令以消除在移动设备上部署时的int64
不兼容float
的错误:
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph
--in_graph=/tmp/image2text_frozen.pb
--out_graph=/tmp/image2text_frozen_transformed.pb
--inputs="convert_image/Cast,input_feed,lstm/state_feed"
--outputs="softmax,lstm/initial_state,lstm/state"
--transforms='
strip_unused_nodes(type=float, shape="299,299,3")
fold_constants(ignore_errors=true, clear_output_shapes=true)
fold_batch_norms
fold_old_batch_norms'
我们将不讨论所有--transforms
选项的详细信息,这些选项在这里有完整记录。 基本上,--transforms
设置可以正确消除模型中未使用的节点,例如DecodeJpeg
,并且还可以进行其他一些优化。
现在,如果您在 iOS 和 Android 应用中加载image2text_frozen_transformed.pb
文件,则不兼容的错误将消失。 当然,我们还没有编写任何真实的 iOS 和 Android 代码,但是我们知道该模型很好,可以随时使用。 很好,但是可以更好。
优化转换后的模型
真正的最后一步,也是至关重要的一步,尤其是在运行复杂的冻结和转换模型(例如我们在较旧的 iOS 设备上训练过的模型)时,是使用位于 tensorflow/contrib/util
的另一个工具convert_graphdef_memmapped_format
,将冻结和转换后的模型转换为映射格式。 映射文件允许现代操作系统(例如 iOS 和 Android)将文件直接映射到主内存,因此无需为文件分配内存,也无需写回磁盘,因为文件数据是只读的,这非常重要。 性能提高。
更重要的是,iOS 不会将已映射文件视为内存使用量,因此,当内存压力过大时,即使文件很大,使用已映射文件的应用也不会由于内存使用太大而被 iOS 杀死和崩溃。 实际上,正如我们将在下一节中很快看到的那样,如果模型文件的转换版本未转换为 memmapped 格式,则将在较旧的移动设备(如 iPhone 6)上崩溃,在这种情况下,转换是必须的, 有。
构建和运行该工具的命令非常简单:
代码语言:javascript复制bazel build tensorflow/contrib/util:convert_graphdef_memmapped_format
bazel-bin/tensorflow/contrib/util/convert_graphdef_memmapped_format
--in_graph=/tmp/image2text_frozen_transformed.pb
--out_graph=/tmp/image2text_frozen_transformed_memmapped.pb
下一节将向您展示如何在 iOS 应用中使用image2text_frozen_transformed_memmapped.pb
模型文件。 它也可以在使用本机代码的 Android 中使用,但是由于时间限制,我们将无法在本章中介绍它。
我们花了很多功夫才能最终为移动应用准备好复杂的图像字幕模型。 是时候欣赏使用模型的简单性了。 实际上,使用模型不仅仅是 iOS 中的单个 session->Run
调用,还是 Android 中的 mInferenceInterface.run
调用,就像我们在前面所有章节中所做的那样; 从输入图像到自然语言输出的推论(如您在上一节中研究run_inference.py
的工作原理时所见)涉及到对模型的run
方法的多次调用。 LSTM 模型就是这样工作的:“继续向我发送新的输入(基于我以前的状态和输出),我将向您发送下一个状态和输出。” 简单来说,我们的意思是向您展示如何使用尽可能少的简洁代码来构建 iOS 和 Android 应用,这些应用使用该模型以自然语言描述图像。 这样,如果需要,您可以轻松地在自己的应用中集成模型及其推理代码。
在 iOS 中使用图像字幕模型
由于该模型的 CNN 部分基于 Inception v3,因此我们在第 2 章,“通过迁移学习对图像进行分类”时使用的模型相同,因此我们可以并且将使用更简单的 TensorFlow Pod 进行以下操作: 创建我们的 Objective-C iOS 应用。 请按照此处的步骤查看如何在新的 iOS 应用中同时使用image2text_frozen_transformed.pb
和image2text_frozen_transformed_memmapped.pb
模型文件:
- 类似于第 2 章,“通过迁移学习对图像进行分类”,“将 TensorFlow 添加到 Objective-C iOS 应用”部分中的前四个步骤, 名为
Image2Text
的 iOS 项目,添加具有以下内容的名为Podfile
的新文件:
target 'Image2Text'
pod 'TensorFlow-experimental'
然后在终端上运行pod install
并打开Image2Text.xcworkspace
文件。 将ios_image_load.h
, ios_image_load.mm
,tensorflow_utils.h
和tensorflow_utils.mm
文件从位于tensorflow/examples/ios/camera
的 TensorFlow iOS 示例相机应用拖放到 Xcode 的Image2Text
项目中。 之前我们已经重用了ios_image_load.*
文件,此处tensorflow_utils.*
文件主要用于加载映射的模型文件。 tensorflow_utils.mm
中有两种方法LoadModel
和 LoadMemoryMappedModel
:一种以我们以前的方式加载非映射模型,另一种加载了映射模型 。 如果有兴趣,请看一下LoadMemoryMappedModel
的实现方式,并且这个页面上的文档也可能会有用。
- 添加在上一节末尾生成的两个模型文件,在“训练和测试字幕生成”小节第 2 步中生成的
word_counts.txt
文件,以及一些测试图像–我们保存并使用 TensorFlow im2txt 模型页面顶部的四个图像,以便我们比较我们的模型的字幕结果,以及那些由使用更多步骤训练的模型所生成的结果。 还将ViewController.m
重命名为.mm
,从现在开始,我们将只处理ViewController.mm
文件即可完成应用。 现在,您的 XcodeImage2Text
项目应类似于图 6.5:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-b48gBjaP-1681653119031)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/62be062e-39dc-4f0c-8a59-7e0e675bafaa.png)]
图 6.5:设置Image2Text
iOS 应用,还显示如何实现LoadMemoryMappedModel
- 打开
ViewController.mm
并添加一堆 Objective-C 和 C 常量,如下所示:
static NSString* MODEL_FILE = @"image2text_frozen_transformed";
static NSString* MODEL_FILE_MEMMAPPED = @"image2text_frozen_transformed_memmapped";
static NSString* MODEL_FILE_TYPE = @"pb";
static NSString* VOCAB_FILE = @"word_counts";
static NSString* VOCAB_FILE_TYPE = @"txt";
static NSString *image_name = @"im2txt4.png";
const string INPUT_NODE1 = "convert_image/Cast";
const string OUTPUT_NODE1 = "lstm/initial_state";
const string INPUT_NODE2 = "input_feed";
const string INPUT_NODE3 = "lstm/state_feed";
const string OUTPUT_NODE2 = "softmax";
const string OUTPUT_NODE3 = "lstm/state";
const int wanted_width = 299;
const int wanted_height = 299;
const int wanted_channels = 3;
const int CAPTION_LEN = 20;
const int START_ID = 2;
const int END_ID = 3;
const int WORD_COUNT = 12000;
const int STATE_COUNT = 1024;
它们都是自我解释的,如果您通读了本章,则应该看起来都很熟悉,除了最后五个常量:CAPTION_LEN
是我们要在标题中生成的最大单词数,START_ID
是句子起始词<S>
的 ID,定义为word_counts.txt
文件中的行号; 所以 2 在第二行表示,在第三行表示 3。 word_counts.txt
文件的前几行是这样的:
a 969108
<S> 586368
</S> 586368
. 440479
on 213612
of 202290
WORD_COUNT
是模型假设的总单词数,对于您很快就会看到的每个推理调用,模型将返回总计 12,000 的概率得分以及 LSTM 模型的 1,024 个状态值。
- 添加一些全局变量和一个函数签名:
unique_ptr<tensorflow::Session> session;
unique_ptr<tensorflow::MemmappedEnv> tf_memmapped_env;
std::vector<std::string> words;
UIImageView *_iv;
UILabel *_lbl;
NSString* generateCaption(bool memmapped);
此简单的与 UI 相关的代码类似于第 2 章,“通过迁移学习对图像进行分类”的 iOS 应用的代码。 基本上,您可以在应用启动后点击任意位置,然后选择两个模型之一,图像描述结果将显示在顶部。 当用户在alert
操作中选择了映射模型时,将运行以下代码:
dispatch_async(dispatch_get_global_queue(0, 0), ^{
NSString *caption = generateCaption(true);
dispatch_async(dispatch_get_main_queue(), ^{
_lbl.text = caption;
});
});
如果选择了非映射模型,则使用generateCaption(false)
。
- 在
viewDidLoad
方法的末尾,添加代码以加载word_counts.txt
并将这些单词逐行保存在 Objective-C 和 C 中:
NSString* voc_file_path = FilePathForResourceName(VOCAB_FILE, VOCAB_FILE_TYPE);
if (!voc_file_path) {
LOG(FATAL) << "Couldn't load vocabuary file: " << voc_file_path;
}
ifstream t;
t.open([voc_file_path UTF8String]);
string line;
while(t){
getline(t, line);
size_t pos = line.find(" ");
words.push_back(line.substr(0, pos));
}
t.close();
- 剩下的我们要做的就是实现
generateCaption
函数。 在其中,首先加载正确的模型:
tensorflow::Status load_status;
if (memmapped)
load_status = LoadMemoryMappedModel(MODEL_FILE_MEMMAPPED, MODEL_FILE_TYPE, &session, &tf_memmapped_env);
else
load_status = LoadModel(MODEL_FILE, MODEL_FILE_TYPE, &session);
if (!load_status.ok()) {
return @"Couldn't load model";
}
- 然后,使用类似的图像处理代码来准备要输入到模型中的图像张量:
int image_width;
int image_height;
int image_channels;
NSArray *name_ext = [image_name componentsSeparatedByString:@"."];
NSString* image_path = FilePathForResourceName(name_ext[0], name_ext[1]);
std::vector<tensorflow::uint8> image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels);
tensorflow::Tensor image_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({wanted_height, wanted_width, wanted_channels}));
auto image_tensor_mapped = image_tensor.tensor<float, 3>();
tensorflow::uint8* in = image_data.data();
float* out = image_tensor_mapped.data();
for (int y = 0; y < wanted_height; y) {
const int in_y = (y * image_height) / wanted_height;
tensorflow::uint8* in_row = in (in_y * image_width * image_channels);
float* out_row = out (y * wanted_width * wanted_channels);
for (int x = 0; x < wanted_width; x) {
const int in_x = (x * image_width) / wanted_width;
tensorflow::uint8* in_pixel = in_row (in_x * image_channels);
float* out_pixel = out_row (x * wanted_channels);
for (int c = 0; c < wanted_channels; c) {
out_pixel[c] = in_pixel[c];
}
}
}
- 现在,我们可以将图像发送到模型,并获取返回的
initial_state
张量向量,该向量包含 1,200(STATE_COUNT
)个值:
vector<tensorflow::Tensor> initial_state;
if (session.get()) {
tensorflow::Status run_status = session->Run({{INPUT_NODE1, image_tensor}}, {OUTPUT_NODE1}, {}, &initial_state);
if (!run_status.ok()) {
return @"Getting initial state failed";
}
}
- 定义
input_feed
和state_feed
张量,并将它们的值分别设置为起始字的 ID 和返回的initial_state
值:
tensorflow::Tensor input_feed(tensorflow::DT_INT64, tensorflow::TensorShape({1,}));
tensorflow::Tensor state_feed(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, STATE_COUNT}));
auto input_feed_map = input_feed.tensor<int64_t, 1>();
auto state_feed_map = state_feed.tensor<float, 2>();
input_feed_map(0) = START_ID;
auto initial_state_map = initial_state[0].tensor<float, 2>();
for (int i = 0; i < STATE_COUNT; i ){
state_feed_map(0,i) = initial_state_map(0,i);
}
- 在
CAPTION_LEN
上创建一个for
循环,然后在该循环内,首先创建output_feed
和output_states
张量向量,然后馈入我们先前设置的input_feed
和state_feed
,并运行模型以返回由softmax
张量和new_state
张量组成的output
张量向量:
vector<int> captions;
for (int i=0; i<CAPTION_LEN; i ) {
vector<tensorflow::Tensor> output;
tensorflow::Status run_status = session->Run({{INPUT_NODE2, input_feed}, {INPUT_NODE3, state_feed}}, {OUTPUT_NODE2, OUTPUT_NODE3}, {}, &output);
if (!run_status.ok()) {
return @"Getting LSTM state failed";
}
else {
tensorflow::Tensor softmax = output[0];
tensorflow::Tensor state = output[1];
auto softmax_map = softmax.tensor<float, 2>();
auto state_map = state.tensor<float, 2>();
- 现在,找到可能性最大(softmax 值)的单词 ID。 如果是结束字的 ID,则结束
for
循环;否则,结束循环。 否则,将具有最大 softmax 值的单词id
添加到向量captions
中。 请注意,此处我们使用贪婪搜索,始终选择概率最大的单词,而不是像run_inference.py
脚本中那样将大小设置为 3 的集束搜索。 在for
循环的末尾,用最大字数id
更新input_feed
值,并用先前返回的state
值更新state_feed
值,然后再将两个输入,所有下一个单词的 softmax 值和下一个状态值,馈送到模型:
float max_prob = 0.0f;
int max_word_id = 0;
for (int j = 0; j < WORD_COUNT; j ){
if (softmax_map(0,j) > max_prob) {
max_prob = softmax_map(0,j);
max_word_id = j;
}
}
if (max_word_id == END_ID) break;
captions.push_back(max_word_id);
input_feed_map(0) = max_word_id;
for (int j = 0; j < STATE_COUNT; j ){
state_feed_map(0,j) = state_map(0,j);
}
}
}
我们可能从未详细解释过如何在 C 中获取和设置 TensorFlow 张量值。 但是,如果您到目前为止已经阅读了本书中的代码,那么您应该已经学会了。 这就像 RNN 学习:如果您接受了足够的代码示例训练,就可以编写有意义的代码。 总而言之,首先使用Tensor
类型定义变量,并使用该变量的数据类型和形状指定,然后调用Tensor
类的tensor
方法,传入数据类型的 C 版本和形状,以创建张量的贴图变量。 之后,您可以简单地使用映射来获取或设置张量的值。
- 最后,只需遍历
captions
向量并将向量中存储的每个词 ID 转换为一个词,然后将该词添加到sentence
字符串中,而忽略起始 ID 和结束 ID,然后返回该句子,希望是可读的自然语言:
NSString *sentence = @"";
for (int i=0; i<captions.size(); i ) {
if (captions[i] == START_ID) continue;
if (captions[i] == END_ID) break;
sentence = [NSString stringWithFormat:@"%@ %s", sentence, words[captions[i]].c_str()];
}
return sentence;
这就是在 iOS 应用中运行模型所需的一切。 现在,在 iOS 模拟器或设备中运行该应用,点击并选择一个模型,如图 6.6 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yeO4JaRG-1681653119031)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/dd9bd09b-1b39-47ff-8a8f-f7f180faff3c.png)]
图 6.6:运行Image2Text
iOS 应用并选择模型
在 iOS 模拟器上,运行非映射模型需要 10 秒钟以上,运行映射模型则需要 5 秒钟以上。 在 iPhone 6 上,运行贴图模型还需要大约 5 秒钟,但由于模型文件和内存压力较大,运行非贴图模型时会崩溃。
至于结果,图 6.7 显示了四个测试图像结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MghzzYzK-1681653119032)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/59f462cb-59e6-4165-bd6f-80166f84bd61.png)]
图 6.7:显示图像字幕结果
图 6.8 显示了 TensorFlow im2txt 网站上的结果,您可以看到我们更简单的贪婪搜索结果看起来也不错。 但是对于长颈鹿图片,看来我们的模型或推理代码不够好。 完成本章中的工作后,希望您会在改进训练或模型推断方面有所收获:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d0KLlsdY-1681653119032)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/e08dec95-3b49-432a-bc4f-e3077151f2e9.png)]
图 6.8:字幕示例显示在 TensorFlow im2txt 模型网站上
在我们进行下一个智能任务之前,是时候给 Android 开发人员一个不错的选择了。
在 Android 中使用图像字幕模型
遵循相同的简单性考虑,我们将开发具有最小 UI 的新 Android 应用,并着重于如何在 Android 中使用该模型:
- 创建一个名为
Image2Text
的新 Android 应用,在应用build.gradle
文件的依存关系的末尾添加compile 'org.tensorflow:tensorflow-android: '
,创建一个assets
文件夹,然后将image2text_frozen_transformed.pb
模型文件word_counts.txt
文件和一些测试图像文件拖放到其中。 - 在
activity_main.xml
文件中添加一个ImageView
和一个按钮:
<ImageView
android:id="@ id/imageview"
android:layout_width="match_parent"
android:layout_height="match_parent"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintHorizontal_bias="0.0"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintVertical_bias="1.0"/>
<Button
android:id="@ id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="DESCRIBE ME"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintVertical_bias="1.0"/>
- 打开
MainActivity.java
,使其实现Runnable
接口,然后添加以下常量,在前一节中说明了其中的最后五个,而其他则是自解释的:
private static final String MODEL_FILE = "file:///android_asset/image2text_frozen_transformed.pb";
private static final String VOCAB_FILE = "file:///android_asset/word_counts.txt";
private static final String IMAGE_NAME = "im2txt1.png";
private static final String INPUT_NODE1 = "convert_image/Cast";
private static final String OUTPUT_NODE1 = "lstm/initial_state";
private static final String INPUT_NODE2 = "input_feed";
private static final String INPUT_NODE3 = "lstm/state_feed";
private static final String OUTPUT_NODE2 = "softmax";
private static final String OUTPUT_NODE3 = "lstm/state";
private static final int IMAGE_WIDTH = 299;
private static final int IMAGE_HEIGHT = 299;
private static final int IMAGE_CHANNEL = 3;
private static final int CAPTION_LEN = 20;
private static final int WORD_COUNT = 12000;
private static final int STATE_COUNT = 1024;
private static final int START_ID = 2;
private static final int END_ID = 3;
以及以下实例变量和处理器实现:
代码语言:javascript复制private ImageView mImageView;
private Button mButton;
private TensorFlowInferenceInterface mInferenceInterface;
private String[] mWords = new String[WORD_COUNT];
private int[] intValues;
private float[] floatValues;
Handler mHandler = new Handler() {
@Override
public void handleMessage(Message msg) {
mButton.setText("DESCRIBE ME");
String text = (String)msg.obj;
Toast.makeText(MainActivity.this, text, Toast.LENGTH_LONG).show();
mButton.setEnabled(true);
} };
- 在
onCreate
方法中,首先在ImageView
中添加显示测试图像并处理按钮单击事件的代码:
mImageView = findViewById(R.id.imageview);
try {
AssetManager am = getAssets();
InputStream is = am.open(IMAGE_NAME);
Bitmap bitmap = BitmapFactory.decodeStream(is);
mImageView.setImageBitmap(bitmap);
} catch (IOException e) {
e.printStackTrace();
}
mButton = findViewById(R.id.button);
mButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
mButton.setEnabled(false);
mButton.setText("Processing...");
Thread thread = new Thread(MainActivity.this);
thread.start();
}
});
然后添加读取word_counts.txt
每行的代码,并将每个单词保存在mWords
数组中:
String filename = VOCAB_FILE.split("file:///android_asset/")[1];
BufferedReader br = null;
int linenum = 0;
try {
br = new BufferedReader(new InputStreamReader(getAssets().open(filename)));
String line;
while ((line = br.readLine()) != null) {
String word = line.split(" ")[0];
mWords[linenum ] = word;
}
br.close();
} catch (IOException e) {
throw new RuntimeException("Problem reading vocab file!" , e);
}
- 现在,在
public void run()
方法中,在DESCRIBE ME
按钮发生onClick
事件时启动,添加代码以调整测试图像的大小,从调整后的位图中读取像素值,然后将它们转换为浮点数-我们已经在前三章中看到了这样的代码:
intValues = new int[IMAGE_WIDTH * IMAGE_HEIGHT];
floatValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * IMAGE_CHANNEL];
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open(IMAGE_NAME));
Bitmap croppedBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, true);
croppedBitmap.getPixels(intValues, 0, IMAGE_WIDTH, 0, 0, IMAGE_WIDTH, IMAGE_HEIGHT);
for (int i = 0; i < intValues.length; i) {
final int val = intValues[i];
floatValues[i * IMAGE_CHANNEL 0] = ((val >> 16) & 0xFF);
floatValues[i * IMAGE_CHANNEL 1] = ((val >> 8) & 0xFF);
floatValues[i * IMAGE_CHANNEL 2] = (val & 0xFF);
}
- 创建一个
TensorFlowInferenceInterface
实例,该实例加载模型文件,并通过向其提供图像值,然后在initialState
中获取返回结果来使用该模型进行第一个推断:
AssetManager assetManager = getAssets();
mInferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE);
float[] initialState = new float[STATE_COUNT];
mInferenceInterface.feed(INPUT_NODE1, floatValues, IMAGE_WIDTH, IMAGE_HEIGHT, 3);
mInferenceInterface.run(new String[] {OUTPUT_NODE1}, false);
mInferenceInterface.fetch(OUTPUT_NODE1, initialState);
- 将第一个
input_feed
值设置为起始 ID,并将第一个state_feed
值设置为返回的initialState
值:
long[] inputFeed = new long[] {START_ID};
float[] stateFeed = new float[STATE_COUNT * inputFeed.length];
for (int i=0; i < STATE_COUNT; i ) {
stateFeed[i] = initialState[i];
}
如您所见,得益于 Android 中的TensorFlowInferenceInterface
实现,在 Android 中获取和设置张量值并进行推理比在 iOS 中更简单。 在我们开始重复使用inputFeed
和stateFeed
进行模型推断之前,我们创建了一个captions
列表,该列表包含一对整数和浮点数,其中整数作为单词 ID,具有最大 softmax 值(在模型为每个推理调用返回的所有 softmax 值中)和float
作为单词的 softmax 值。 我们可以使用一个简单的向量来保存每个推论返回中具有最大 softmax 值的单词,但是使用对的列表可以使以后我们从贪婪搜索方法切换到集束搜索时更加容易:
List<Pair<Integer, Float>> captions = new ArrayList<Pair<Integer, Float>>();
- 在字幕长度的
for
循环中,我们将上面设置的值提供给input_feed
和state_feed
,然后获取返回的softmax
和newstate
值:
for (int i=0; i<CAPTION_LEN; i ) {
float[] softmax = new float[WORD_COUNT * inputFeed.length];
float[] newstate = new float[STATE_COUNT * inputFeed.length];
mInferenceInterface.feed(INPUT_NODE2, inputFeed, 1);
mInferenceInterface.feed(INPUT_NODE3, stateFeed, 1, STATE_COUNT);
mInferenceInterface.run(new String[]{OUTPUT_NODE2, OUTPUT_NODE3}, false);
mInferenceInterface.fetch(OUTPUT_NODE2, softmax);
mInferenceInterface.fetch(OUTPUT_NODE3, newstate);
- 现在,创建另一个由整数和浮点对组成的列表,将每个单词的 ID 和 softmax 值添加到列表中,并以降序对列表进行排序:
List<Pair<Integer, Float>> prob_id = new ArrayList<Pair<Integer, Float>>();
for (int j = 0; j < WORD_COUNT; j ) {
prob_id.add(new Pair(j, softmax[j]));
}
Collections.sort(prob_id, new Comparator<Pair<Integer, Float>>() {
@Override
public int compare(final Pair<Integer, Float> o1, final Pair<Integer, Float> o2) {
return o1.second > o2.second ? -1 : (o1.second == o2.second ? 0 : 1);
}
});
- 如果最大概率的单词是结束单词,则结束循环。 否则,将该对添加到
captions
列表,并使用最大 softmax 值的单词 ID 更新input_feed
并使用返回的状态值更新state_feed
,以继续进行下一个推断:
if (prob_id.get(0).first == END_ID) break;
captions.add(new Pair(prob_id.get(0).first, prob_id.get(0).first));
inputFeed = new long[] {prob_id.get(0).first};
for (int j=0; j < STATE_COUNT; j ) {
stateFeed[j] = newstate[j];
}
}
- 最后,遍历
captions
列表中的每一对,并将每个单词(如果不是开头和结尾的话)添加到sentence
字符串,该字符串通过处理器返回,以向用户显示自然语言输出:
String sentence = "";
for (int i=0; i<captions.size(); i ) {
if (captions.get(i).first == START_ID) continue;
if (captions.get(i).first == END_ID) break;
sentence = sentence " " mWords[captions.get(i).first];
}
Message msg = new Message();
msg.obj = sentence;
mHandler.sendMessage(msg);
在您的虚拟或真实 Android 设备上运行该应用。 大约需要 10 秒钟才能看到结果。 您可以使用上一节中显示的四个不同的测试图像,并在图 6.9 中查看结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fOpv4y9U-1681653119032)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/249e6bc8-7f89-45d4-a217-2cf1c54bf5fe.png)]
图 6.9:在 Android 中显示图像字幕结果
一些结果与 iOS 结果以及 TensorFlow im2txt 网站上的结果略有不同。 但是它们看起来都不错。 另外,在相对较旧的 Android 设备(例如 Nexus 5)上运行该模型的非映射版本也可以。 但是最好在 Android 中加载映射模型,以查看性能的显着提高,我们可能会在本书后面的章节中介绍。
因此,这将使用功能强大的图像字幕模型完成分步的 Android 应用构建过程。 无论您使用的是 iOS 还是 Android 应用,您都应该能够轻松地将我们训练有素的模型和推理代码集成到自己的应用中,或者返回到训练过程以微调模型,然后准备并优化更好的模型。 在您的移动应用中使用的模型。
总结
在本章中,我们首先讨论了由现代端到端深度学习支持的图像字幕如何工作,然后总结了如何使用 TensorFlow im2txt 模型项目训练这种模型。 我们详细讨论了如何找到正确的输入节点名称和输出节点名称,以及如何冻结模型,然后使用最新的图转换工具和映射转换工具修复在将模型加载到手机上时出现的一些讨厌的错误。 之后,我们展示了有关如何使用模型构建 iOS 和 Android 应用以及如何使用模型的 LSTM RNN 组件进行新的序列推断的详细教程。
令人惊讶的是,经过训练了成千上万个图像字幕示例,并在现代 CNN 和 LSTM 模型的支持下,我们可以构建和使用一个模型,该模型可以在移动设备上生成合理的自然语言描述。 不难想象可以在此基础上构建什么样的有用应用。 我们喜欢福尔摩斯吗? 当然不。 我们已经在路上了吗? 我们希望如此。 AI 的世界既令人着迷又充满挑战,但是只要我们不断取得稳步进步并改善自己的学习过程,同时又避免了梯度问题的消失和爆炸,我们就有很大机会建立一个类似于 Holmes 的模型,并可以随时随地在一天中在移动应用中使用它。
漫长的篇章讨论了基于 CNN 和 LSTM 的网络模型的实际使用,我们值得一试。 在下一章中,您将看到如何使用另一个基于 CNN 和 LSTM 的模型来开发有趣的 iOS 和 Android 应用,这些应用使您可以绘制对象然后识别它们是什么。 要快速获得游戏在线版本的乐趣,请访问这里。
七、使用 CNN 和 LSTM 识别绘画
在上一章中,我们看到了使用深度学习模型的强大功能,该模型将 CNN 与 LSTM RNN 集成在一起以生成图像的自然语言描述。 如果深度学习驱动的 AI 就像新的电力一样,我们当然希望看到这种混合神经网络模型在许多不同领域中的应用。 诸如图像字幕之类的严肃应用与之相反? 一个有趣的绘画应用,例如 Quick Draw(请参见这里了解有趣的示例数据),使用经过训练并基于 345 个类别中的 5000 万张绘画的模型,并将新绘画分类到这些类别中,听起来不错。 还有一个正式的 TensorFlow 教程,该教程介绍了如何构建这样的模型来帮助我们快速入门。
事实证明,在 iOS 和 Android 应用上使用本教程构建的模型的任务提供了一个绝佳的机会:
- 加深我们对找出模型的正确输入和输出节点名称的理解,因此我们可以为移动应用适当地准备模型
- 使用其他方法来修复 iOS 中的新模型加载和推断错误
- 首次为 Android 构建自定义的 TensorFlow 本机库,以修复 Android 中的新模型加载和预测错误
- 查看有关如何使用预期格式的输入来输入 TensorFlow 模型以及如何在 iOS 和 Android 中获取和处理其输出的更多示例
此外,在处理所有繁琐而重要的细节的过程中,以便模型可以像魔术一样工作,以进行漂亮的绘画分类,您将在 iOS 和 Android 设备上享受有趣的涂鸦。
因此,在本章中,我们将介绍以下主题:
- 绘画分类 – 工作原理
- 训练并准备绘画分类模型
- 在 iOS 中使用绘画分类模型
- 在 Android 中使用绘画分类模型
绘画分类 – 工作原理
TensorFlow 教程中内置的绘画分类模型,首先接受表示为点列表的用户绘画输入,并将规范化输入转换为连续点的增量的张量,以及有关每个点是否是新笔画的开始的信息。 然后将张量穿过几个卷积层和 LSTM 层,最后穿过 softmax 层,如图 7.1 所示,以对用户绘画进行分类:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hFvjcltm-1681653119032)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/ef475044-6568-4e50-bc98-8ff0a0d6efbf.png)]
图 7.1:绘画分类模式
与接受 2D 图像输入的 2D 卷积 API tf.layers.conv2d
不同,此处将 1D 卷积 API tf.layers.conv1d
用于时间卷积(例如绘画)。 默认情况下,在绘画分类模型中,使用三个 1D 卷积层,每个层具有 48、64 和 96 个过滤器,其长度分别为 5、5 和 3。 卷积层之后,将创建 3 个 LSTM 层,每层具有 128 个正向BasicLSTMCell
节点和 128 个反向BasicLSTMCell
节点,然后将其用于创建动态双向循环神经网络,该网络的输出将发送到最终的完全连接层以计算logits
(非标准化的对数概率)。
If you don’t have a good understanding of all these details, don’t worry; to develop powerful mobile apps using a model built by others, you don’t have to understand all the details, but in the next chapter we’ll also discuss in greater detail how you can build a RNN model from scratch for stock prediction, and with that, you’ll have a better understanding of all the RNN stuff.
在前面提到的有趣的教程中详细描述了简单而优雅的模型以及构建模型的 Python 实现,其源代码位于仓库中。 在继续进行下一部分之前,我们只想说一件事:模型的构建,训练,评估和预测的代码与上一章中看到的代码不同,它使用了称为Estimator
的 TensorFlow API,或更准确地说,是自定义Estimator
。 如果您对模型实现的详细信息感兴趣,则应该阅读有关创建和使用自定义Estimator
的指南。 这个页面的models/samples/core/get_started/custom_estimator.py
上的指南的有用源代码。 基本上,首先要实现一个函数,该函数定义模型,指定损失和准确率度量,设置优化器和training
操作,然后创建tf.estimator.Estimator
类的实例并调用其train
,evaluate
和predict
方法。 就像您将很快看到的那样,使用Estimator
可以简化如何构建,训练和推断神经网络模型,但是由于它是高级 API,因此它还会执行一些更加困难的低级任务,例如找出输入和输出节点名称来推断移动设备。
训练,预测和准备绘画分类模型
训练模型非常简单,但为移动部署准备模型则有些棘手。 在我们开始训练之前,请首先确保您已经在 TensorFlow 根目录中克隆了 TensorFlow 模型库,就像我们在前两章中所做的一样。 然后从这里下载绘画分类训练数据集,大约 1.1GB,创建一个名为rnn_tutorial_data
的新文件夹, 并解压缩dataset tar.gz
文件。 您将看到 10 个训练 TFRecord 文件和 10 个评估 TFRecord 文件,以及两个带有.classes
扩展名的文件,它们具有相同的内容,并且只是该数据集可用于分类的 345 个类别的纯文本,例如"sheep", "skull", "donut", "apple"
。
训练绘画分类模型
要训练模型,只需打开终端cd
到tensorflow/models/tutorials/rnn/quickdraw
,然后运行以下脚本:
python train_model.py
--training_data=rnn_tutorial_data/training.tfrecord-?????-of-?????
--eval_data=rnn_tutorial_data/eval.tfrecord-?????-of-?????
--model_dir quickdraw_model/
--classes_file=rnn_tutorial_data/training.tfrecord.classes
默认情况下,训练步骤为 100k,在我们的 GTX 1070 GPU 上大约需要 6 个小时才能完成训练。 训练完成后,您将在模型目录中看到一个熟悉的文件列表(省略了其他四组model.ckpt*
文件):
ls -lt quickdraw_model/
-rw-rw-r-- 1 jeff jeff 164419871 Feb 12 05:56 events.out.tfevents.1518422507.AiLabby
-rw-rw-r-- 1 jeff jeff 1365548 Feb 12 05:56 model.ckpt-100000.meta
-rw-rw-r-- 1 jeff jeff 279 Feb 12 05:56 checkpoint
-rw-rw-r-- 1 jeff jeff 13707200 Feb 12 05:56 model.ckpt-100000.data-00000-of-00001
-rw-rw-r-- 1 jeff jeff 2825 Feb 12 05:56 model.ckpt-100000.index
-rw-rw-r-- 1 jeff jeff 2493402 Feb 12 05:47 graph.pbtxt
drwxr-xr-x 2 jeff jeff 4096 Feb 12 00:11 eval
如果您运行tensorboard --logdir quickdraw_model
,然后从浏览器在http://localhost:6006
上启动 TensorBoard,您会看到精度达到约 0.55,损失到约 2.0。 如果继续进行约 200k 的训练,则精度将提高到约 0.65,损失将下降到 1.3,如图 7.2 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZwniLh1v-1681653119033)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/479a1d37-1775-459d-a616-c45f2f10f0f1.png)]
图 7.2:300k 训练步骤后模型的准确率和损失
现在,我们可以像上一章一样运行freeze_graph.py
工具,以生成用于移动设备的模型文件。 但是在执行此操作之前,我们首先来看一下如何在 Python 中使用该模型进行推断,例如上一章中的run_inference.py
脚本。
使用绘画分类模型进行预测
看一下models/tutorial/rnn/quickdraw
文件夹中的train_model.py
文件。 当它开始运行时,将在create_estimator_and_specs
函数中创建一个Estimator
实例:
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params=model_params)
传递给Estimator
类的关键参数是名为model_fn
的模型函数,该函数定义:
- 获取输入张量并创建卷积,RNN 和最终层的函数
- 调用这些函数来构建模型的代码
- 损失,优化器和预测
在返回tf.estimator.EstimatorSpec
实例之前,model_fn
函数还具有一个名为mode
的参数,该参数可以具有以下三个值之一:
tf.estimator.ModeKeys.TRAIN
tf.estimator.ModeKeys.EVAL
tf.estimator.ModeKeys.PREDICT
实现train_model.py
的方式支持训练和求值模式,但是您不能直接使用它来通过特定的绘画输入进行推理(对绘画进行分类)。 要使用特定输入来测试预测,请按照以下步骤操作:
- 复制
train_model.py
,然后将新文件重命名为predict.py
-这样您就可以更自由地进行预测了。 - 在
predict.py
中,定义[预测]的输入函数,并将features
设置为模型期望的绘画输入(连续点的增量,其中第三个数字表示该点是否为笔划的起点) :
def predict_input_fn():
def _input_fn():
features = {'shape': [[16, 3]], 'ink': [[
-0.23137257, 0.31067961, 0. ,
-0.05490196, 0.1116505 , 0. ,
0.00784314, 0.09223297, 0. ,
0.19215687, 0.07766992, 0. ,
...
0.12156862, 0.05825245, 0. ,
0. , -0.06310678, 1. ,
0. , 0., 0. ,
...
0. , 0., 0. ,
]]}
features['shape'].append( features['shape'][0])
features['ink'].append( features['ink'][0])
features=dict(features)
dataset = tf.data.Dataset.from_tensor_slices(features)
dataset = dataset.batch(FLAGS.batch_size)
return dataset.make_one_shot_iterator().get_next()
return _input_fn
我们并没有显示所有的点值,但它们是使用 TensorFlow RNN 用于绘画分类的教程中显示的示例猫示例数据创建的,并应用了parse_line
函数(请参见教程或models/tutorials/rnn/quickdraw
文件夹中的create_dataset.py
细节)。
还要注意,我们使用tf.data.Dataset
的make_one_shot_iterator
方法创建了一个迭代器,该迭代器从数据集中返回一个示例(在这种情况下,我们在数据集中只有一个示例),与模型在处理大型数据集时,在训练和评估过程中获取数据的方式相同–这就是为什么稍后在模型的图中看到OneShotIterator
操作的原因。
- 在主函数中,调用估计器的
predict
方法,该方法将生成给定特征的预测,然后打印下一个预测:
predictions = estimator.predict(input_fn=predict_input_fn())
print(next(predictions)['argmax'])
- 在
model_fn
函数中,在logits = _add_fc_layers(final_state)
之后,添加以下代码:
argmax = tf.argmax(logits, axis=1)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'argmax': argmax,
'softmax': tf.nn.softmax(logits),
'logits': logits,
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
现在,如果您运行predict.py
,您将在步骤 2 中获得具有输入数据返回最大值的类 ID。
基本了解如何使用Estimator
高级 API 构建的模型进行预测后,我们现在就可以冻结该模型,以便可以在移动设备上使用该模型,这需要我们首先弄清楚输出节点名称应该是什么。
准备绘画分类模型
让我们使用 TensorBoard 看看我们能找到什么。 在我们模型的 TensorBoard 视图的 GRAPHS 部分中,您可以看到,如图 7.3 所示,以红色突出显示的BiasAdd
节点是ArgMax
操作的输入,用于计算精度,以及 softmax 操作的输入。 我们可以使用SparseSoftmaxCrossEntropyWithLogits
(图 7.3 仅显示为SparseSiftnaxCr ...
)操作,也可以仅使用Dense
/BiasAdd
作为输出节点名称,但我们将ArgMax
和Dense
/BiasAdd
用作freeze_graph
工具的两个输出节点名称,因此我们可以更轻松地查看最终密集层的输出以及ArgMax
结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XWqHPeoM-1681653119033)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/d911bab2-cfc2-41d3-a352-d00db55b1f71.png)]
图 7.3:显示模型的可能输出节点名称
用您的graph.pbtxt
文件的路径和最新的模型检查点前缀替换--input_graph
和--input_checkpoint
值后,在 TensorFlow 根目录中运行以下脚本以获取冻结的图:
python tensorflow/python/tools/freeze_graph.py --input_graph=/tmp/graph.pbtxt --input_checkpoint=/tmp/model.ckpt-314576 --output_graph=/tmp/quickdraw_frozen_dense_biasadd_argmax.pb --output_node_names="dense/BiasAdd,ArgMax"
您会看到quickdraw_frozen_dense_biasadd_argmax.pb
成功创建。 但是,如果您尝试在 iOS 或 Android 应用中加载模型,则会收到一条错误消息,内容为Could not create TensorFlow Graph: Not found: Op type not registered 'OneShotIterator' in binary. Make sure the Op and Kernel are registered in the binary running in this process.
我们在前面的小节中讨论了OneShotIterator
的含义。 回到 TensorBoard GRAPHS
部分,我们可以看到OneShotIterator
(如图 7.4 所示),该区域以红色突出显示,并且还显示在右下方的信息面板中,在图表的底部,以及上方的几个层次中,有一个 Reshape
操作用作第一卷积层的输入:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DKlJXEkX-1681653119033)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/a88ae370-8a39-4ee3-8ce0-ebaa5fd8a8ed.png)]
图 7.4:查找可能的输入节点名称
您可能想知道为什么我们不能使用我们之前使用的技术来解决Not found: Op type not registered 'OneShotIterator'
错误,即先使用命令grep 'REGISTER.*"OneShotIterator"' tensorflow/core/ops/*.cc
(您将看到输出为tensorflow/core/ops/dataset_ops.cc:REGISTER_OP("OneShotIterator")
),然后将tensorflow/core/ops/dataset_ops.cc
添加到tf_op_files.txt
并重建 TensorFlow 库。 即使这可行,也会使解决方案复杂化,因为现在我们需要向模型提供一些与OneShotIterator
相关的数据,而不是以点为单位的直接用户绘画。
此外,在右侧上方一层(图 7.5),还有另一种操作 Squeeze
,它是 rnn_classification
子图的输入:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BMOmquMH-1681653119033)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/6230229e-1c21-4280-b60a-cac640c1cb1f.png)]
图 7.5:找出输入节点名称的进一步研究
我们不必担心Reshape
右侧的Shape
运算,因为它实际上是rnn_classification
子图的输出。 因此,所有这些研究背后的直觉是,我们可以使用Reshape
和Squeeze
作为两个输入节点,然后使用在上一章中看到的transform_graph
工具,我们应该能够删除 Reshape
和Squeeze
以下的节点,包括OneShotIterator
。
现在在 TensorFlow 根目录中运行以下命令:
代码语言:javascript复制bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=/tmp/quickdraw_frozen_dense_biasadd_argmax.pb --out_graph=/tmp/quickdraw_frozen_strip_transformed.pb --inputs="Reshape,Squeeze" --outputs="dense/BiasAdd,ArgMax" --transforms='
strip_unused_nodes(name=Squeeze,type_for_name=int64,shape_for_name="8",name=Reshape,type_for_name=float,shape_for_name="8,16,3")'
在这里,我们为strip_unused_nodes
使用了更高级的格式:对于每个输入节点名称(Squeeze
和Reshape
),我们指定其特定的类型和形状,以避免以后出现模型加载错误。 有关transform_graph
工具的strip_unused_nodes
的更多详细信息,请参见其上的文档 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms 。
现在在 iOS 或 Android 中加载模型,OneShotIterator
错误将消失。 但是,您可能已经学会了预期,但是会出现一个新错误:Could not create TensorFlow Graph: Invalid argument: Input 0 of node IsVariableInitialized was passed int64 from global_step:0 incompatible with expected int64_ref.
我们首先需要了解有关IsVariableInitialized
的更多信息。 如果我们回到 TensorBoard GRAPHS
标签,我们会在左侧看到一个IsVariableInitialized
操作,该操作以红色突出显示并在右侧的信息面板中以global_step
作为其输入(图 7.6)。
即使我们不确切知道它的用途,我们也可以确保它与模型推断无关,该模型推断只需要一些输入(图 7.4 和图 7.5)并生成绘画分类作为输出(图 7.3)。 :
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eTUXwrz3-1681653119034)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/6cc9946c-dfb2-41de-af95-a35e47d405a7.png)]
图 7.6:查找导致模型加载错误但与模型推断无关的节点
那么,如何摆脱global_step
以及其他相关的cond
节点,由于它们的隔离性,它们不会被变换图工具剥离掉? 幸运的是,freeze_graph
脚本支持这一点 – 仅在其源代码中记录。 我们可以为脚本使用variable_names_blacklist
参数来指定应在冻结模型中删除的节点:
python tensorflow/python/tools/freeze_graph.py --input_graph=/tmp/graph.pbtxt --input_checkpoint=/tmp/model.ckpt-314576 --output_graph=/tmp/quickdraw_frozen_long_blacklist.pb --output_node_names="dense/BiasAdd,ArgMax" --variable_names_blacklist="IsVariableInitialized,global_step,global_step/Initializer/zeros,cond/pred_id,cond/read/Switch,cond/read,cond/Switch_1,cond/Merge"
在这里,我们只列出global_step
和cond
范围内的节点。 现在再次运行transform_graph
工具:
bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=/tmp/quickdraw_frozen_long_blacklist.pb --out_graph=/tmp/quickdraw_frozen_long_blacklist_strip_transformed.pb --inputs="Reshape,Squeeze" --outputs="dense/BiasAdd,ArgMax" --transforms='
strip_unused_nodes(name=Squeeze,type_for_name=int64,shape_for_name="8",name=Reshape,type_for_name=float,shape_for_name="8,16,3")'
在 iOS 或 Android 中加载生成的模型文件quickdraw_frozen_long_blacklist_strip_transformed.pb
,您将不再看到 IsVariableInitialized
错误。 当然,在 iOS 和 Android 上,您还会看到另一个错误。 加载先前的模型将导致此错误:
Couldn't load model: Invalid argument: No OpKernel was registered to support Op 'RefSwitch' with these attrs. Registered devices: [CPU], Registered kernels:
device='GPU'; T in [DT_FLOAT]
device='GPU'; T in [DT_INT32]
device='GPU'; T in [DT_BOOL]
device='GPU'; T in [DT_STRING]
device='CPU'; T in [DT_INT32]
device='CPU'; T in [DT_FLOAT]
device='CPU'; T in [DT_BOOL]
[[Node: cond/read/Switch = RefSwitch[T=DT_INT64, _class=["loc:@global_step"], _output_shapes=[[], []]](global_step, cond/pred_id)]]
要解决此错误,我们必须以不同的方式为 iOS 和 Android 构建自定义的 TensorFlow 库。 在下面的 iOS 和 Android 部分中讨论如何执行此操作之前,让我们首先做一件事:将模型转换为映射版本,以便在 iOS 中更快地加载并使用更少的内存:
代码语言:javascript复制bazel-bin/tensorflow/contrib/util/convert_graphdef_memmapped_format
--in_graph=/tmp/quickdraw_frozen_long_blacklist_strip_transformed.pb
--out_graph=/tmp/quickdraw_frozen_long_blacklist_strip_transformed_memmapped.pb
在 iOS 中使用绘画分类模型
要解决以前的 RefSwitch 错误,无论您是否像在第 2 章,“通过迁移学习对图像分类”和第 6 章,“用自然语言描述图像”或手动构建的 TensorFlow 库,就像在其他章节中一样,我们必须使用一些新技巧。 发生错误的原因是RefSwitch
操作需要INT64
数据类型,但它不是 TensorFlow 库中内置的已注册数据类型之一,因为默认情况下,要使该库尽可能小,仅包括每个操作的共同数据类型。 我们可能会从 Python 的模型构建端修复此问题,但是在这里,我们仅向您展示如何从 iOS 端修复此问题,当您无权访问源代码来构建模型时,这很有用。
为 iOS 构建自定义的 TensorFlow 库
从tensorflow/contrib/makefile/Makefile
打开 Makefile,然后,如果您使用 TensorFlow 1.4,则搜索IOS_ARCH
。 对于每种架构(总共 5 种:ARMV7,ARMV7S,ARM64,I386,X86_64),将-D__ANDROID_TYPES_SLIM__
更改为
-D__ANDROID_TYPES_FULL__
。 TensorFlow 1.5(或 1.6/1.7)中的Makefile
稍有不同,尽管它仍位于同一文件夹中。 对于 1.5/1.6/1.7,搜索ANDROID_TYPES_SLIM
并将其更改为 ANDROID_TYPES_FULL
。 现在,通过运行tensorflow/contrib/makefile/build_all_ios.sh
重建 TensorFlow 库。 此后,在加载模型文件时,RefSwitch
错误将消失。 使用 TensorFlow 库构建并具有完整数据类型支持的应用大小约为 70MB,而使用默认的细长数据类型构建的应用大小为 37MB。
好像还不够,仍然发生另一个模型加载错误:
Could not create TensorFlow Graph: Invalid argument: No OpKernel was registered to support Op 'RandomUniform' with these attrs. Registered devices: [CPU], Registered kernels: <no registered kernels>.
幸运的是,如果您已经阅读了前面的章节,那么您应该非常熟悉如何解决这种错误。 快速回顾一下:首先找出哪些操作和内核文件定义并实现了该操作,然后检查tf_op_files.txt
文件中是否包含操作或内核文件,并且应该至少缺少一个文件,从而导致错误 ; 现在只需将操作或内核文件添加到tf_op_files.txt
并重建库。 在我们的情况下,运行以下命令:
grep RandomUniform tensorflow/core/ops/*.cc
grep RandomUniform tensorflow/core/kernels/*.cc
您将看到这些文件作为输出:
代码语言:javascript复制tensorflow/core/ops/random_grad.cc
tensorflow/core/ops/random_ops.cc:
tensorflow/core/kernels/random_op.cc
tensorflow/contrib/makefile/tf_op_files.txt
文件只有前两个文件,因此只需将最后一个tensorflow/core/kernels/random_op.cc
添加到 tf_op_files.txt
的末尾,然后再次运行tensorflow/contrib/makefile/build_all_ios.sh
。
最终,在加载模型时所有错误都消失了,我们可以通过实现应用逻辑来处理用户绘画,将点转换为模型期望的格式并返回分类结果,从而开始获得一些真正的乐趣。
开发 iOS 应用来使用模型
让我们使用 Objective-C 创建一个新的 Xcode 项目,然后从上一章中创建的Image2Text
iOS 项目中拖放tensorflow_util.h
和tensorflow_util.mm
文件。 另外,将两个模型文件quickdraw_frozen_long_blacklist_strip_transformed.pb
和quickdraw_frozen_long_blacklist_strip_transformed_memmapped.pb
以及training.tfrecord.classes
文件从 models/tutorials/rnn/quickdraw/rnn_tutorial_data
拖放到QuickDraw
项目,然后将training.tfrecord.classes
重命名为classes.txt
。
还将ViewController.m
重命名为ViewController.mm
,并在tensorflow_util.h
中注释GetTopN
函数定义,并在tensorflow_util.mm
中注释其实现,因为我们将在ViewController.mm
中实现修改后的版本。 您的项目现在应如图 7.7 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ySJd1grN-1681653119034)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/ee464972-a065-4400-aea9-72ce9306b901.png)]
图 7.7:显示带有ViewController
初始内容的QuickDraw
Xcode 项目。
我们现在准备单独处理ViewController.mm
,以完成我们的任务。
- 在按图 7.6 设置基本常量和变量以及两个函数原型之后,在
ViewController
的viewDidLoad
中实例化UIButton
,UILabel
和UIImageView
。 每个 UI 控件都设置有多个NSLayoutConstraint
(有关完整的代码列表,请参见源代码仓库)。UIImageView
的相关代码如下:
_iv = [[UIImageView alloc] init];
_iv.contentMode = UIViewContentModeScaleAspectFit;
[_iv setTranslatesAutoresizingMaskIntoConstraints:NO];
[self.view addSubview:_iv];
UIImageView
将用于显示通过UIBezierPath
实现的用户绘画。 同样,初始化两个用于保存每个连续点和用户绘制的所有点的数组:
_allPoints = [NSMutableArray array];
_consecutivePoints = [NSMutableArray array];
- 点击具有初始标题“开始”的按钮后,用户可以开始绘画; 按钮标题更改为“重新启动”,并进行了其他一些重置:
- (IBAction)btnTapped:(id)sender {
_canDraw = YES;
[_btn setTitle:@"Restart" forState:UIControlStateNormal];
[_lbl setText:@""];
_iv.image = [UIImage imageNamed:@""];
[_allPoints removeAllObjects];
}
- 为了处理用户绘画,我们首先实现
touchesBegan
方法:
- (void) touchesBegan:(NSSet *)touches withEvent:(UIEvent *)event {
if (!_canDraw) return;
[_consecutivePoints removeAllObjects];
UITouch *touch = [touches anyObject];
CGPoint point = [touch locationInView:self.view];
[_consecutivePoints addObject:[NSValue valueWithCGPoint:point]];
_iv.image = [self createDrawingImageInRect:_iv.frame];
}
然后是touchesMoved
方法:
- (void) touchesMoved:(NSSet *)touches withEvent:(UIEvent *)event {
if (!_canDraw) return;
UITouch *touch = [touches anyObject];
CGPoint point = [touch locationInView:self.view];
[_consecutivePoints addObject:[NSValue valueWithCGPoint:point]];
_iv.image = [self createDrawingImageInRect:_iv.frame];
}
最后是touchesEnd
方法:
- (void) touchesEnded:(NSSet *)touches withEvent:(UIEvent *)event {
if (!_canDraw) return;
UITouch *touch = [touches anyObject];
CGPoint point = [touch locationInView:self.view];
[_consecutivePoints addObject:[NSValue valueWithCGPoint:point]];
[_allPoints addObject:[NSArray arrayWithArray:_consecutivePoints]];
[_consecutivePoints removeAllObjects];
_iv.image = [self createDrawingImageInRect:_iv.frame];
dispatch_async(dispatch_get_global_queue(0, 0), ^{
std::string classes = getDrawingClassification(_allPoints);
dispatch_async(dispatch_get_main_queue(), ^{
NSString *c = [NSString stringWithCString:classes.c_str() encoding:[NSString defaultCStringEncoding]];
[_lbl setText:c];
});
});
}
这里的代码很容易解释,除了createDrawingImageInRect
和getDrawingClassification
这两种方法外,我们将在后面介绍。
- 方法
createDrawingImageInRect
使用UIBezierPath's
moveToPoint
和addLineToPoint
方法显示用户绘画。 它首先通过触摸事件准备所有完成的笔划,并将所有点存储在_allPoints
数组中:
- (UIImage *)createDrawingImageInRect:(CGRect)rect
{
UIGraphicsBeginImageContextWithOptions(CGSizeMake(rect.size.width, rect.size.height), NO, 0.0);
UIBezierPath *path = [UIBezierPath bezierPath];
for (NSArray *cp in _allPoints) {
bool firstPoint = TRUE;
for (NSValue *pointVal in cp) {
CGPoint point = pointVal.CGPointValue;
if (firstPoint) {
[path moveToPoint:point];
firstPoint = FALSE;
}
else
[path addLineToPoint:point];
}
}
然后,它准备当前正在进行的笔划中的所有点,并存储在_consecutivePoints
中:
bool firstPoint = TRUE;
for (NSValue *pointVal in _consecutivePoints) {
CGPoint point = pointVal.CGPointValue;
if (firstPoint) {
[path moveToPoint:point];
firstPoint = FALSE;
}
else
[path addLineToPoint:point];
}
最后,它执行实际绘画,并将绘画作为UIImage
返回,以显示在UIImageView
中:
path.lineWidth = 6.0;
[[UIColor blackColor] setStroke];
[path stroke];
UIImage *image = UIGraphicsGetImageFromCurrentImageContext();
UIGraphicsEndImageContext();
return image;
}
getDrawingClassification
首先使用与上一章相同的代码来加载模型或其映射版本:
std::string getDrawingClassification(NSMutableArray *allPoints) {
if (!_modelLoaded) {
tensorflow::Status load_status;
if (USEMEMMAPPED) {
load_status = LoadMemoryMappedModel(MODEL_FILE_MEMMAPPED, MODEL_FILE_TYPE, &tf_session, &tf_memmapped_env);
}
else {
load_status = LoadModel(MODEL_FILE, MODEL_FILE_TYPE, &tf_session);
}
if (!load_status.ok()) {
LOG(FATAL) << "Couldn't load model: " << load_status;
return "";
}
_modelLoaded = YES;
}
然后,它获得总点数并分配一个浮点数数组,然后调用另一个函数normalizeScreenCoordinates
(稍后将介绍)将点转换为模型期望的格式:
if ([allPoints count] == 0) return "";
int total_points = 0;
for (NSArray *cp in allPoints) {
total_points = cp.count;
}
float *normalized_points = new float[total_points * 3];
normalizeScreenCoordinates(allPoints, normalized_points);
接下来,我们定义输入和输出节点名称,并创建一个包含总点数的张量:
代码语言:javascript复制 std::string input_name1 = "Reshape";
std::string input_name2 = "Squeeze";
std::string output_name1 = "dense/BiasAdd";
std::string output_name2 = "ArgMax"
const int BATCH_SIZE = 8;
tensorflow::Tensor seqlen_tensor(tensorflow::DT_INT64, tensorflow::TensorShape({BATCH_SIZE}));
auto seqlen_mapped = seqlen_tensor.tensor<int64_t, 1>();
int64_t* seqlen_mapped_data = seqlen_mapped.data();
for (int i=0; i<BATCH_SIZE; i ) {
seqlen_mapped_data[i] = total_points;
}
请注意,在运行train_model.py
来训练模型时,我们必须使用与BATCH_SIZE
相同的BATCH_SIZE
,默认情况下为 8。
保存所有转换点值的另一个张量在这里创建:
代码语言:javascript复制 tensorflow::Tensor points_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({8, total_points, 3}));
auto points_tensor_mapped = points_tensor.tensor<float, 3>();
float* out = points_tensor_mapped.data();
for (int i=0; i<BATCH_SIZE; i ) {
for (int j=0; j<total_points*3; j )
out[i*total_points*3 j] = normalized_points[j];
}
- 现在,我们运行模型并获得预期的输出:
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = tf_session->Run({{input_name1, points_tensor}, {input_name2, seqlen_tensor}}, {output_name1, output_name2}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Getting model failed:" << run_status;
return "";
}
tensorflow::string status_string = run_status.ToString();
tensorflow::Tensor* logits_tensor = &outputs[0];
- 使用修改后的
GetTopN
版本并解析logits
获得最佳结果:
const int kNumResults = 5;
const float kThreshold = 0.1f;
std::vector<std::pair<float, int> > top_results;
const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& logits = logits_tensor->flat<float>();
GetTopN(logits, kNumResults, kThreshold, &top_results);
string result = "";
for (int i=0; i<top_results.size(); i ) {
std::pair<float, int> r = top_results[i];
if (result == "")
result = classes[r.second];
else result = ", " classes[r.second];
}
- 通过将
logits
值转换为 softmax 值来更改GetTopN
,然后返回顶部 softmax 值及其位置:
float sum = 0.0;
for (int i = 0; i < CLASS_COUNT; i) {
sum = expf(prediction(i));
}
for (int i = 0; i < CLASS_COUNT; i) {
const float value = expf(prediction(i)) / sum;
if (value < threshold) {
continue;
}
top_result_pq.push(std::pair<float, int>(value, i));
if (top_result_pq.size() > num_results) {
top_result_pq.pop();
}
}
- 最后,
normalizeScreenCoordinates
函数将其在触摸事件中捕获的屏幕坐标中的所有点转换为增量差异 – 这几乎是这个页面中的 Python 方法parse_line
的一部分:
void normalizeScreenCoordinates(NSMutableArray *allPoints, float *normalized) {
float lowerx=MAXFLOAT, lowery=MAXFLOAT, upperx=-MAXFLOAT, uppery=-MAXFLOAT;
for (NSArray *cp in allPoints) {
for (NSValue *pointVal in cp) {
CGPoint point = pointVal.CGPointValue;
if (point.x < lowerx) lowerx = point.x;
if (point.y < lowery) lowery = point.y;
if (point.x > upperx) upperx = point.x;
if (point.y > uppery) uppery = point.y;
}
}
float scalex = upperx - lowerx;
float scaley = uppery - lowery;
int n = 0;
for (NSArray *cp in allPoints) {
int m=0;
for (NSValue *pointVal in cp) {
CGPoint point = pointVal.CGPointValue;
normalized[n*3] = (point.x - lowerx) / scalex;
normalized[n*3 1] = (point.y - lowery) / scaley;
normalized[n*3 2] = (m ==cp.count-1 ? 1 : 0);
n ; m ;
}
}
for (int i=0; i<n-1; i ) {
normalized[i*3] = normalized[(i 1)*3] - normalized[i*3];
normalized[i*3 1] = normalized[(i 1)*3 1] - normalized[i*3 1];
normalized[i*3 2] = normalized[(i 1)*3 2];
}
}
现在,您可以在 iOS 模拟器或设备中运行该应用,开始绘画,并查看模型认为您正在绘画的内容。 图 7.8 显示了一些绘画和分类结果–不是最佳绘画,而是整个过程!
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MnGEbMeY-1681653119034)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/48678dd2-0490-4b97-a0a6-fa4ba52bc27b.png)]
图 7.8:在 iOS 上显示绘画和分类结果
在 Android 中使用绘画分类模型
现在该看看我们如何在 Android 中加载和使用该模型。 在之前的章节中,我们通过使用 Android 应用的build.gradle
文件并添加了一行 compile 'org.tensorflow:tensorflow-android: '
仅添加了 TensorFlow 支持。 与 iOS 相比,我们必须构建一个自定义的 TensorFlow 库来修复不同的模型加载或运行错误(例如,在第 3 章,“检测对象及其位置”中,第四章,“变换具有惊人艺术风格的图片”和第五章,“了解简单的语音命令”),Android 的默认 TensorFlow 库对注册的操作和数据类型有更好的支持,这可能是因为 Android 是 Google 的一等公民,而 iOS 是第二名,甚至是第二名。
事实是,当我们处理各种惊人的模型时,我们不得不面对不可避免的问题只是时间问题:我们必须手动为 Android 构建 TensorFlow 库,以修复默认 TensorFlow 库中的一些根本无法应对的错误。 No OpKernel was registered to support Op 'RefSwitch' with these attrs.
错误就是这样的错误之一。 对于乐观的开发人员来说,这仅意味着另一种向您的技能组合中添加新技巧的机会。
为 Android 构建自定义 TensorFlow 库
请按照以下步骤手动为 Android 构建自定义的 TensorFlow 库:
- 在您的 TensorFlow 根目录中,有一个名为
WORKSPACE
的文件。 编辑它,并使android_sdk_repository
和android_ndk_repository
看起来像以下设置(用您自己的设置替换build_tools_version
以及 SDK 和 NDK 路径):
android_sdk_repository(
name = "androidsdk",
api_level = 23,
build_tools_version = "26.0.1",
path = "$HOME/Library/Android/sdk",
)
android_ndk_repository(
name="androidndk",
path="$HOME/Downloads/android-ndk-r15c",
api_level=14)
- 如果您还使用过本书中的 iOS 应用,并且已将
tensorflow/core/platform/default/mutex.h
从#include "nsync_cv.h"
和#include "nsync_mu.h"
更改为#include "nsync/public/nsync_cv.h"
和#include "nsync/public/nsync_mu.h"
,请参见第 3 章, “检测对象及其位置” 时,您需要将其更改回以成功构建 TensorFlow Android 库(此后,当您使用手动构建的 TensorFlow 库在 Xcode 和 iOS 应用上工作时,需要先添加nsync/public
这两个标头。
Changing tensorflow/core/platform/default/mutex.h
back and forth certainly is not an ideal solution. It’s supposed to be just as a workaround. As it only needs to be changed when you start using a manually built TensorFlow iOS library or when you build a custom TensorFlow library, we can live with it for now.
- 如果您具有支持 x86 CPU 的虚拟模拟器或 Android 设备,请运行以下命令来构建本机 TensorFlow 库:
bazel build -c opt --copt="-D__ANDROID_TYPES_FULL__" //tensorflow/contrib/android:libtensorflow_inference.so
--crosstool_top=//external:android/crosstool
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain
--cpu=x86_64
如果您的 Android 设备像大多数 Android 设备一样支持 armeabi-v7a,请运行以下命令:
代码语言:javascript复制bazel build -c opt --copt="-D__ANDROID_TYPES_FULL__" //tensorflow/contrib/android:libtensorflow_inference.so
--crosstool_top=//external:android/crosstool
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain
--cpu=armeabi-v7a
在 Android 应用中使用手动构建的本机库时,您需要让该应用知道该库是针对哪个 CPU 指令集(也称为应用二进制接口(ABI))构建的。 Android 支持两种主要的 ABI:ARM 和 X86,而armeabi-v7a
是 Android 上最受欢迎的 ABI。 要找出您的设备或仿真器使用的是哪个 ABI,请运行adb -s <device_id> shell getprop ro.product.cpu.abi
。 例如,此命令为我的 Nexus 7 平板电脑返回armeabi-v7a
,为我的模拟器返回x86_64
。
如果您具有支持 x86_64 的虚拟仿真器以在开发过程中进行快速测试,并且在设备上进行最终性能测试,则可能要同时构建两者。
构建完成后,您将在bazel-bin/tensorflow/contrib/android
文件夹中看到 TensorFlow 本机库文件libtensorflow_inference.so
。 将其拖到android/app/src/main/jniLibs/armeabi-v7a
或 android/app/src/main/jniLibs/x86_64
的app
文件夹中,如图 7.9 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U37zgeVF-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/0865ad60-dece-4349-9825-24ca46220d07.png)]
图 7.9:显示 TensorFlow 本机库文件
- 通过运行以下命令构建 TensorFlow 本机库的 Java 接口:
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
这将在bazel-bin/tensorflow/contrib/android
处生成文件libandroid_tensorflow_inference_java.jar
。 将文件移动到 android/app/lib
文件夹,如图 7.10 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p6iE4MOg-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/51e4c597-7447-4623-8e07-396b56faddfa.png)]
图 7.10:将 Java 接口文件显示到 TensorFlow 库
现在,我们准备在 Android 中编码和测试模型。
开发一个 Android 应用来使用该模型
请按照以下步骤使用 TensorFlow 库和我们先前构建的模型创建一个新的 Android 应用:
- 在 Android Studio 中,创建一个名为 QuickDraw 的新 Android 应用,接受所有默认设置。 然后在应用的
build.gradle
中,将compile files('libs/libandroid_tensorflow_inference_java.jar')
添加到依赖项的末尾。 像以前一样创建一个新的assets
文件夹,并将quickdraw_frozen_long_blacklist_strip_transformed.pb
和classes.txt
拖放到该文件夹中。 - 创建一个名为
QuickDrawView
的新 Java 类,该类扩展了View
,并如下设置字段及其构造器:
public class QuickDrawView extends View {
private Path mPath;
private Paint mPaint, mCanvasPaint;
private Canvas mCanvas;
private Bitmap mBitmap;
private MainActivity mActivity;
private List<List<Pair<Float, Float>>> mAllPoints = new ArrayList<List<Pair<Float, Float>>>();
private List<Pair<Float, Float>> mConsecutivePoints = new ArrayList<Pair<Float, Float>>();
public QuickDrawView(Context context, AttributeSet attrs) {
super(context, attrs);
mActivity = (MainActivity) context;
setPathPaint();
}
mAllPoints
用于保存mConsecutivePoints
的列表。 QuickDrawView
用于主要活动的布局中,以显示用户的绘画。
- 如下定义
setPathPaint
方法:
private void setPathPaint() {
mPath = new Path();
mPaint = new Paint();
mPaint.setColor(0xFF000000);
mPaint.setAntiAlias(true);
mPaint.setStrokeWidth(18);
mPaint.setStyle(Paint.Style.STROKE);
mPaint.setStrokeJoin(Paint.Join.ROUND);
mCanvasPaint = new Paint(Paint.DITHER_FLAG);
}
添加两个实例化Bitmap
和Canvas
对象并向用户显示在画布上绘画的重写方法:
@Override protected void onSizeChanged(int w, int h, int oldw, int oldh) {
super.onSizeChanged(w, h, oldw, oldh);
mBitmap = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
mCanvas = new Canvas(mBitmap);
}
@Override protected void onDraw(Canvas canvas) {
canvas.drawBitmap(mBitmap, 0, 0, mCanvasPaint);
canvas.drawPath(mPath, mPaint);
}
- 覆盖方法
onTouchEvent
用于填充mConsecutivePoints
和mAllPoints
,调用画布的drawPath
方法,使图无效(以调用onDraw
方法),以及(每次使用MotionEvent.ACTION_UP
完成笔划线),以启动一个新线程以使用模型对绘画进行分类:
@Override
public boolean onTouchEvent(MotionEvent event) {
if (!mActivity.canDraw()) return true;
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
mConsecutivePoints.clear();
mConsecutivePoints.add(new Pair(x, y));
mPath.moveTo(x, y);
break;
case MotionEvent.ACTION_MOVE:
mConsecutivePoints.add(new Pair(x, y));
mPath.lineTo(x, y);
break;
case MotionEvent.ACTION_UP:
mConsecutivePoints.add(new Pair(x, y));
mAllPoints.add(new ArrayList<Pair<Float, Float>>
(mConsecutivePoints));
mCanvas.drawPath(mPath, mPaint);
mPath.reset();
Thread thread = new Thread(mActivity);
thread.start();
break;
default:
return false;
}
invalidate();
return true;
}
- 定义两个将由
MainActivity
调用的公共方法,以获取所有点并在用户点击重新启动按钮后重置绘画:
public List<List<Pair<Float, Float>>> getAllPoints() {
return mAllPoints;
}
public void clearAllPointsAndRedraw() {
mBitmap = Bitmap.createBitmap(mBitmap.getWidth(),
mBitmap.getHeight(), Bitmap.Config.ARGB_8888);
mCanvas = new Canvas(mBitmap);
mCanvasPaint = new Paint(Paint.DITHER_FLAG);
mCanvas.drawBitmap(mBitmap, 0, 0, mCanvasPaint);
setPathPaint();
invalidate();
mAllPoints.clear();
}
- 现在打开
MainActivity
,并使其实现Runnable
及其字段,如下所示:
public class MainActivity extends AppCompatActivity implements Runnable {
private static final String MODEL_FILE = "file:///android_asset/quickdraw_frozen_long_blacklist_strip_transformed.pb";
private static final String CLASSES_FILE = "file:///android_asset/classes.txt";
private static final String INPUT_NODE1 = "Reshape";
private static final String INPUT_NODE2 = "Squeeze";
private static final String OUTPUT_NODE1 = "dense/BiasAdd";
private static final String OUTPUT_NODE2 = "ArgMax";
private static final int CLASSES_COUNT = 345;
private static final int BATCH_SIZE = 8;
private String[] mClasses = new String[CLASSES_COUNT];
private QuickDrawView mDrawView;
private Button mButton;
private TextView mTextView;
private String mResult = "";
private boolean mCanDraw = false;
private TensorFlowInferenceInterface mInferenceInterface;
- 在主布局文件
activity_main.xml
中,除了我们之前所做的TextView
和Button
之外,还创建一个QuickDrawView
元素:
<com.ailabby.quickdraw.QuickDrawView
android:id="@ id/drawview"
android:layout_width="fill_parent"
android:layout_height="fill_parent"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent"/>
- 返回
MainActivity
; 在其onCreate
方法中,将 UI 元素 ID 与字段绑定,为启动/重启按钮设置点击监听器。 然后将classes.txt
文件读入字符串数组:
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mDrawView = findViewById(R.id.drawview);
mButton = findViewById(R.id.button);
mTextView = findViewById(R.id.textview);
mButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
mCanDraw = true;
mButton.setText("Restart");
mTextView.setText("");
mDrawView.clearAllPointsAndRedraw();
}
});
String classesFilename = CLASSES_FILE.split("file:///android_asset/")[1];
BufferedReader br = null;
int linenum = 0;
try {
br = new BufferedReader(new InputStreamReader(getAssets().open(classesFilename)));
String line;
while ((line = br.readLine()) != null) {
mClasses[linenum ] = line;
}
br.close();
} catch (IOException e) {
throw new RuntimeException("Problem reading classes file!" , e);
}
}
- 然后从线程的
run
方法中调用同步方法classifyDrawing
:
public void run() {
classifyDrawing();
}
private synchronized void classifyDrawing() {
try {
double normalized_points[] = normalizeScreenCoordinates();
long total_points = normalized_points.length / 3;
float[] floatValues = new float[normalized_points.length*BATCH_SIZE];
for (int i=0; i<normalized_points.length; i ) {
for (int j=0; j<BATCH_SIZE; j )
floatValues[j*normalized_points.length i] = (float)normalized_points[i];
}
long[] seqlen = new long[BATCH_SIZE];
for (int i=0; i<BATCH_SIZE; i )
seqlen[i] = total_points;
即将实现的normalizeScreenCoordinates
方法将用户绘画点转换为模型期望的格式。 floatValues
和seqlen
将被输入模型。 请注意,由于模型需要这些确切的数据类型(float
和int64
),因此我们必须在floatValues
中使用float
在seqlen
中使用long
,否则在使用模型时会发生运行时错误。
- 创建一个与 TensorFlow 库的 Java 接口以加载模型,向模型提供输入并获取输出:
AssetManager assetManager = getAssets();
mInferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE);
mInferenceInterface.feed(INPUT_NODE1, floatValues, BATCH_SIZE, total_points, 3);
mInferenceInterface.feed(INPUT_NODE2, seqlen, BATCH_SIZE);
float[] logits = new float[CLASSES_COUNT * BATCH_SIZE];
float[] argmax = new float[CLASSES_COUNT * BATCH_SIZE];
mInferenceInterface.run(new String[] {OUTPUT_NODE1, OUTPUT_NODE2}, false);
mInferenceInterface.fetch(OUTPUT_NODE1, logits);
mInferenceInterface.fetch(OUTPUT_NODE1, argmax);
- 归一化所提取的
logits
概率并以降序对其进行排序:
double sum = 0.0;
for (int i=0; i<CLASSES_COUNT; i )
sum = Math.exp(logits[i]);
List<Pair<Integer, Float>> prob_idx = new ArrayList<Pair<Integer, Float>>();
for (int j = 0; j < CLASSES_COUNT; j ) {
prob_idx.add(new Pair(j, (float)(Math.exp(logits[j]) / sum) ));
}
Collections.sort(prob_idx, new Comparator<Pair<Integer, Float>>() {
@Override
public int compare(final Pair<Integer, Float> o1, final Pair<Integer, Float> o2) {
return o1.second > o2.second ? -1 : (o1.second == o2.second ? 0 : 1);
}
});
获取前五个结果并将其显示在TextView
中:
mResult = "";
for (int i=0; i<5; i ) {
if (prob_idx.get(i).second > 0.1) {
if (mResult == "") mResult = "" mClasses[prob_idx.get(i).first];
else mResult = mResult ", " mClasses[prob_idx.get(i).first];
}
}
runOnUiThread(
new Runnable() {
@Override
public void run() {
mTextView.setText(mResult);
}
});
- 最后,实现
normalizeScreenCoordinates
方法,它是 iOS 实现的便捷端口:
private double[] normalizeScreenCoordinates() {
List<List<Pair<Float, Float>>> allPoints = mDrawView.getAllPoints();
int total_points = 0;
for (List<Pair<Float, Float>> cp : allPoints) {
total_points = cp.size();
}
double[] normalized = new double[total_points * 3];
float lowerx=Float.MAX_VALUE, lowery=Float.MAX_VALUE, upperx=-Float.MAX_VALUE, uppery=-Float.MAX_VALUE;
for (List<Pair<Float, Float>> cp : allPoints) {
for (Pair<Float, Float> p : cp) {
if (p.first < lowerx) lowerx = p.first;
if (p.second < lowery) lowery = p.second;
if (p.first > upperx) upperx = p.first;
if (p.second > uppery) uppery = p.second;
}
}
float scalex = upperx - lowerx;
float scaley = uppery - lowery;
int n = 0;
for (List<Pair<Float, Float>> cp : allPoints) {
int m = 0;
for (Pair<Float, Float> p : cp) {
normalized[n*3] = (p.first - lowerx) / scalex;
normalized[n*3 1] = (p.second - lowery) / scaley;
normalized[n*3 2] = (m ==cp.size()-1 ? 1 : 0);
n ; m ;
}
}
for (int i=0; i<n-1; i ) {
normalized[i*3] = normalized[(i 1)*3] - normalized[i*3];
normalized[i*3 1] = normalized[(i 1)*3 1] -
normalized[i*3 1];
normalized[i*3 2] = normalized[(i 1)*3 2];
}
return normalized;
}
在您的 Android 模拟器或设备上运行该应用,并享受分类结果的乐趣。 您应该看到类似图 7.11 的内容:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-o4rePOTz-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/560fa3a2-94a9-4f93-ae05-bd09921b8e0c.png)]
图 7.11:在 Android 上显示绘画和分类结果
既然您已经了解了训练 Quick Draw 模型的全过程,并在 iOS 和 Android 应用中使用了它,那么您当然可以微调训练方法,使其更加准确,并改善移动应用的乐趣。
在本章我们不得不结束有趣旅程之前的最后一个提示是,如果您使用错误的 ABI 构建适用于 Android 的 TensorFlow 本机库,您仍然可以从 Android Studio 构建和运行该应用,但将出现运行时错误java.lang.RuntimeException: Native TF methods not found; check that the correct native libraries are present in the APK.
,这意味着您的应用的jniLibs
文件夹中没有正确的 TensorFlow 本机库(图 7.9)。 要找出jniLibs
内特定 ABI 文件夹中是否缺少该文件,可以从Android Studio | View | Tool Windows
中打开Device File Explorer
,然后选择设备的data | app | package | lib
来查看,如图 7.12 所示。 如果您更喜欢命令行,则也可以使用adb
工具找出来。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0X6N6LxN-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/16494075-d42a-4b86-a48a-4cb0bc2ec865.png)]
图 7.12:使用设备文件资源管理器检出 TensorFlow 本机库文件
总结
在本章中,我们首先描述了绘画分类模型的工作原理,然后介绍了如何使用高级 TensorFlow Estimator API 训练这种模型。 我们研究了如何编写 Python 代码以使用经过训练的模型进行预测,然后详细讨论了如何找到正确的输入和输出节点名称以及如何以正确的方式冻结和转换模型以使移动应用可以使用它。 我们还提供了一种新方法来构建新的 TensorFlow 自定义 iOS 库,并提供了一个逐步教程,以构建适用于 Android 的 TensorFlow 自定义库,以修复使用模型时的运行时错误。 最后,我们展示了 iOS 和 Android 代码,这些代码捕获并显示用户绘画,将其转换为模型所需的数据,并处理和呈现模型返回的分类结果。 希望您在漫长的旅途中学到了很多东西。
到目前为止,除了来自其他开放源代码项目的几个模型以外,所有由我们自己进行预训练或训练的模型,我们在 iOS 和 Android 应用中使用的都是 TensorFlow 开放源代码项目,当然,该项目提供了大量强大的模型,其中一些模型在强大的 GPU 上进行了数周的训练。 但是,如果您有兴趣从头开始构建自己的模型,并且还对本章中使用和应用的强大 RNN 模型以及概念感到困惑,那么下一章就是您所需要的:我们将讨论如何从头开始构建自己的 RNN 模型并在移动应用中使用它,从而带来另一种乐趣-从股市中赚钱-至少我们会尽力做到这一点。 当然,没有人能保证您每次都能从每次股票交易中获利,但是至少让我们看看我们的 RNN 模型如何帮助我们提高这样做的机会。
八、用 RNN 预测股价
如果在上一章中在移动设备上玩过涂鸦和构建(并运行模型以识别涂鸦),当您在股市上赚钱时会感到很开心,而如果您不认真的话会变得很认真。 一方面,股价是时间序列数据,一系列离散时间数据,而处理时间序列数据的最佳深度学习方法是 RNN,这是我们在前两章中使用的方法。 AurélienGéron 在他的畅销书《Scikit-Learn 和 TensorFlow 机器学习实战》中,建议使用 RNN“分析时间序列数据,例如股票价格,并告诉您何时买卖”。 另一方面,其他人则认为股票的过去表现无法预测其未来收益,因此,随机选择的投资组合的表现与专家精心挑选的股票一样好。 实际上,Keras(在 TensorFlow 和其他几个库之上运行的非常受欢迎的高级深度学习库)的作者 FrançoisChollet 在他的畅销书《Python 深度学习》中表示,使用 RNN。 仅用公开数据来击败市场是“一项非常困难的努力,您可能会浪费时间和资源,而无所作为。”
因此,冒着“可能”浪费我们时间和资源的风险,但是可以肯定的是,我们至少将了解更多有关 RNN 的知识,以及为什么有可能比随机 50% 的策略更好地预测股价,我们将首先概述如何使用 RNN 进行股票价格预测,然后讨论如何使用 TensorFlow API 构建 RNN 模型来预测股票价格,以及如何使用易于使用的 Keras API 来为价格预测构建 RNN LSTM 模型。 我们将测试这些模型是否可以击败随机的买入或卖出策略。 如果我们对我们的模型感到满意,以提高我们在市场上的领先优势,或者只是出于专有技术的目的,我们将了解如何冻结并准备 TensorFlow 和 Keras 模型以在 iOS 和 Android 应用上运行。 如果该模型可以提高我们的机会,那么我们支持该模型的移动应用可以在任何时候,无论何时何地做出买或卖决定。 感觉有点不确定和兴奋? 欢迎来到市场。
总之,本章将涵盖以下主题:
- RNN 和股价预测:什么以及如何
- 使用 TensorFlow RNN API 进行股价预测
- 使用 Keras RNN LSTM API 进行股价预测
- 在 iOS 上运行 TensorFlow 和 Keras 模型
- 在 Android 上运行 TensorFlow 和 Keras 模型
RNN 和股价预测 – 什么以及如何
前馈网络(例如密集连接的网络)没有内存,无法将每个输入视为一个整体。 例如,表示为像素向量的图像输入在单个步骤中由前馈网络处理。 但是,使用具有内存的网络可以更好地处理时间序列数据,例如最近 10 或 20 天的股价。 假设过去 10 天的价格为X1, X2, ..., X10
,其中X1
为最早的和X10
为最晚,然后将所有 10 天价格视为一个序列输入,并且当 RNN 处理此类输入时,将发生以下步骤:
- 按顺序连接到第一个元素
X1
的特定 RNN 单元处理X1
并获取其输出y1
- 在序列输入中,连接到下一个元素
X2
的另一个 RNN 单元使用X2
以及先前的输出y1
, 获得下一个输出y2
- 重复该过程:在时间步长使用 RNN 单元处理输入序列中的
Xi
元素时,先前的输出y[i-1]
,在时间步i-1
与Xi
一起使用,以在时间步i
生成新的输出yi
。
因此,在时间步长i
的每个yi
输出,都具有有关输入序列中直到时间步长i
以及包括时间步长i
的所有元素的信息:X1, X2, ..., X[i-1]
和Xi
。 在 RNN 训练期间,预测价格y1, y2, ..., y9
和y10
的每个时间步长与每个时间步长的真实目标价格进行比较,即X2, X3, ..., X10
和X11
和损失函数因此被定义并用于优化以更新网络参数。 训练完成后,在预测期间,将X11
用作输入序列的预测,X1, X2, ..., X10
。
这就是为什么我们说 RNN 有内存。 RNN 对于处理股票价格数据似乎很有意义,因为直觉是,今天(以及明天和后天等等)的股票价格可能会受其前N
天的价格影响。
LSTM 只是解决 RNN 已知梯度消失问题的一种 RNN,我们在第 6 章,“用自然语言描述图像”中引入了 LSTM。 基本上,在训练 RNN 模型的过程中,,如果到 RNN 的输入序列的时间步太长,则使用反向传播更新较早时间步的网络权重可能会得到 0 的梯度值, 导致没有学习发生。 例如,当我们使用 50 天的价格作为输入,并且如果使用 50 天甚至 40 天的时间步长变得太长,则常规 RNN 将是不可训练的。 LSTM 通过添加一个长期状态来解决此问题,该状态决定可以丢弃哪些信息以及需要在许多时间步骤中存储和携带哪些信息。
可以很好地解决梯度消失问题的另一种 RNN 被称为门控循环单元(GRU),它稍微简化了标准 LSTM 模型,并且越来越受欢迎。 TensorFlow 和 Keras API 均支持基本的 RNN 和 LSTM/GRU 模型。 在接下来的两部分中,您将看到使用 RNN 和标准 LSTM 的具体 TensorFlow 和 Keras API,并且可以在代码中简单地将LSTM
替换为GRU
,以将使用 GRU 模型的结果与 RNN 和标准 LSTM 模型比较。
三种常用技术可以使 LSTM 模型表现更好:
- 堆叠 LSTM 层并增加层中神经元的数量:如果不产生过拟合,通常这将导致功能更强大,更准确的网络模型。 如果还没有,那么您绝对应该玩 TensorFlow Playground来体验一下。
- 使用丢弃处理过拟合。 删除意味着随机删除层中的隐藏单元和输入单元。
- 使用双向 RNN 在两个方向(常规方向和反向方向)处理每个输入序列,希望检测出可能被常规单向 RNN 忽略的模式。
所有这些技术已经实现,并且可以在 TensorFlow 和 Keras API 中轻松访问。
那么,我们如何使用 RNN 和 LSTM 测试股价预测? 我们将在这个页面上使用免费的 API 收集特定股票代码的每日股票价格数据,将其解析为训练集和测试集,并每次向 RNN/LSTM 模型提供一批训练输入(每个训练输入有 20 个时间步长,即,连续 20 天的价格),对模型进行训练,然后进行测试以查看模型在测试数据集中的准确率。 我们将同时使用 TensorFlow 和 Keras API 进行测试,并比较常规 RNN 和 LSTM 模型之间的差异。 我们还将测试三个略有不同的序列输入和输出,看看哪个是最好的:
- 根据过去
N
天预测一天的价格 - 根据过去
N
天预测M
天的价格 - 基于将过去
N
天移动 1 并使用预测序列的最后输出作为第二天的预测价格进行预测
现在让我们深入研究 TensorFlow RNN API 并进行编码以训练模型来预测股票价格,以查看其准确率如何。
将 TensorFlow RNN API 用于股价预测
首先,您需要在这里索取免费的 API 密钥,以便获取任何股票代码的股价数据。 取得 API 密钥后,打开终端并运行以下命令(将<your_api_key>
替换为您自己的密钥后)以获取 Amazon(amzn)和 Google(goog)的每日股票数据,或将它们替换为你感兴趣的任何符号:
curl -o daily_amzn.csv "https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=amzn&apikey=<your_api_key>&datatype=csv&outputsize=full"
curl -o daily_goog.csv "https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=goog&apikey=<your_api_key>&datatype=csv&outputsize=full"
这将生成一个daily_amzn.csv
或daily_goog.csv
csv 文件 ,其顶行为“时间戳,开盘,高位,低位,收盘,交易量”,这些行的其余部分作为每日股票信息。 我们只关心收盘价,因此运行以下命令以获取所有收盘价:
cut -d ',' -f 5 daily_amzn.csv | tail -n 2 > amzn.txt
cut -d ',' -f 5 daily_goog.csv | tail -n 2 > goog.txt
截至 2018 年 2 月 26 日,amzn.txt
或goog.txt
中的行数为 4,566 或 987,这是亚马逊或 Google 的交易天数。 现在,让我们看一下使用 TensorFlow RNN API 训练和预测模型的完整 Python 代码。
在 TensorFlow 中训练 RNN 模型
- 导入所需的 Python 包并定义一些常量:
import numpy as np
import tensorflow as tf
from tensorflow.contrib.rnn import *
import matplotlib.pyplot as plt
num_neurons = 100
num_inputs = 1
num_outputs = 1
symbol = 'goog' # amzn
epochs = 500
seq_len = 20
learning_rate = 0.001
NumPy 是用于 N 维数组操作的最受欢迎的 Python 库,而 Matplotlib 是领先的 Python 2D 绘图库。 我们将使用 numpy 处理数据集,并使用 Matplotlib 可视化股票价格和预测。 num_neurons
是 RNN(或更准确地说是 RNN 单元)在每个时间步长上的神经元数量-每个神经元在该时间步长上都接收输入序列的输入元素,并从前一个时间步长上接收输出。 num_inputs
和num_outputs
指定每个时间步长的输入和输出数量-我们将从每个时间步长的 20 天输入序列中将一个股票价格提供给带有num_neurons
神经元的 RNN 单元,并在每个步骤期望一个预测的股票输出。 seq_len
是时间步数。 因此,我们将使用 Google 的 20 天股票价格作为输入序列,并将这些输入发送给具有 100 个神经元的 RNN 单元。
- 打开并读取包含所有价格的文本文件,将价格解析为
float
数字列表,颠倒列表顺序,以便最早的价格首先开始,然后每次添加seq_len 1
值(第一个seq_len
值将是 RNN 的输入序列,最后的seq_len
值将是目标输出序列),从列表中的第一个开始,每次移动 1 直到列表的末尾,直到一个 numpyresult
数组:
f = open(symbol '.txt', 'r').read()
data = f.split('n')[:-1] # get rid of the last '' so float(n) works
data.reverse()
d = [float(n) for n in data]
result = []
for i in range(len(d) - seq_len - 1):
result.append(d[i: i seq_len 1])
result = np.array(result)
result
数组现在包含我们模型的整个数据集,但是我们需要将其进一步处理为 RNN API 期望的格式。 首先,将其分为训练集(占整个数据集的 90%)和测试集(占 10%):
row = int(round(0.9 * result.shape[0]))
train = result[:row, :]
test = result[row:, :]
然后随机地随机排列训练集,作为机器学习模型训练中的标准做法:
代码语言:javascript复制np.random.shuffle(train)
制定训练集和测试集X_train
和X_test
的输入序列,以及训练集和测试集y_train
和y_test
的目标输出序列。 请注意,大写字母X
和小写字母y
是机器学习中常用的命名约定,分别代表输入和目标输出:
X_train = train[:, :-1] # all rows with all columns except the last one
X_test = test[:, :-1] # each row contains seq_len 1 columns
y_train = train[:, 1:]
y_test = test[:, 1:]
最后,将四个数组重塑为 3-D(批大小,时间步数以及输入或输出数),以完成训练和测试数据集的准备:
代码语言:javascript复制X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], num_inputs))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], num_inputs))
y_train = np.reshape(y_train, (y_train.shape[0], y_train.shape[1], num_outputs))
y_test = np.reshape(y_test, (y_test.shape[0], y_test.shape[1], num_outputs))
注意,X_train.shape[1]
,X_test.shape[1]
,y_train.shape[1]
和y_test.shape[1]
与seq_len
相同。
- 我们已经准备好构建模型。 创建两个占位符,以便在训练期间和
X_test
一起喂入X_train
和y_train
:
X = tf.placeholder(tf.float32, [None, seq_len, num_inputs])
y = tf.placeholder(tf.float32, [None, seq_len, num_outputs])
使用BasicRNNCell
创建一个 RNN 单元,每个时间步分别具有 num_neurons
神经元,:
cell = tf.contrib.rnn.OutputProjectionWrapper(
tf.contrib.rnn.BasicRNNCell(num_units=num_neurons, activation=tf.nn.relu), output_size=num_outputs)
outputs, _ = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
OutputProjectionWrapper
用于在每个单元的输出之上添加一个完全连接的层,因此,在每个时间步长处,RNN 单元的输出(将是num_neurons
值的序列)都会减小为单个值。 这就是 RNN 在每个时间步为输入序列中的每个值输出一个值,或为每个实例的seq_len
个数的值的每个输入序列输出总计seq_len
个数的值的方式。
dynamic_rnn
用于循环所有时间步长的 RNN 信元,总和为seq_len
(在X
形状中定义),它返回两个值:每个时间步长的输出列表,以及网络的最终状态。 接下来,我们将使用第一个outputs
返回的整形值来定义损失函数。
- 通过以标准方式指定预测张量,损失,优化器和训练操作来完成模型定义:
preds = tf.reshape(outputs, [1, seq_len], name="preds")
loss = tf.reduce_mean(tf.square(outputs - y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
请注意,当我们使用freeze_graph
工具准备要在移动设备上部署的模型时,"preds"
将用作输出节点名称,它也将在 iOS 和 Android 中用于运行模型进行预测。 如您所见,在我们甚至开始训练模型之前一定要知道那条信息,这绝对是一件很高兴的事情,而这是我们从头开始构建的模型的好处。
- 开始训练过程。 对于每个周期,我们将
X_train
和y_train
数据输入以运行training_op
以最小化loss
,然后保存模型检查点文件,并每 10 个周期打印损失值:
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
init.run()
count = 0
for _ in range(epochs):
n=0
sess.run(training_op, feed_dict={X: X_train, y: y_train})
count = 1
if count % 10 == 0:
saver.save(sess, "/tmp/" symbol "_model.ckpt")
loss_val = loss.eval(feed_dict={X: X_train, y: y_train})
print(count, "loss:", loss_val)
如果您运行上面的代码,您将看到如下输出:
代码语言:javascript复制(10, 'loss:', 243802.61)
(20, 'loss:', 80629.57)
(30, 'loss:', 40018.996)
(40, 'loss:', 28197.496)
(50, 'loss:', 24306.758)
...
(460, 'loss:', 93.095985)
(470, 'loss:', 92.864082)
(480, 'loss:', 92.33461)
(490, 'loss:', 92.09893)
(500, 'loss:', 91.966286)
您可以在第 4 步中用BasicLSTMCell
替换BasicRNNCell
并运行训练代码,但是使用BasicLSTMCell
进行训练要慢得多,并且在 500 个周期之后损失值仍然很大。 在本节中,我们将不再对BasicLSTMCell
进行实验,但是为了进行比较,在使用 Keras 的下一部分中,您将看到堆叠 LSTM 层,丢弃法和双向 RNN 的详细用法。
测试 TensorFlow RNN 模型
要查看 500 个周期后的损失值是否足够好,让我们使用测试数据集添加以下代码,以计算总测试示例中正确预测的数量(正确的意思是,预测价格在目标价格的同一个方向上上下波动,相对于前一天的价格):
代码语言:javascript复制 correct = 0
y_pred = sess.run(outputs, feed_dict={X: X_test})
targets = []
predictions = []
for i in range(y_pred.shape[0]):
input = X_test[i]
target = y_test[i]
prediction = y_pred[i]
targets.append(target[-1][0])
predictions.append(prediction[-1][0])
if target[-1][0] >= input[-1][0] and prediction[-1][0] >=
input[-1][0]:
correct = 1
elif target[-1][0] < input[-1][0] and prediction[-1][0] <
input[-1][0]:
correct = 1
现在我们可以使用plot
方法可视化预测正确率:
total = len(X_test)
xs = [i for i, _ in enumerate(y_test)]
plt.plot(xs, predictions, 'r-', label='prediction')
plt.plot(xs, targets, 'b-', label='true')
plt.legend(loc=0)
plt.title("%s - %d/%d=%.2f%%" %(symbol, correct, total,
100*float(correct)/total))
plt.show()
现在运行代码将显示如图 8.1 所示,正确预测的比率为 56.25% :
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Q8Ngho3K-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/71f57975-9e80-4886-a13f-7a38b52dc84d.png)]
图 8.1:显示使用 TensorFlow RNN 训练的股价预测正确性
注意,每次运行此训练和测试代码时,您获得的比率可能都会有所不同。 通过微调模型的超参数,您可能会获得超过 60% 的比率,这似乎比随机预测要好。 如果您乐观的话,您可能会认为至少有 50% (56.25%)的东西要显示出来,并且可能希望看到该模型在移动设备上运行。 但首先让我们看看是否可以使用酷的 Keras 库来构建更好的模型-在执行此操作之前,让我们通过简单地运行来冻结经过训练的 TensorFlow 模型:
代码语言:javascript复制python tensorflow/python/tools/freeze_graph.py --input_meta_graph=/tmp/amzn_model.ckpt.meta --input_checkpoint=/tmp/amzn_model.ckpt --output_graph=/tmp/amzn_tf_frozen.pb --output_node_names="preds" --input_binary=true
将 Keras RNN LSTM API 用于股价预测
Keras 是一个非常易于使用的高级深度学习 Python 库,它运行在 TensorFlow,Theano 和 CNTK 等其他流行的深度学习库之上。 您很快就会看到,Keras 使构建和使用模型变得更加容易。 要安装和使用 Keras 以及 TensorFlow 作为 Keras 的后端,最好首先设置一个 VirtualEnv:
代码语言:javascript复制sudo pip install virtualenv
如果您的机器和 iOS 和 Android 应用上都有 TensorFlow 1.4 源,请运行以下命令;否则,请运行以下命令。 使用 TensorFlow 1.4 自定义库:
代码语言:javascript复制cd
mkdir ~/tf14_keras
virtualenv --system-site-packages ~/tf14_keras/
cd ~/tf14_keras/
source ./bin/activate
easy_install -U pip
pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0-py2-none-any.whl
pip install keras
如果您的机器上装有 TensorFlow 1.5 源,则应在 Keras 上安装 TensorFlow 1.5,因为使用 Keras 创建的模型需要具有与 TensorFlow 移动应用所使用的模型相同的 TensorFlow 版本,或者在尝试加载模型时发生错误:
代码语言:javascript复制cd
mkdir ~/tf15_keras
virtualenv --system-site-packages ~/tf15_keras/
cd ~/tf15_keras/
source ./bin/activate
easy_install -U pip
pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0-py2-none-any.whl
pip install keras
如果您的操作系统不是 Mac 或计算机具有 GPU,则您需要用正确的 URL 替换 TensorFlow Python 包 URL,您可以在这个页面上找到它。
在 Keras 中训练 RNN 模型
现在,让我们看看在 Keras 中建立和训练 LSTM 模型以预测股价的过程。 首先,一些导入和常量设置:
代码语言:javascript复制import keras
from keras import backend as K
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
from keras.layers import Bidirectional
from keras.models import Sequential
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
symbol = 'amzn'
epochs = 10
num_neurons = 100
seq_len = 20
pred_len = 1
shift_pred = False
shift_pred
用于指示我们是否要预测价格的输出序列而不是单个输出价格。 如果是True
,我们将根据输入X1, X2, ..., Xn
来预测X2, X3, ..., X[n 1]
,就像我们在使用 TensorFlow API 的最后一部分中所做的那样。 如果shift_pred
为False
,我们将基于输入X1, X2, ..., Xn
来预测输出的pred_len
。 例如,如果pred_len
为 1,我们将预测X[n 1]
,如果pred_len
为 3,我们将预测X[n 1], X[n 2], X[n 3]
,这很有意义,因为我们很想知道价格是连续连续 3 天上涨还是仅上涨 1 天然后下降 2 天。
现在,让我们创建一个根据上一节中的数据加载代码进行修改的方法,该方法根据pred_len
和shift_pred
设置准备适当的训练和测试数据集:
def load_data(filename, seq_len, pred_len, shift_pred):
f = open(filename, 'r').read()
data = f.split('n')[:-1] # get rid of the last '' so float(n) works
data.reverse()
d = [float(n) for n in data]
lower = np.min(d)
upper = np.max(d)
scale = upper-lower
normalized_d = [(x-lower)/scale for x in d]
result = []
if shift_pred:
pred_len = 1
for i in range((len(normalized_d) - seq_len - pred_len)/pred_len):
result.append(normalized_d[i*pred_len: i*pred_len seq_len pred_len])
result = np.array(result)
row = int(round(0.9 * result.shape[0]))
train = result[:row, :]
test = result[row:, :]
np.random.shuffle(train)
X_train = train[:, :-pred_len]
X_test = test[:, :-pred_len]
if shift_pred:
y_train = train[:, 1:]
y_test = test[:, 1:]
else:
y_train = train[:, -pred_len:]
y_test = test[:, -pred_len:]
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1],
1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
return [X_train, y_train, X_test, y_test, lower, scale]
注意,在这里我们也使用归一化,使用与上一章相同的归一化方法,以查看它是否可以改善我们的模型。 当使用训练模型进行预测时,我们还返回lower
和scale
值,这是非规范化所需的值。
现在我们可以调用load_data
来获取训练和测试数据集,以及lower
和scale
值:
X_train, y_train, X_test, y_test, lower, scale = load_data(symbol '.txt', seq_len, pred_len, shift_pred)
完整的模型构建代码如下:
代码语言:javascript复制model = Sequential()
model.add(Bidirectional(LSTM(num_neurons, return_sequences=True, input_shape=(None, 1)), input_shape=(seq_len, 1)))
model.add(Dropout(0.2))
model.add(LSTM(num_neurons, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(num_neurons, return_sequences=False))
model.add(Dropout(0.2))
if shift_pred:
model.add(Dense(units=seq_len))
else:
model.add(Dense(units=pred_len))
model.add(Activation('linear'))
model.compile(loss='mse', optimizer='rmsprop')
model.fit(
X_train,
y_train,
batch_size=512,
epochs=epochs,
validation_split=0.05)
print(model.output.op.name)
print(model.input.op.name)
即使使用新添加的Bidirectional
,Dropout
,validation_split
和堆叠 LSTM 层,该代码也比 TensorFlow 中的模型构建代码更容易解释和简化。 请注意,LSTM 调用中的return_sequences
参数i
必须为True
,因此 LSTM 单元的输出将是完整的输出序列,而不仅仅是输出序列中的最后一个输出, 除非它是最后的堆叠层。 最后两个 print
语句将打印输入节点名称( bidirectional_1_input
)和输出节点名称(activation_1/Identity
),当我们冻结模型并在移动设备上运行模型时需要。
现在,如果您运行前面的代码,您将看到如下输出:
代码语言:javascript复制824/824 [==============================] - 7s 9ms/step - loss: 0.0833 - val_loss: 0.3831
Epoch 2/10
824/824 [==============================] - 2s 3ms/step - loss: 0.2546 - val_loss: 0.0308
Epoch 3/10
824/824 [==============================] - 2s 2ms/step - loss: 0.0258 - val_loss: 0.0098
Epoch 4/10
824/824 [==============================] - 2s 2ms/step - loss: 0.0085 - val_loss: 0.0035
Epoch 5/10
824/824 [==============================] - 2s 2ms/step - loss: 0.0044 - val_loss: 0.0026
Epoch 6/10
824/824 [==============================] - 2s 2ms/step - loss: 0.0038 - val_loss: 0.0022
Epoch 7/10
824/824 [==============================] - 2s 2ms/step - loss: 0.0033 - val_loss: 0.0019
Epoch 8/10
824/824 [==============================] - 2s 2ms/step - loss: 0.0030 - val_loss: 0.0019
Epoch 9/10
824/824 [==============================] - 2s 2ms/step - loss: 0.0028 - val_loss: 0.0017
Epoch 10/10
824/824 [==============================] - 2s 3ms/step - loss: 0.0027 - val_loss: 0.0019
训练损失和验证损失都可以通过简单调用model.fit
进行打印。
测试 Keras RNN 模型
现在该保存模型检查点并使用测试数据集来计算正确预测的数量,正如我们在上一节中所解释的那样:
代码语言:javascript复制saver = tf.train.Saver()
saver.save(K.get_session(), '/tmp/keras_' symbol '.ckpt')
predictions = []
correct = 0
total = pred_len*len(X_test)
for i in range(len(X_test)):
input = X_test[i]
y_pred = model.predict(input.reshape(1, seq_len, 1))
predictions.append(scale * y_pred[0][-1] lower)
if shift_pred:
if y_test[i][-1] >= input[-1][0] and y_pred[0][-1] >= input[-1]
[0]:
correct = 1
elif y_test[i][-1] < input[-1][0] and y_pred[0][-1] < input[-1][0]:
correct = 1
else:
for j in range(len(y_test[i])):
if y_test[i][j] >= input[-1][0] and y_pred[0][j] >= input[-1][0]:
correct = 1
elif y_test[i][j] < input[-1][0] and y_pred[0][j] < input[-1][0]:
correct = 1
我们主要调用model.predict
来获取X_test
中每个实例的预测,并将其与真实值和前一天的价格一起使用,以查看在方向方面是否为正确的预测。 最后,让我们根据测试数据集和预测来绘制真实价格:
y_test = scale * y_test lower
y_test = y_test[:, -1]
xs = [i for i, _ in enumerate(y_test)]
plt.plot(xs, y_test, 'g-', label='true')
plt.plot(xs, predictions, 'r-', label='prediction')
plt.legend(loc=0)
if shift_pred:
plt.title("%s - epochs=%d, shift_pred=True, seq_len=%d: %d/%d=%.2f%%" %(symbol, epochs, seq_len, correct, total, 100*float(correct)/total))
else:
plt.title("%s - epochs=%d, lens=%d,%d: %d/%d=%.2f%%" %(symbol, epochs, seq_len, pred_len, correct, total, 100*float(correct)/total))
plt.show()
您会看到类似图 8.2 的内容:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tZ7FK7s8-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/b8227667-1d0d-4ea7-bf0b-fa7dac192064.png)]
图 8.2:使用 Keras 双向和堆叠 LSTM 层进行股价预测
很容易在栈中添加更多 LSTM 层,或者使用诸如学习率和丢弃率以及许多恒定设置之类的超参数。 但是,对于使用pred_len
和shift_pred
的不同设置,正确率的差异还没有发现。 也许我们现在应该对接近 60% 的正确率感到满意,并看看如何在 iOS 和 Android 上使用 TensorFlow 和 Keras 训练的模型-我们可以在以后继续尝试改进模型,但是,了解使用 TensorFlow 和 Keras 训练的 RNN 模型是否会遇到任何问题将非常有价值。
正如 FrançoisChollet 指出的那样,“深度学习更多的是艺术而不是科学……每个问题都是独特的,您将不得不尝试并经验地评估不同的策略。目前尚无理论可以提前准确地告诉您应该做什么。 以最佳方式解决问题。您必须尝试并进行迭代。” 希望我们为您使用 TensorFlow 和 Keras API 改善股票价格预测模型提供了一个很好的起点。
本节中最后要做的就是从检查点冻结 Keras 模型-因为我们在虚拟环境中安装了 TensorFlow 和 Keras,而 TensorFlow 是 VirtualEnv 中唯一安装并受支持的深度学习库,Keras 使用 TensorFlow 后端,并通过saver.save(K.get_session(), '/tmp/keras_' symbol '.ckpt')
调用以 TensorFlow 格式生成检查点。 现在运行以下命令冻结检查点(回想我们在训练期间从print(model.input.op.name)
获得output_node_name
):
python tensorflow/python/tools/freeze_graph.py --input_meta_graph=/tmp/keras_amzn.ckpt.meta --input_checkpoint=/tmp/keras_amzn.ckpt --output_graph=/tmp/amzn_keras_frozen.pb --output_node_names="activation_1/Identity" --input_binary=true
因为我们的模型非常简单明了,所以我们将直接在移动设备上尝试这两个冻结的模型,而无需像前两章中那样使用transform_graph
工具。
在 iOS 上运行 TensorFlow 和 Keras 模型
我们不会通过重复项目设置步骤来烦您-只需按照我们之前的操作即可创建一个名为 StockPrice 的新 Objective-C 项目,该项目将使用手动构建的 TensorFlow 库(请参阅第 7 章,“使用 CNN 和 LSTM 识别绘画”的 iOS 部分(如果需要详细信息)。 然后将两个模型文件amzn_tf_frozen.pb
和amzn_keras_frozen.pb
添加到项目中,您应该在 Xcode 中拥有 StockPrice 项目,如图 8.3 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VZGO2jEI-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/454020d6-1f32-480d-b159-b474b6878540.png)]
图 8.3:在 Xcode 中使用 TensorFlow 和 Keras 训练的模型的 iOS 应用
在ViewController.mm
中,我们将首先声明一些变量和一个常量:
unique_ptr<tensorflow::Session> tf_session;
UITextView *_tv;
UIButton *_btn;
NSMutableArray *_closeprices;
const int SEQ_LEN = 20;
然后创建一个按钮点击处理器,以使用户可以选择 TensorFlow 或 Keras 模型(该按钮在viewDidLoad
方法中像以前一样创建):
- (IBAction)btnTapped:(id)sender {
UIAlertAction* tf = [UIAlertAction actionWithTitle:@"Use TensorFlow Model" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {
[self getLatestData:NO];
}];
UIAlertAction* keras = [UIAlertAction actionWithTitle:@"Use Keras Model" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {
[self getLatestData:YES];
}];
UIAlertAction* none = [UIAlertAction actionWithTitle:@"None" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {}];
UIAlertController* alert = [UIAlertController alertControllerWithTitle:@"RNN Model Pick" message:nil preferredStyle:UIAlertControllerStyleAlert];
[alert addAction:tf];
[alert addAction:keras];
[alert addAction:none];
[self presentViewController:alert animated:YES completion:nil];
}
getLatestData
方法首先发出 URL 请求以获取紧凑型版本的 Alpha Vantage API,该 API 返回 Amazon 每日股票数据的最后 100 个数据点,然后解析结果并将最后 20 个收盘价保存在_closeprices
数组中:
-(void)getLatestData:(BOOL)useKerasModel {
NSURLSession *session = [NSURLSession sharedSession];
[[session dataTaskWithURL:[NSURL URLWithString:@"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=amzn&apikey=<your_api_key>&datatype=csv&outputsize=compact"]
completionHandler:^(NSData *data,
NSURLResponse *response,
NSError *error) {
NSString *stockinfo = [[NSString alloc] initWithData:data encoding:NSASCIIStringEncoding];
NSArray *lines = [stockinfo componentsSeparatedByString:@"n"];
_closeprices = [NSMutableArray array];
for (int i=0; i<SEQ_LEN; i ) {
NSArray *items = [lines[i 1] componentsSeparatedByString:@","];
[_closeprices addObject:items[4]];
}
if (useKerasModel)
[self runKerasModel];
else
[self runTFModel];
}] resume];
}
runTFModel
方法定义如下:
- (void) runTFModel {
tensorflow::Status load_status;
load_status = LoadModel(@"amzn_tf_frozen", @"pb", &tf_session);
tensorflow::Tensor prices(tensorflow::DT_FLOAT,
tensorflow::TensorShape({1, SEQ_LEN, 1}));
auto prices_map = prices.tensor<float, 3>();
NSString *txt = @"Last 20 Days:n";
for (int i = 0; i < SEQ_LEN; i ){
prices_map(0,i,0) = [_closeprices[SEQ_LEN-i-1] floatValue];
txt = [NSString stringWithFormat:@"%@%@n", txt,
_closeprices[SEQ_LEN-i-1]];
}
std::vector<tensorflow::Tensor> output;
tensorflow::Status run_status = tf_session->Run({{"Placeholder",
prices}}, {"preds"}, {}, &output);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed:" << run_status;
}
else {
tensorflow::Tensor preds = output[0];
auto preds_map = preds.tensor<float, 2>();
txt = [NSString stringWithFormat:@"%@nPrediction with TF RNN
model:n%f", txt, preds_map(0,SEQ_LEN-1)];
dispatch_async(dispatch_get_main_queue(), ^{
[_tv setText:txt];
[_tv sizeToFit];
});
}
}
preds_map(0,SEQ_LEN-1)
是基于最近 20 天的第二天的预测价格; Placeholder
是“在 TensorFlow 中训练 RNN 模型”小节的第四步的X = tf.placeholder(tf.float32, [None, seq_len, num_inputs])
中定义的输入节点名称。 在模型生成预测后,我们将其与最近 20 天的价格一起显示在TextView
中。
runKeras
方法的定义与此类似,但具有反规范化以及不同的输入和输出节点名称。 由于我们的 Keras 模型经过训练只能输出一个预测价格,而不是一系列seq_len
价格,因此我们使用preds_map(0,0)
来获得预测:
- (void) runKerasModel {
tensorflow::Status load_status;
load_status = LoadModel(@"amzn_keras_frozen", @"pb", &tf_session);
if (!load_status.ok()) return;
tensorflow::Tensor prices(tensorflow::DT_FLOAT,
tensorflow::TensorShape({1, SEQ_LEN, 1}));
auto prices_map = prices.tensor<float, 3>();
float lower = 5.97;
float scale = 1479.37;
NSString *txt = @"Last 20 Days:n";
for (int i = 0; i < SEQ_LEN; i ){
prices_map(0,i,0) = ([_closeprices[SEQ_LEN-i-1] floatValue] -
lower)/scale;
txt = [NSString stringWithFormat:@"%@%@n", txt,
_closeprices[SEQ_LEN-i-1]];
}
std::vector<tensorflow::Tensor> output;
tensorflow::Status run_status = tf_session->Run({{"bidirectional_1_input", prices}}, {"activation_1/Identity"},
{}, &output);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed:" << run_status;
}
else {
tensorflow::Tensor preds = output[0];
auto preds_map = preds.tensor<float, 2>();
txt = [NSString stringWithFormat:@"%@nPrediction with Keras
RNN model:n%f", txt, scale * preds_map(0,0) lower];
dispatch_async(dispatch_get_main_queue(), ^{
[_tv setText:txt];
[_tv sizeToFit];
});
}
}
如果您现在运行该应用并点击Predict
按钮,您将看到模型选择消息(图 8.4):
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VMADgb2i-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/067fa448-a497-4f13-944f-e976240694c9.png)]
图 8.4:选择 TensorFlow 或 Keras RNN 模型
如果选择 TensorFlow 模型,则可能会出现错误:
代码语言:javascript复制Could not create TensorFlow Graph: Invalid argument: No OpKernel was registered to support Op 'Less' with these attrs. Registered devices: [CPU], Registered kernels:
device='CPU'; T in [DT_FLOAT]
[[Node: rnn/while/Less = Less[T=DT_INT32, _output_shapes=[[]]](rnn/while/Merge, rnn/while/Less/Enter)]]
如果选择 Keras 模型,则可能会出现稍微不同的错误:
代码语言:javascript复制Could not create TensorFlow Graph: Invalid argument: No OpKernel was registered to support Op 'Less' with these attrs. Registered devices: [CPU], Registered kernels:
device='CPU'; T in [DT_FLOAT]
[[Node: bidirectional_1/while_1/Less = Less[T=DT_INT32, _output_shapes=[[]]](bidirectional_1/while_1/Merge, bidirectional_1/while_1/Less/Enter)]]
我们在上一章中已经看到RefSwitch
操作出现类似的错误,并且知道针对此类错误的解决方法是在启用 -D__ANDROID_TYPES_FULL__
的情况下构建 TensorFlow 库。 如果没有看到这些错误,则意味着您在上一章的 iOS 应用中已建立了这样的库; 否则,请按照“为 iOS 构建自定义 TensorFlow 库”的开头的说明。 上一章的内容构建新的 TensorFlow 库,然后再次运行该应用。
现在选择 TensorFlow 模型,您将看到如图 8.5 所示的结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j0y2fPol-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/28715473-1f89-4f14-8e81-f98242a11c8d.png)]
图 8.5:使用 TensorFlow RNN 模型进行预测
使用 Keras 模型输出不同的预测,如图 8.6 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cET3Uh9U-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/4114a6b9-15da-41f7-b535-bf04c7fb4dad.png)]
图 8.6:使用 Keras RNN 模型进行预测
我们无法确定哪个模型能在没有进一步研究的情况下更好地工作,但是我们可以确定的是,我们的两个 RNN 模型都使用 TensorFlow 和 Keras API 从头开始训练了,其准确率接近 60%, 在 iOS 上运行良好,这很值得我们付出努力,因为我们正在尝试建立一个许多专家认为将达到与随机选择相同的表现的模型,并且在此过程中,我们学到了一些新奇的东西-使用 TensorFlow 和 Keras 构建 RNN 模型并在 iOS 上运行它们。 在下一章中,我们只剩下一件事了:如何在 Android 上使用模型? 我们会遇到新的障碍吗?
在 Android 上运行 TensorFlow 和 Keras 模型
事实证明,这就像使用 Android 上的模型在沙滩上散步一样-尽管我们必须使用自定义的 TensorFlow 库(而不是 TensorFlow pod),我们甚至不需要像上一章那样使用自定义的 TensorFlow Android 库。 截至 2018 年 2 月)。 与用于 iOS 的 TensorFlow Pod 相比,在build.gradle
文件中使用compile 'org.tensorflow:tensorflow-android: '
构建的 TensorFlow Android 库必须对Less
操作具有更完整的数据类型支持。
要在 Android 中测试模型,请创建一个新的 Android 应用 StockPrice,并将两个模型文件添加到其assets
文件夹中。 然后在布局中添加几个按钮和一个TextView
并在MainActivity.java
中定义一些字段和常量:
private static final String TF_MODEL_FILENAME = "file:///android_asset/amzn_tf_frozen.pb";
private static final String KERAS_MODEL_FILENAME = "file:///android_asset/amzn_keras_frozen.pb";
private static final String INPUT_NODE_NAME_TF = "Placeholder";
private static final String OUTPUT_NODE_NAME_TF = "preds";
private static final String INPUT_NODE_NAME_KERAS = "bidirectional_1_input";
private static final String OUTPUT_NODE_NAME_KERAS = "activation_1/Identity";
private static final int SEQ_LEN = 20;
private static final float LOWER = 5.97f;
private static final float SCALE = 1479.37f;
private TensorFlowInferenceInterface mInferenceInterface;
private Button mButtonTF;
private Button mButtonKeras;
private TextView mTextView;
private boolean mUseTFModel;
private String mResult;
制作onCreate
如下:
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mButtonTF = findViewById(R.id.tfbutton);
mButtonKeras = findViewById(R.id.kerasbutton);
mTextView = findViewById(R.id.textview);
mTextView.setMovementMethod(new ScrollingMovementMethod());
mButtonTF.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
mUseTFModel = true;
Thread thread = new Thread(MainActivity.this);
thread.start();
}
});
mButtonKeras.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
mUseTFModel = false;
Thread thread = new Thread(MainActivity.this);
thread.start();
}
});
}
其余代码全部在run
方法中,在点击TF PREDICTION
或KERAS PREDICTION
按钮时在工作线程中启动,需要一些解释,使用 Keras 模型需要在运行模型之前和之后规范化和非规范化:
public void run() {
runOnUiThread(
new Runnable() {
@Override
public void run() {
mTextView.setText("Getting data...");
}
});
float[] floatValues = new float[SEQ_LEN];
try {
URL url = new URL("https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=amzn&apikey=4SOSJM2XCRIB5IUS&datatype=csv&outputsize=compact");
HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection();
InputStream in = new BufferedInputStream(urlConnection.getInputStream());
Scanner s = new Scanner(in).useDelimiter("\n");
mResult = "Last 20 Days:n";
if (s.hasNext()) s.next(); // get rid of the first title line
List<String> priceList = new ArrayList<>();
while (s.hasNext()) {
String line = s.next();
String[] items = line.split(",");
priceList.add(items[4]);
}
for (int i=0; i<SEQ_LEN; i )
mResult = priceList.get(SEQ_LEN-i-1) "n";
for (int i=0; i<SEQ_LEN; i ) {
if (mUseTFModel)
floatValues[i] = Float.parseFloat(priceList.get(SEQ_LEN-i-1));
else
floatValues[i] = (Float.parseFloat(priceList.get(SEQ_LEN-i-1)) - LOWER) / SCALE;
}
AssetManager assetManager = getAssets();
mInferenceInterface = new TensorFlowInferenceInterface(assetManager, mUseTFModel ? TF_MODEL_FILENAME : KERAS_MODEL_FILENAME);
mInferenceInterface.feed(mUseTFModel ? INPUT_NODE_NAME_TF : INPUT_NODE_NAME_KERAS, floatValues, 1, SEQ_LEN, 1);
float[] predictions = new float[mUseTFModel ? SEQ_LEN : 1];
mInferenceInterface.run(new String[] {mUseTFModel ? OUTPUT_NODE_NAME_TF : OUTPUT_NODE_NAME_KERAS}, false);
mInferenceInterface.fetch(mUseTFModel ? OUTPUT_NODE_NAME_TF : OUTPUT_NODE_NAME_KERAS, predictions);
if (mUseTFModel) {
mResult = "nPrediction with TF RNN model:n" predictions[SEQ_LEN - 1];
}
else {
mResult = "nPrediction with Keras RNN model:n" (predictions[0] * SCALE LOWER);
}
runOnUiThread(
new Runnable() {
@Override
public void run() {
mTextView.setText(mResult);
}
});
} catch (Exception e) {
e.printStackTrace();
}
}
现在运行该应用,然后点击TF PREDICTION
按钮,您将在图 8.7 中看到结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yZX2YsPg-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/8ce6016b-658a-46c4-a051-04c772d503fc.png)]
图 8.7:使用 TensorFlow 模型在亚马逊上进行股价预测
选择 KERAS 预测将为您提供如图 8.8 所示的结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1l5o0sai-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/ec5b48b2-0c2a-439a-af59-303b9c628019.png)]
图 8.8:使用 Keras 模型在亚马逊上进行股价预测
总结
在本章中,我们首先对表示不屑一顾,试图通过使用 TensorFlow 和 Keras RNN API 预测股价来击败市场。 我们首先讨论了 RNN 和 LSTM 模型是什么以及如何使用它们进行股价预测。 然后,我们使用 TensorFlow 和 Keras 从零开始构建了两个 RNN 模型,接近测试正确率的 60%。 最后,我们介绍了如何冻结模型并在 iOS 和 Android 上使用它们,并使用自定义 TensorFlow 库修复了 iOS 上可能出现的运行时错误。
如果您对我们尚未建立预测正确率为 80% 或 90% 的模型感到有些失望,则可能需要继续进行“尝试并迭代”过程,以查看是否可以以该正确率预测股票价格。 但是,您肯定会从使用 TensorFlow 和 Keras API 的 RNN 模型构建,训练和测试中学到的技能以及在 iOS 和 Android 上运行的技能而受益。
如果您对使用深度学习技术打败市场感兴趣并感到兴奋,让我们在 GAN(生成对抗网络)上的下一章中进行研究,该模型试图击败能够分辨真实数据与虚假数据之间差异的对手, 并且越来越擅长生成看起来像真实数据的数据,欺骗对手。 GAN 实际上被深度学习的一些顶级研究人员誉为是过去十年中深度学习中最有趣和令人兴奋的想法。
九、使用 GAN 生成和增强图像
自 2012 年深度学习起步以来,有人认为 Ian Goodfellow 在 2014 年提出的生成对抗网络(GAN)比这更有趣或更有前途。 实际上, Facebook AI 研究主管和之一,深度学习研究人员之一的 Yann LeCun 将 GAN 和对抗训练称为,“这是近十年来机器学习中最有趣的想法。” 因此,我们如何在这里不介绍它,以了解 GAN 为什么如此令人兴奋,以及如何构建 GAN 模型并在 iOS 和 Android 上运行它们?
在本章中,我们将首先概述 GAN 是什么,它如何工作以及为什么它具有如此巨大的潜力。 然后,我们将研究两个 GAN 模型:一个基本的 GAN 模型可用于生成类似人的手写数字,另一个更高级的 GAN 模型可将低分辨率的图像增强为高分辨率的图像。 我们将向您展示如何在 Python 和 TensorFlow 中构建和训练此类模型,以及如何为移动部署准备模型。 然后,我们将提供带有完整源代码的 iOS 和 Android 应用,它们使用这些模型来生成手写数字并增强图像。 在本章的最后,您应该准备好进一步探索各种基于 GAN 的模型,或者开始构建自己的模型,并了解如何在移动应用中运行它们。
总之,本章将涵盖以下主题:
- GAN – 什么以及为什么
- 使用 TensorFlow 构建和训练 GAN 模型
- 在 iOS 中使用 GAN 模型
- 在 Android 中使用 GAN 模型
GAN – 什么以及为什么
GAN 是学习生成类似于真实数据或训练集中数据的神经网络。 GAN 的关键思想是让生成器网络和判别器网络相互竞争:生成器试图生成看起来像真实数据的数据,而判别器试图分辨生成的数据是否真实(从已知真实数据)或伪造(由生成器生成)。 生成器和判别器是一起训练的,在训练过程中,生成器学会生成看起来越来越像真实数据的数据,而判别器则学会将真实数据与伪数据区分开。 生成器通过尝试使判别器的输出概率为真实数据来学习,当将生成器的输出作为判别器的输入时,生成器的输出概率尽可能接近 1.0,而判别器通过尝试实现两个目标来学习:
- 当以生成器的输出作为输入时,使其输出的可能性为实,尽可能接近 0.0,这恰好是生成器的相反目标
- 当输入真实数据作为输入时,使其输出的可能性为实数,尽可能接近 1.0
在下一节中,您将看到与生成器和判别器网络及其训练过程的给定描述相匹配的详细代码片段。 如果您想了解更多关于 GAN 的知识,除了这里的摘要概述之外,您还可以在 YouTube 上搜索“GAN 简介”,并观看 2016 年 NIPS(神经信息处理系统)和 ICCV(国际计算机视觉会议)2017 大会上的 Ian Goodfellow 的 GAN 入门和教程视频。 事实上,YouTube 上有 7 个 NIPS 2016 对抗训练训练班视频和 12 个 ICCV 2017 GAN 指导视频,您可以自己投入其中。
在生成器和判别器两个参与者的竞争目标下,GAN 是一个寻求两个对手之间保持平衡的系统。 如果两个玩家都具有无限的能力并且可以进行最佳训练,那么纳什均衡(继 1994 年诺贝尔经济学奖得主约翰·纳什和电影主题《美丽心灵》之后) 一种状态,在这种状态下,任何玩家都无法通过仅更改其自己的策略来获利,这对应于生成器生成数据的状态,该数据看起来像真实数据,而判别器无法从假数据中分辨真实数据。
如果您有兴趣了解有关纳什均衡的更多信息,请访问 Google “可汗学院纳什均衡”,并观看 Sal Khan 撰写的两个有趣的视频。 《经济学家》解释经济学的“纳什均衡”维基百科页面和文章“纳什均衡是什么,为什么重要?”也是不错的读物。 了解 GAN 的基本直觉和想法将有助于您进一步了解 GAN 具有巨大潜力的原因。
生成器能够生成看起来像真实数据的数据的潜力意味着可以使用 GAN 开发各种出色的应用,例如:
- 从劣质图像生成高质量图像
- 图像修复(修复丢失或损坏的图像)
- 翻译图像(例如,从边缘草图到照片,或者在人脸上添加或移除诸如眼镜之类的对象)
- 从文本生成图像(和第 6 章,“使用自然语言描述图像”的 Text2Image 相反)
- 撰写看起来像真实新闻的新闻文章
- 生成与训练集中的音频相似的音频波形
基本上,GAN 可以从随机输入生成逼真的图像,文本或音频数据; 如果您具有一组源数据和目标数据的训练集,则 GAN 还可从类似于源数据的输入中生成类似于目标数据的数据。 GAN 模型中的生成器和判别器以动态方式工作的这一通用特性,使 GAN 可以生成任何种类的现实输出,这使 GAN 十分令人兴奋。
但是,由于生成器和判别器的动态或竞争目标,训练 GAN 达到纳什均衡状态是一个棘手且困难的问题。 实际上,这仍然是一个开放的研究问题 – Ian Goodfellow 在 2017 年 8 月对 Andrew Ng 进行的“深度学习英雄”采访中(YouTube 上的搜索ian goodfellow andrew ng
)说,如果我们可以使 GAN 变得像深度学习一样可靠,我们将看到 GAN 取得更大的成功,否则我们最终将用其他形式的生成模型代替它们。
尽管在 GAN 的训练方面存在挑战,但是在训练期间您已经可以应用许多有效的已知技巧 – 我们在这里不会介绍它们,但是如果您有兴趣调整我们将在本章中描述的模型或许多其他 GAN 模型 ),或构建自己的 GAN 模型。
使用 TensorFlow 构建和训练 GAN 模型
通常,GAN 模型具有两个神经网络:G
用于生成器,D
用于判别器。 x
是来自训练集的一些实际数据输入,z
是随机输入噪声。 在训练过程中,D(x)
是x
为真实的概率,D
尝试使D(x)
接近 1;G(z)
是具有随机输入z
的生成的输出,并且D
试图使D(G(z))
接近 0,但同时G
试图使D(G(z))
接近 1。 现在,让我们首先来看一下如何在 TensorFlow 和 Python 中构建基本的 GAN 模型,该模型可以编写或生成手写数字。
生成手写数字的基本 GAN 模型
手写数字的训练模型基于仓库,这是这个页面的分支,并添加了显示生成的数字并使用输入占位符保存 TensorFlow 训练模型的脚本,因此我们的 iOS 和 Android 应用可以使用该模型。 是的您应该查看原始仓库的博客。在继续之前,需要对具有代码的 GAN 模型有基本的了解。
在研究定义生成器和判别器网络并进行 GAN 训练的核心代码片段之前,让我们先运行脚本以在克隆存储库并转到仓库目录之后训练和测试模型:
代码语言:javascript复制git clone https://github.com/jeffxtang/generative-adversarial-networks
cd generative-adversarial-networks
该派生向gan-script-fast.py
脚本添加了检查点保存代码,还添加了新脚本gan-script-test.py
以使用随机输入的占位符测试和保存新的检查点–因此,使用新检查点冻结的模型可以在 iOS 和 Android 应用中使用。
运行命令python gan-script-fast.py
训练模型,在 Ubuntu 上的 GTX-1070 GPU 上花费不到一小时。 训练完成后,检查点文件将保存在模型目录中。 现在运行python gan-script-test.py
来查看一些生成的手写数字。 该脚本还从模型目录读取检查点文件,并在运行gan-script-fast.py
时保存该文件,然后将更新的检查点文件以及随机输入占位符重新保存在newmodel
目录中:
ls -lt newmodel
-rw-r--r-- 1 jeffmbair staff 266311 Mar 5 16:43 ckpt.meta
-rw-r--r-- 1 jeffmbair staff 65 Mar 5 16:42 checkpoint
-rw-r--r-- 1 jeffmbair staff 69252168 Mar 5 16:42 ckpt.data-00000-of-00001
-rw-r--r-- 1 jeffmbair staff 2660 Mar 5 16:42 ckpt.index
gan-script-test.py
中的下一个代码片段显示了输入节点名称(z_placeholder
)和输出节点名称(Sigmoid_1
),如print(generated_images)
所示:
z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions], name='z_placeholder')
...
saver.restore(sess, 'model/ckpt')
generated_images = generator(z_placeholder, 5, z_dimensions)
print(generated_images)
images = sess.run(generated_images, {z_placeholder: z_batch})
saver.save(sess, "newmodel/ckpt")
在gan-script-fast.py
脚本中,方法def discriminator(images, reuse_variables=None)
定义了一个判别器网络,该网络使用一个真实的手写图像输入或由生成器生成的一个手写输入,经过一个典型的小型 CNN 网络,该网络具有两层conv2d
层,每一层都带有relu
激活和平均池化层以及两个完全连接的层来输出一个标量值,该标量值将保持输入图像为真或假的概率。 另一种方法def generator(batch_size, z_dim)
定义了生成器网络,该网络采用随机输入的图像向量并将其转换为具有 3 个conv2d
层的28 x 28
图像。
现在可以使用这两种方法来定义三个输出:
Gz
,即随机图像输入的生成器输出:Gz = generator(batch_size, z_dimensions)
Dx
,是真实图像输入的判别器输出:Dx = discriminator(x_placeholder)
Dg
,Gz
的判别器输出:Dg = discriminator(Gz, reuse_variables=True)
和三个损失函数:
d_loss_real
,Dx
和 1 之差:d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx)))
d_loss_fake
,Dg
和 0 之差:d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg)))
g_loss
,Dg
和 1 之差:g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg)))
请注意,判别器尝试使 d_loss_fake
最小化,而生成器尝试使g_loss
最小化,两种情况下Dg
之间的差分别为 0 和 1。
最后,现在可以为三个损失函数设置三个优化器:d_trainer_fake
,d_trainer_real
和g_trainer
,它们全部是通过tf.train.AdamOptimizer
的minimize
方法定义的。
现在,脚本仅创建一个 TensorFlow 会话,通过运行三个优化器将生成器和判别器进行 100,000 步训练,将随机图像输入馈入生成器,将真实和伪图像输入均馈入判别器。
在运行 gan-script-fast.py
和gan-script-test.py
之后,将检查点文件从newmodel
目录运至/tmp
,然后转到 TensorFlow 源根目录并运行:
python tensorflow/python/tools/freeze_graph.py
--input_meta_graph=/tmp/ckpt.meta
--input_checkpoint=/tmp/ckpt
--output_graph=/tmp/gan_mnist.pb
--output_node_names="Sigmoid_1"
--input_binary=true
这将创建可用于移动应用的冻结模型gan_mnist.pb
。 但是在此之前,让我们看一下可以增强低分辨率图像的更高级的 GAN 模型。
增强图像分辨率的高级 GAN 模型
我们将用于增强低分辨率模糊图像的模型,基于论文《使用条件对抗网络的图像到图像转换》及其 TensorFlow 实现 pix2pix。 在仓库的分支中,我们添加了两个脚本:
tools/convert.py
从普通图像创建模糊图像pix2pix_runinference.py
添加了一个用于低分辨率图像输入的占位符和一个用于返回增强图像的操作,并保存了新的检查点文件,我们将冻结这些文件以生成在移动设备上使用的模型文件。
基本上,pix2pix 使用 GAN 将输入图像映射到输出图像。 您可以使用不同类型的输入图像和输出图像来创建许多有趣的图像转换:
- 地图到航拍
- 白天到黑夜
- 边界到照片
- 黑白图像到彩色图像
- 损坏的图像到原始图像
- 从低分辨率图像到高分辨率图像
在所有情况下,生成器都将输入图像转换为输出图像,试图使输出看起来像真实的目标图像,判别器将训练集中的样本或生成器的输出作为输入,并尝试告诉它是真实图像还是生成器生成的图像。 自然,与模型相比,pix2pix 中的生成器和判别器网络以更复杂的方式构建以生成手写数字,并且训练还应用了一些技巧来使过程稳定-有关详细信息,您可以阅读本文或较早提供的 TensorFlow 实现链接。 我们在这里仅向您展示如何设置训练集和训练 pix2pix 模型以增强低分辨率图像。
- 通过在终端上运行来克隆仓库:
git clone https://github.com/jeffxtang/pix2pix-tensorflow
cd pix2pix-tensorflow
- 创建一个新目录
photos/original
并复制一些图像文件-例如,我们将所有拉布拉多犬的图片从斯坦福狗数据集(在第 2 章,“使用迁移学习的图像分类”中使用)复制到photos/original
目录 - 运行脚本
python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized
调整photo/original
目录中图像的大小并将调整后的图像保存到photos/resized
目录中 - 运行
mkdir photos/blurry
,然后运行python tools/convert.py
,以使用流行的 ImageMagick 的convert
命令将调整大小的图像转换为模糊的图像。convert.py
的代码如下:
import os
file_names = os.listdir("photos/resized/")
for f in file_names:
if f.find(".png") != -1:
os.system("convert photos/resized/" f " -blur 0x3 photos/blurry/" f)
- 将
photos/resized
和photos/blurry
中的每个文件合并为一个对,并将所有配对的图像(一个调整大小的图像,另一个模糊的版本)保存到photos/resized_blurry
目录:
python tools/process.py --input_dir photos/resized --b_dir photos/blurry --operation combine --output_dir photos/resized_blurry
- 运行拆分工具
python tools/split.py --dir photos/resized_blurry
,将文件转换为train
目录和val
目录 - 通过运行以下命令训练
pix2pix
模型:
python pix2pix.py
--mode train
--output_dir photos/resized_blurry/ckpt_1000
--max_epochs 1000
--input_dir photos/resized_blurry/train
--which_direction BtoA
方向BtoA
表示从模糊图像转换为原始图像。 在 GTX-1070 GPU 上进行的训练大约需要四个小时,并且photos/resized_blurry/ckpt_1000
目录中生成的检查点文件如下所示:
-rw-rw-r-- 1 jeff jeff 1721531 Mar 2 18:37 model-136000.meta
-rw-rw-r-- 1 jeff jeff 81 Mar 2 18:37 checkpoint
-rw-rw-r-- 1 jeff jeff 686331732 Mar 2 18:37 model-136000.data-00000-of-00001
-rw-rw-r-- 1 jeff jeff 10424 Mar 2 18:37 model-136000.index
-rw-rw-r-- 1 jeff jeff 3807975 Mar 2 14:19 graph.pbtxt
-rw-rw-r-- 1 jeff jeff 682 Mar 2 14:19 options.json
- (可选)您可以在测试模式下运行脚本,然后在
--output_dir
指定的目录中检查图像翻译结果:
python pix2pix.py
--mode test
--output_dir photos/resized_blurry/output_1000
--input_dir photos/resized_blurry/val
--checkpoint photos/resized_blurry/ckpt_1000
- 运行
pix2pix_runinference.py
脚本以恢复在步骤 7 中保存的检查点,为图像输入创建一个新的占位符,为它提供测试图像ww.png
,将翻译输出为result.png
,最后将新的检查点文件保存在newckpt
目录:
python pix2pix_runinference.py
--mode test
--output_dir photos/blurry_output
--input_dir photos/blurry_test
--checkpoint photos/resized_blurry/ckpt_1000
以下pix2pix_runinference.py
中的代码段设置并打印输入和输出节点:
image_feed = tf.placeholder(dtype=tf.float32, shape=(1, 256, 256, 3), name="image_feed")
print(image_feed) # Tensor("image_feed:0", shape=(1, 256, 256, 3), dtype=float32)
with tf.variable_scope("generator", reuse=True):
output_image = deprocess(create_generator(image_feed, 3))
print(output_image) #Tensor("generator_1/deprocess/truediv:0", shape=(1, 256, 256, 3), dtype=float32)
具有tf.variable_scope("generator", reuse=True):
的行非常重要,因为需要共享generator
变量,以便可以使用所有训练后的参数值。 否则,您会看到奇怪的翻译结果。
以下代码显示了如何在newckpt
目录中填充占位符,运行 GAN 模型并保存生成器的输出以及检查点文件:
if a.mode == "test":
from scipy import misc
image = misc.imread("ww.png").reshape(1, 256, 256, 3)
image = (image / 255.0) * 2 - 1
result = sess.run(output_image, feed_dict={image_feed:image})
misc.imsave("result.png", result.reshape(256, 256, 3))
saver.save(sess, "newckpt/pix2pix")
图 9.1 显示了原始测试图像,其模糊版本以及经过训练的 GAN 模型的生成器输出。 结果并不理想,但是 GAN 模型确实具有更好的分辨率而没有模糊效果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PNdQLBQU-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/b73deef3-1598-4019-ac72-5c0212d53c74.png)]
图 9.1:原始的,模糊的和生成的
- 现在,将
newckpt
目录复制到/tmp
,我们可以如下冻结模型:
python tensorflow/python/tools/freeze_graph.py
--input_meta_graph=/tmp/newckpt/pix2pix.meta
--input_checkpoint=/tmp/newckpt/pix2pix
--output_graph=/tmp/newckpt/pix2pix.pb
--output_node_names="generator_1/deprocess/truediv"
--input_binary=true
- 生成的
pix2pix.pb
模型文件很大,约为 217MB,将其加载到 iOS 或 Android 设备上时会崩溃或导致内存不足(OOM)错误。 我们必须像在第 6 章,“使用自然语言描述图像”的复杂 im2txt 模型中所做的那样,将其转换为 iOS 的映射格式。
bazel-bin/tensorflow/tools/graph_transforms/transform_graph
--in_graph=/tmp/newckpt/pix2pix.pb
--out_graph=/tmp/newckpt/pix2pix_transformed.pb
--inputs="image_feed"
--outputs="generator_1/deprocess/truediv"
--transforms='strip_unused_nodes(type=float, shape="1,256,256,3")
fold_constants(ignore_errors=true, clear_output_shapes=true)
fold_batch_norms
fold_old_batch_norms'
bazel-bin/tensorflow/contrib/util/convert_graphdef_memmapped_format
--in_graph=/tmp/newckpt/pix2pix_transformed.pb
--out_graph=/tmp/newckpt/pix2pix_transformed_memmapped.pb
pix2pix_transformed_memmapped.pb
模型文件现在可以在 iOS 中使用。
- 要为 Android 构建模型,我们需要量化冻结的模型,以将模型大小从 217MB 减少到约 54MB:
bazel-bin/tensorflow/tools/graph_transforms/transform_graph
--in_graph=/tmp/newckpt/pix2pix.pb
--out_graph=/tmp/newckpt/pix2pix_transformed_quantized.pb --inputs="image_feed"
--outputs="generator_1/deprocess/truediv"
--transforms='quantize_weights'
现在,让我们看看如何在移动应用中使用两个 GAN 模型。
在 iOS 中使用 GAN 模型
如果您尝试在 iOS 应用中使用 TensorFlow 窗格并加载gan_mnist.pb
文件,则会收到错误消息:
Could not create TensorFlow Graph: Invalid argument: No OpKernel was registered to support Op 'RandomStandardNormal' with these attrs. Registered devices: [CPU], Registered kernels:
<no registered kernels>
[[Node: z_1/RandomStandardNormal = RandomStandardNormal[T=DT_INT32, _output_shapes=[[50,100]], dtype=DT_FLOAT, seed=0, seed2=0](z_1/shape)]]
将行添加到tf_op_files.txt
之后,请确保tensorflow/contrib/makefile/tf_op_files.txt
文件具有tensorflow/core/kernels/random_op.cc
,该文件实现了RandomStandardNormal
操作,并且libtensorflow-core.a
是由 tensorflow/contrib/makefile/build_all_ios.sh
构建的。
此外,如果即使在使用 TensorFlow 1.4 构建的自定义 TensorFlow 库中尝试加载pix2pix_transformed_memmapped.pb
,也会出现以下错误:
No OpKernel was registered to support Op 'FIFOQueueV2' with these attrs. Registered devices: [CPU], Registered kernels:
<no registered kernels>
[[Node: batch/fifo_queue = FIFOQueueV2[_output_shapes=[[]], capacity=32, component_types=[DT_STRING, DT_FLOAT, DT_FLOAT], container="", shapes=[[], [256,256,1], [256,256,2]], shared_name=""]()]]
您需要将tensorflow/core/kernels/fifo_queue_op.cc
添加到tf_op_files.txt
并重建 iOS 库。 但是,如果您使用 TensorFlow 1.5 或 1.6,则tensorflow/core/kernels/fifo_queue_op.cc
文件已经添加到tf_op_files.txt
文件中。 在每个新版本的 TensorFlow 中,默认情况下,越来越多的内核被添加到tf_op_files.txt
。
借助为模型构建的 TensorFlow iOS 库,让我们在 Xcode 中创建一个名为 GAN 的新项目,并像在第 8 章,“使用 RNN 预测股价”一样在该项目中设置 TensorFlow。 以及其他不使用 TensorFlow 窗格的章节。 然后将两个模型文件gan_mnist.pb
和pix2pix_transformed_memmapped.pb
以及一个测试图像拖放到项目中。 另外,将第 6 章,“使用自然语言描述图像”的 iOS 项目中的tensorflow_utils.h
, tensorflow_utils.mm
,ios_image_load.h
和 ios_image_load.mm
文件复制到 GAN 项目。 将ViewController.m
重命名为ViewController.mm
。
现在,您的 Xcode 应该类似于图 9.2:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-otj0WMJO-1681653119038)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/9cb1fa66-5a4e-4a1b-aa51-a16fd9051f57.png)]
图 9.2:在 Xcode 中显示 GAN 应用
我们将创建一个按钮,在点击该按钮时,提示用户选择一个模型以生成数字或增强图像:
代码语言:javascript复制- (IBAction)btnTapped:(id)sender {
UIAlertAction* mnist = [UIAlertAction actionWithTitle:@"Generate Digits" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {
_iv.image = NULL;
dispatch_async(dispatch_get_global_queue(0, 0), ^{
NSArray *arrayGreyscaleValues = [self runMNISTModel];
dispatch_async(dispatch_get_main_queue(), ^{
UIImage *imgDigit = [self createMNISTImageInRect:_iv.frame values:arrayGreyscaleValues];
_iv.image = imgDigit;
});
});
}];
UIAlertAction* pix2pix = [UIAlertAction actionWithTitle:@"Enhance Image" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {
_iv.image = [UIImage imageNamed:image_name];
dispatch_async(dispatch_get_global_queue(0, 0), ^{
NSArray *arrayRGBValues = [self runPix2PixBlurryModel];
dispatch_async(dispatch_get_main_queue(), ^{
UIImage *imgTranslated = [self createTranslatedImageInRect:_iv.frame values:arrayRGBValues];
_iv.image = imgTranslated;
});
});
}];
UIAlertAction* none = [UIAlertAction actionWithTitle:@"None" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {}];
UIAlertController* alert = [UIAlertController alertControllerWithTitle:@"Use GAN to" message:nil preferredStyle:UIAlertControllerStyleAlert];
[alert addAction:mnist];
[alert addAction:pix2pix];
[alert addAction:none];
[self presentViewController:alert animated:YES completion:nil];
}
这里的代码非常简单。 应用的主要功能通过以下四种方法实现: runMNISTModel
, runPix2PixBlurryModel
, createMNISTImageInRect
和 createTranslatedImageInRect
。
使用基本 GAN 模型
在runMNISTModel
中,我们调用辅助方法LoadModel
来加载 GAN 模型,然后将输入张量设置为具有正态分布(均值 0.0 和 std 1.0)的 100 个随机数的 6 批。 该模型期望具有正态分布的随机输入。 您可以将 6 更改为任何其他数字,然后取回该数字的生成位数:
- (NSArray*) runMNISTModel {
tensorflow::Status load_status;
load_status = LoadModel(@"gan_mnist", @"pb", &tf_session);
if (!load_status.ok()) return NULL;
std::string input_layer = "z_placeholder";
std::string output_layer = "Sigmoid_1";
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({6, 100}));
auto input_map = input_tensor.tensor<float, 2>();
unsigned seed = (unsigned)std::chrono::system_clock::now().time_since_epoch().count();
std::default_random_engine generator (seed);
std::normal_distribution<double> distribution(0.0, 1.0);
for (int i = 0; i < 6; i ){
for (int j = 0; j < 100; j ){
double number = distribution(generator);
input_map(i,j) = number;
}
}
runMNISTModel
方法中的其余代码运行模型,获得6 * 28 * 28
浮点数的输出,表示每批像素大小为28 * 28
的图像在每个像素处的灰度值,并调用方法createMNISTImageInRect
,以便在将图像上下文转换为UIImage
之前,先使用 UIBezierPath
在图像上下文中呈现数字,然后将其返回并显示在UIImageView
中:
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = tf_session->Run({{input_layer, input_tensor}},
{output_layer}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed: " << run_status;
return NULL;
}
tensorflow::string status_string = run_status.ToString();
tensorflow::Tensor* output_tensor = &outputs[0];
const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& output = output_tensor->flat<float>();
const long count = output.size();
NSMutableArray *arrayGreyscaleValues = [NSMutableArray array];
for (int i = 0; i < count; i) {
const float value = output(i);
[arrayGreyscaleValues addObject:[NSNumber numberWithFloat:value]];
}
return arrayGreyscaleValues;
}
createMNISTImageInRect
的定义如下-我们在第 7 章,“使用 CNN 和 LSTM 识别绘画”中使用了类似的技术:
- (UIImage *)createMNISTImageInRect:(CGRect)rect values:(NSArray*)greyscaleValues
{
UIGraphicsBeginImageContextWithOptions(CGSizeMake(rect.size.width, rect.size.height), NO, 0.0);
int i=0;
const int size = 3;
for (NSNumber *val in greyscaleValues) {
float c = [val floatValue];
int x = i(;
int y = i/28;
i ;
CGRect rect = CGRectMake(145 size*x, 50 y*size, size, size);
UIBezierPath *path = [UIBezierPath bezierPathWithRect:rect];
UIColor *color = [UIColor colorWithRed:c green:c blue:c alpha:1.0];
[color setFill];
[path fill];
}
UIImage *image = UIGraphicsGetImageFromCurrentImageContext();
UIGraphicsEndImageContext();
return image;
}
对于每个像素,我们绘制一个宽度和高度均为 3 的小矩形,并为该像素返回灰度值。
使用高级 GAN 模型
在runPix2PixBlurryModel
方法中,我们使用LoadMemoryMappedModel
方法加载pix2pix_transformed_memmapped.pb
模型文件,并加载测试图像并设置输入张量,其方式与第 4 章,“以惊人的艺术样式迁移图片”相同:
- (NSArray*) runPix2PixBlurryModel {
tensorflow::Status load_status;
load_status = LoadMemoryMappedModel(@"pix2pix_transformed_memmapped", @"pb", &tf_session, &tf_memmapped_env);
if (!load_status.ok()) return NULL;
std::string input_layer = "image_feed";
std::string output_layer = "generator_1/deprocess/truediv";
NSString* image_path = FilePathForResourceName(@"ww", @"png");
int image_width;
int image_height;
int image_channels;
std::vector<tensorflow::uint8> image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels);
然后我们运行模型,获得256 * 256 * 3
(图像大小为256 * 256
,RGB 具有 3 个值)浮点数的输出,并调用createTranslatedImageInRect
将数字转换为UIImage
:
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = tf_session->Run({{input_layer, image_tensor}},
{output_layer}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed: " << run_status;
return NULL;
}
tensorflow::string status_string = run_status.ToString();
tensorflow::Tensor* output_tensor = &outputs[0];
const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& output = output_tensor->flat<float>();
const long count = output.size(); // 256*256*3
NSMutableArray *arrayRGBValues = [NSMutableArray array];
for (int i = 0; i < count; i) {
const float value = output(i);
[arrayRGBValues addObject:[NSNumber numberWithFloat:value]];
}
return arrayRGBValues;
最终方法createTranslatedImageInRect
定义如下,所有这些都很容易解释:
- (UIImage *)createTranslatedImageInRect:(CGRect)rect values:(NSArray*)rgbValues
{
UIGraphicsBeginImageContextWithOptions(CGSizeMake(wanted_width, wanted_height), NO, 0.0);
for (int i=0; i<256*256; i ) {
float R = [rgbValues[i*3] floatValue];
float G = [rgbValues[i*3 1] floatValue];
float B = [rgbValues[i*3 2] floatValue];
const int size = 1;
int x = i%6;
int y = i/256;
CGRect rect = CGRectMake(size*x, y*size, size, size);
UIBezierPath *path = [UIBezierPath bezierPathWithRect:rect];
UIColor *color = [UIColor colorWithRed:R green:G blue:B alpha:1.0];
[color setFill];
[path fill];
}
UIImage *image = UIGraphicsGetImageFromCurrentImageContext();
UIGraphicsEndImageContext();
return image;
}
现在,在 iOS 模拟器或设备中运行该应用,点击 GAN 按钮,然后选择生成数字,您将看到 GAN 生成的手写数字的结果,如图 9.3 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hx6kLWCC-1681653119038)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/0e0f47b1-2efa-4ec0-9eaa-4b21356cad23.png)]
图 9.3:显示 GAN 模型选择和生成的手写数字结果
这些数字看起来很像真实的人类手写数字,都是在训练了基本 GAN 模型之后完成的。 如果您返回并查看进行训练的代码,并且停下来思考一下 GAN 的工作原理,一般来说,则生成器和判别器如何相互竞争,以及尝试达到稳定的纳什均衡状态,在这种状态下,生成器可以生成判别器无法分辨出真实还是伪造的真实假数据,您可能会更欣赏 GAN 的魅力。
现在,让我们选择Enhance Image
选项,您将在图 9.4 中看到结果,该结果与图 9.1 中的 Python 测试代码生成的结果相同:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qiA3J1B4-1681653119038)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/63898b4a-4c1d-4605-9fa3-97c8aa00c968.png)]
图 9.4:iOS 上原始的模糊和增强图像
你知道该怎么做。 是时候将我们的爱献给 Android 了。
在 Android 中使用 GAN 模型
事实证明,我们不需要使用自定义的 TensorFlow Android 库,就像我们在第 7 章,“通过 CNN 和 LSTM 识别绘画”中所做的那样,即可在 Android 中运行 GAN 模型。 只需创建一个具有所有默认设置的名为 GAN 的新 Android Studio 应用,将compile 'org.tensorflow:tensorflow-android: '
添加到应用的build.gradle
文件,创建一个新的素材文件夹,然后复制两个 GAN 模型文件和一个测试模糊图像。
现在,您在 Android Studio 中的项目应如图 9.5 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Kzdl4ngv-1681653119038)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/f1c741dd-1fc6-41da-9243-3b0840aa9b20.png)]
图 9.5:Android Studio GAN 应用概述,显示常量定义
请注意,为简单起见,我们将BATCH_SIZE
设置为 1。您可以轻松地将其设置为任何数字,并像在 iOS 中一样获得很多输出。 除了图 9.5 中定义的常量之外,我们还将创建一些实例变量:
private Button mButtonMNIST;
private Button mButtonPix2Pix;
private ImageView mImageView;
private Bitmap mGeneratedBitmap;
private boolean mMNISTModel;
private TensorFlowInferenceInterface mInferenceInterface;
应用布局由一个ImageView
和两个按钮组成,就像我们之前所做的那样,它们在onCreate
方法中实例化:
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mButtonMNIST = findViewById(R.id.mnistbutton);
mButtonPix2Pix = findViewById(R.id.pix2pixbutton);
mImageView = findViewById(R.id.imageview);
try {
AssetManager am = getAssets();
InputStream is = am.open(IMAGE_NAME);
Bitmap bitmap = BitmapFactory.decodeStream(is);
mImageView.setImageBitmap(bitmap);
} catch (IOException e) {
e.printStackTrace();
}
然后,为两个按钮设置两个单击监听器:
代码语言:javascript复制 mButtonMNIST.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
mMNISTModel = true;
Thread thread = new Thread(MainActivity.this);
thread.start();
}
});
mButtonPix2Pix.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
try {
AssetManager am = getAssets();
InputStream is = am.open(IMAGE_NAME);
Bitmap bitmap = BitmapFactory.decodeStream(is);
mImageView.setImageBitmap(bitmap);
mMNISTModel = false;
Thread thread = new Thread(MainActivity.this);
thread.start();
} catch (IOException e) {
e.printStackTrace();
}
}
});
}
轻按按钮后,run
方法在辅助线程中运行:
public void run() {
if (mMNISTModel)
runMNISTModel();
else
runPix2PixBlurryModel();
}
使用基本 GAN 模型
在runMNISTModel
方法中,我们首先为模型准备一个随机输入:
void runMNISTModel() {
float[] floatValues = new float[BATCH_SIZE*100];
Random r = new Random();
for (int i=0; i<BATCH_SIZE; i ) {
for (int j=0; i<100; i ) {
double sample = r.nextGaussian();
floatValues[i] = (float)sample;
}
}
然后将输入提供给模型,运行模型并获得输出值,它们是介于 0.0 到 1.0 之间的缩放灰度值,并将它们转换为 0 到 255 范围内的整数:
代码语言:javascript复制 float[] outputValues = new float[BATCH_SIZE * 28 * 28];
AssetManager assetManager = getAssets();
mInferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE1);
mInferenceInterface.feed(INPUT_NODE1, floatValues, BATCH_SIZE, 100);
mInferenceInterface.run(new String[] {OUTPUT_NODE1}, false);
mInferenceInterface.fetch(OUTPUT_NODE1, outputValues);
int[] intValues = new int[BATCH_SIZE * 28 * 28];
for (int i = 0; i < intValues.length; i ) {
intValues[i] = (int) (outputValues[i] * 255);
}
之后,对于创建位图时设置的每个像素,我们使用返回和转换的灰度值:
代码语言:javascript复制 try {
Bitmap bitmap = Bitmap.createBitmap(28, 28, Bitmap.Config.ARGB_8888);
for (int y=0; y<28; y ) {
for (int x=0; x<28; x ) {
int c = intValues[y*28 x];
int color = (255 & 0xff) << 24 | (c & 0xff) << 16 | (c & 0xff) << 8 | (c & 0xff);
bitmap.setPixel(x, y, color);
}
}
mGeneratedBitmap = Bitmap.createBitmap(bitmap);
}
catch (Exception e) {
e.printStackTrace();
}
最后,我们在主 UI 线程的 ImageView 中显示位图:
代码语言:javascript复制 runOnUiThread(
new Runnable() {
@Override
public void run() {
mImageView.setImageBitmap(mGeneratedBitmap);
}
});
}
如果现在运行该应用,并使用void runPix2PixBlurryModel() {}
的空白实现来避免生成错误,则在单击GENERATE DIGITS
后会看到初始屏幕和结果,如图 9.6 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3a29pMbu-1681653119038)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/1b2a1b02-e084-40bf-a33d-ce5561a34850.png)]
图 9.6:显示生成的数字
使用高级 GAN 模型
runPix2PixBlurryModel
方法类似于前面几章中的代码,在前几章中,我们使用图像输入来馈入模型。 我们首先从图像位图中获取 RGB 值,然后将它们保存到float
数组中:
void runPix2PixBlurryModel() {
int[] intValues = new int[WANTED_WIDTH * WANTED_HEIGHT];
float[] floatValues = new float[WANTED_WIDTH * WANTED_HEIGHT * 3];
float[] outputValues = new float[WANTED_WIDTH * WANTED_HEIGHT * 3];
try {
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open(IMAGE_NAME));
Bitmap scaledBitmap = Bitmap.createScaledBitmap(bitmap, WANTED_WIDTH, WANTED_HEIGHT, true);
scaledBitmap.getPixels(intValues, 0, scaledBitmap.getWidth(), 0, 0, scaledBitmap.getWidth(), scaledBitmap.getHeight());
for (int i = 0; i < intValues.length; i) {
final int val = intValues[i];
floatValues[i * 3 0] = (((val >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD;
floatValues[i * 3 1] = (((val >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD;
floatValues[i * 3 2] = ((val & 0xFF) - IMAGE_MEAN) / IMAGE_STD;
}
然后,我们使用输入来运行模型,并获取并将输出值转换为整数数组,该整数数组随后用于设置新位图的像素:
代码语言:javascript复制 AssetManager assetManager = getAssets();
mInferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE2);
mInferenceInterface.feed(INPUT_NODE2, floatValues, 1, WANTED_HEIGHT, WANTED_WIDTH, 3);
mInferenceInterface.run(new String[] {OUTPUT_NODE2}, false);
mInferenceInterface.fetch(OUTPUT_NODE2, outputValues);
for (int i = 0; i < intValues.length; i) {
intValues[i] = 0xFF000000
| (((int) (outputValues[i * 3] * 255)) << 16)
| (((int) (outputValues[i * 3 1] * 255)) << 8)
| ((int) (outputValues[i * 3 2] * 255));
}
Bitmap outputBitmap = scaledBitmap.copy( scaledBitmap.getConfig() , true);
outputBitmap.setPixels(intValues, 0, outputBitmap.getWidth(), 0, 0, outputBitmap.getWidth(), outputBitmap.getHeight());
mGeneratedBitmap = Bitmap.createScaledBitmap(outputBitmap, bitmap.getWidth(), bitmap.getHeight(), true);
}
catch (Exception e) {
e.printStackTrace();
}
最后,我们在主 UI 的ImageView
中显示位图:
runOnUiThread(
new Runnable() {
@Override
public void run() {
mImageView.setImageBitmap(mGeneratedBitmap);
}
});
}
再次运行该应用,然后立即点击增强图像按钮,您将在几秒钟内看到图 9.7 中的增强图像:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6NPzBbg4-1681653119039)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/7e7c6110-09d5-4908-bfaa-e9b5b67596d3.png)]
图 9.7:Android 上的模糊和增强图像
这使用两个 GAN 模型完成了我们的 Android 应用。
总结
在本章中,我们快速浏览了 GAN 的美好世界。 我们介绍了 GAN 的含义以及它们为何如此有趣的原因-生成器和判别器相互竞争并尝试击败的方式听起来对大多数人来说很有吸引力。 然后,我们详细介绍了如何训练基本 GAN 模型和更高级的图像分辨率增强模型以及如何为移动设备准备它们的详细步骤。 最后,我们向您展示了如何使用这些模型构建 iOS 和 Android 应用。 如果您对整个过程和结果感到兴奋,那么您肯定会想进一步探索 GAN,这是一个快速发展的领域,在该领域中,新型 GAN 已经迅速开发出来,以克服先前模型的缺点; 例如,正如我们在“增强图像分辨率”小节的 GAN 高级模型中看到的那样,开发了需要配对图像进行训练的 pix2pix 模型的同一位研究人员提出了一种称为 CycleGAN 的模型,删除了图像配对的要求。 如果您对我们生成的数字或增强的图像的质量不满意,则可能还应该进一步探索 GAN,以了解如何改进 GAN 模型。 正如我们之前提到的,GAN 仍很年轻,研究人员仍在努力稳定训练,如果可以稳定的话,将会取得更大的成功。 至少到目前为止,您已经获得了如何在移动应用中快速部署 GAN 模型的经验。 由您决定是关注最新,最出色的 GAN 并在移动设备上使用它们,还是暂时搁置您的移动开发人员的帽子,会全力以赴来构建新的或改进现有的 GAN 模型。
如果 GAN 在深度学习社区中引起了极大的兴奋,那么 AlphaGo 在 2016 年和 2017 年击败最优秀的人类 GO 玩家的成就无疑令所有人都感到惊讶。 此外,在 2017 年 10 月,AlphaGo Zero(一种完全基于自学强化学习而无需任何人类知识的新算法)被推举为击败 AlphaGo 100-0,令人难以置信。 2017 年 12 月,与仅在 GO 游戏中定位的 AlphaGo 和 AlphaGo Zero 不同,AlphaZero(一种可在许多具有挑战性的领域实现“超人表现”的算法)被发布。 在下一章中,我们将看到如何使用最新最酷的 AlphaZero 来构建和训练用于玩简单游戏的模型,以及如何在移动设备上运行该模型。
十、构建类似 AlphaZero 的手机游戏应用
尽管现代人工智能(AI)的日益普及基本上是由 2012 年深度学习的突破引起的,但 2016 年 3 月,Google DeepMind 的 AlphaGo 以 4-1 击败围棋世界冠军 Lee Sedol,然后在 2017 年 5 月以 3-0 击败了目前排名第一的围棋玩家 Ke Jie 的历史性事件,这在很大程度上使 AI 家喻户晓。 由于围棋游戏的复杂性,人们普遍认为任务无法实现,或者至少十年内计算机程序不可能击败顶级围棋玩家。
在 2017 年 5 月 AlphaGo 和 Ke Jie 的比赛之后,Google 退役了 AlphaGo; 谷歌(DeepMind)是 Google 因其开创性的深度强化学习技术而收购的创业公司,也是 AlphaGo 的开发商,决定将其 AI 研究重点放在其他领域。 然后,有趣的是,在 2017 年 10 月,DeepMind 在游戏上发表了另一篇论文《围棋:在没有人类知识的情况下掌握围棋游戏》,它描述了一种称为 AlphaGo Zero 的改进算法,该算法仅通过自我强化学习来学习如何玩围棋,而无需依赖任何人类专家知识,例如大量玩过的专业的围棋游戏,AlphaGo 用它来训练其模型。 令人惊讶的是,AlphaGo Zero 完全击败了 AlphaGo,后者在几个月前以 100-0 击败了世界上最好的人类 GO 玩家!
事实证明,这只是朝着 Google 更雄心勃勃的目标迈出的一步,该目标是将 AlphaGo 背后的 AI 技术应用和改进到其他领域。 2017 年 12 月,DeepMind 发表了另一篇论文,即使用通用强化学习算法通过自学掌握国际象棋和将棋,对 AlphaGo 进行了概括。 将零程序归类为一个称为 AlphaZero 的算法,并使用该算法从头开始快速学习如何玩象棋和将棋的游戏,从除了游戏规则之外没有任何领域知识的随机游戏开始,并在 24 小时内实现了超人级别并击败世界冠军。
在本章中,我们将带您浏览 AlphaZero 的最新最酷的部分,向您展示如何构建和训练类似 AlphaZero 的模型来玩一个简单而有趣的游戏,称为 Connect4,在 TensorFlow 和 Keras 中使用,这是我们在第 8 章,“使用 RNN 预测股价”的流行的高级深度学习库。 我们还将介绍如何使用训练有素的 AlphaZero 模型来获得训练有素的专家策略,以指导移动游戏的玩法,以及使用该模型玩 Connect4 游戏的完整 iOS 和 Android 应用的源代码。
总之,本章将涵盖以下主题:
- AlphaZero – 它如何工作?
- 为 Connect4 构建和训练类似于 AlphaZero 的模型
- 在 iOS 中使用模型玩 Connect4
- 在 Android 中使用模型玩 Connect4
AlphaZero – 它如何工作?
AlphaZero 算法包含三个主要组件:
- 一个深度卷积神经网络,它以棋盘位置(或状态)为输入,并从该位置输出一个值作为预测的博弈结果,该策略是输入棋盘状态下每个可能动作的移动概率列表。
- 一种通用的强化学习算法,该算法通过自玩从头开始学习,除了游戏规则外,没有特定的领域知识。 通过自增强学习学习深度神经网络的参数,以使预测值与实际自游戏结果之间的损失最小,并使预测策略与搜索概率之间的相似性最大化,这来自以下算法。
- 一种通用(与域无关)的蒙特卡洛树搜索(MCTS)算法,该算法从头至尾模拟自玩游戏,并通过考虑到从深度神经网络返回的预测值和策略概率值,以及访问节点的频率—有时,选择访问次数较少的节点称为强化学习中的探索(与采取较高预测值和策略的举动相反,这称为利用)。 探索与利用之间的良好平衡可以带来更好的结果。
强化学习的历史可以追溯到 1960 年代,当时该术语在工程文献中首次使用。 但是突破发生在 2013 年,当时 DeepMind 将强化学习与深度学习相结合,并开发了深度强化学习应用,该应用学会了从头开始玩 Atari 游戏,以原始像素为输入的,并随后击败了人类。 与监督学习不同,监督学习需要标记数据进行训练,就像我们在前几章中建立或使用的许多模型中所看到的那样,强化学习使用反复试验的方法来获得更好的效果:智能体与环境交互并接收在每个状态上采取的每个动作的奖励(正面或负面)。 在 AlphaZero 下象棋的示例中,只有在游戏结束后才能获得奖励,获胜的结果为 1,失败的为 -1,平局为 0。强化学习 AlphaZero 中的算法对我们前面提到的损失使用梯度下降来更新深层神经网络的参数, 就像一个通用函数近似来学习和编码游戏技巧。
学习或训练过程的结果可以是由深度神经网络生成的策略,该策略说出对任何状态应采取的行动,或者是将每个状态以及该状态的每个可能动作映射到长期奖励的值函数 。
如果深层神经网络使用自我玩法强化学习所学习的策略是理想的,则我们可能无需让程序在游戏过程中执行任何 MCTS,而程序总是可以最大可能地选择移动。 但是在诸如象棋或围棋的复杂游戏中,无法生成完美的策略,因此 MCTS 必须与训练有素的深度网络一起工作,以指导针对每种游戏状态的最佳可能动作的搜索。
如果您不熟悉强化学习或 MCTS,则在互联网上有很多关于强化学习或 MCTS 的信息。 考虑查看 Richard Sutton 和 Andrew Barto 的经典著作《强化学习:简介》,该书可在以下网站上公开获得。 您还可以在 YouTube 上观看 DeepMind 的 AlphaGo 的技术负责人 David Silver 的强化学习课程视频(搜索“强化学习 David Silver”)。 一个有趣且有用的强化学习工具包是 OpenAI Gym。 在本书的最后一章中,我们将更深入地学习强化学习和 OpenAI Gym。 对于 MCTS,请查看其维基页面,以及此博客。
在下一节中,我们将研究以 TensorFlow 为后端的 Keras 实现 AlphaZero 算法,其目标是使用该算法构建和训练模型以玩 Connect4。您将看到模型架构是什么样,以及构建模型的 Keras 关键代码。
训练和测试适用于 Connect4 的类似 AlphaZero 的模型
如果您从未玩过 Connect4,则可以在这个页面上免费玩它。 这是一个快速有趣的游戏。 基本上,两个玩家轮流从一列的顶部将不同颜色的光盘放入六行乘七列的网格中。 如果尚未在该列中放入任何光盘,则新放置的光盘将位于该列的底部,或者位于该列中最后放置的光盘的顶部。 谁先在三个可能的方向(水平,垂直,对角线)中的任何一个方向上拥有自己颜色的四个连续光盘赢得比赛。
Connect4 的 AlphaZero 模型基于存储库,这是这个页面的分支, 有一个不错的博客,如何使用 Python 和 Keras 构建自己的 AlphaZero AI,您可能应该在继续之前阅读它,因此以下步骤更有意义。
训练模型
在我们看一些核心代码片段之前,让我们首先看一下如何训练模型。 首先,通过在终端上运行以下命令来获取存储库:
代码语言:javascript复制 git clone https://github.com/jeffxtang/DeepReinforcementLearning
然后,如果尚未在第 8 章,“使用 RNN 预测股价”中设置,则设置 Keras 和 TensorFlow 虚拟环境:
代码语言:javascript复制cd
mkdir ~/tf_keras
virtualenv --system-site-packages ~/tf_keras/
cd ~/tf_keras/
source ./bin/activate
easy_install -U pip
#On Mac:
pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0-py2-none-any.whl
#On Ubuntu:
pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.4.0-cp27-none-linux_x86_64.whl
easy_install ipython
pip install keras
您也可以在前面的pip install
命令中尝试 TensorFlow 1.5-1.8 下载 URL。
现在,先按cd DeepReinforcementLearning
打开run.ipynb
,然后按jupyter notebook
打开-根据您的环境,如果发现任何错误,则需要安装缺少的 Python 包。 在浏览器上,打开http://localhost:8888/notebooks/run.ipynb
,然后运行笔记本中的第一个代码块以加载所有必需的核心库,并运行第二个代码块以开始训练—该代码被编写为永远训练,因此经过数小时的训练后,您可能要取消jupyter notebook
命令。 在较旧的 Mac 上,要花一个小时才能看到在以下目录中创建的模型的第一个版本(较新的版本,例如version0004.h5
,其权重比旧版本中的权重要微调,例如 version0001.h5
):
(tf_keras) MacBook-Air:DeepReinforcementLearning jeffmbair$ ls -lt run/models
-rw-r--r-- 1 jeffmbair staff 3781664 Mar 8 15:23 version0004.h5
-rw-r--r-- 1 jeffmbair staff 3781664 Mar 8 14:59 version0003.h5
-rw-r--r-- 1 jeffmbair staff 3781664 Mar 8 14:36 version0002.h5
-rw-r--r-- 1 jeffmbair staff 3781664 Mar 8 14:12 version0001.h5
-rw-r--r-- 1 jeffmbair staff 656600 Mar 8 12:29 model.png
带有.h5
扩展名的文件是 HDF5 格式的 Keras 模型文件,每个文件主要包含模型架构定义,训练后的权重和训练配置。 稍后,您将看到如何使用 Keras 模型文件生成 TensorFlow 检查点文件,然后将其冻结为可在移动设备上运行的模型文件。
model.png
文件包含深度神经网络架构的详细视图。 卷积层的许多残差块之后是批量归一化和 ReLU 层,以稳定训练,它的深度非常大。 该模型的顶部如下图所示(中间部分很大,因此我们将不显示其中间部分,建议您打开model.png
文件以供参考):
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PYBZOSHd-1681653119039)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/99dc16a9-52c9-4adf-bd8f-1a244504de79.png)]
图 10.1:深度残差网络的第一层
值得注意的是,神经网络称为残差网络(ResNet),由 Microsoft 于 2015 年在 ImageNet 和 COCO 2015 竞赛的获奖作品中引入。 在 ResNet 中,使用身份映射(图 10.1 右侧的箭头)可避免在网络越深时出现更高的训练误差。 有关 ResNet 的更多信息,您可以查看原始论文《用于图像识别的深度残差学习》, 以及博客《了解深度残差网络》 - 一个简单的模块化学习框架,它重新定义了构成最新技术的内容。
深度网络的最后一层如图 10.2 所示,您可以看到,在最后的残差块和具有批量归一化和 ReLU 层的卷积层之后,将应用密集的全连接层以输出value_head and policy_head
值:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-evyAiKIU-1681653119039)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/e3383760-5f06-4d10-a903-8978f41281de.png)]
图 10.2:深度 Resnet 的最后一层
在本节的最后部分,您将看到一些使用 Keras API 的 Python 代码片段,该片段对 ResNet 有着很好的支持,以构建这样的网络。 现在让我们让这些模型首先互相对抗,然后与我们一起对抗,看看这些模型有多好。
测试模型
例如,要让模型的版本 4 与版本 1 竞争,请首先通过运行mkdir -p run_archive/connect4/run0001/models
创建新的目录路径,然后将run/models
文件从run/models
复制到run0001/models
目录。 然后将DeepReinforcementLearning
目录中的play.py
更改为:
playMatchesBetweenVersions(env, 1, 1, 4, 10, lg.logger_tourney, 0)
参数1,1,4,10
的第一个值表示运行版本,因此 1 表示模型位于run_archive/connect4
的run0001/models
中。 第二个和第三个值是两个玩家的模型版本,因此 1 和 4 表示该模型的版本 1 将与版本 4 一起玩。10 是玩的次数或剧集。
运行python play.py
脚本按照指定的方式玩游戏后,可以使用以下命令找出结果:
grep WINS run/logs/logger_tourney.log |tail -10
对于与版本 1 对抗的版本 4,您可能会看到与以下内容相似的结果,这意味着它们处于大致相同的水平:
代码语言:javascript复制2018-03-14 23:55:21,001 INFO player2 WINS!
2018-03-14 23:55:58,828 INFO player1 WINS!
2018-03-14 23:56:43,778 INFO player2 WINS!
2018-03-14 23:56:51,981 INFO player1 WINS!
2018-03-14 23:57:00,985 INFO player1 WINS!
2018-03-14 23:57:30,389 INFO player2 WINS!
2018-03-14 23:57:39,742 INFO player1 WINS!
2018-03-14 23:58:19,498 INFO player2 WINS!
2018-03-14 23:58:27,554 INFO player1 WINS!
2018-03-14 23:58:36,490 INFO player1 WINS!
config.py
中有一个设置MCTS_SIMS = 50
(MCTS 的模拟次数)会对游玩时间产生重大影响。 在每个状态下,MCTS 都会进行MCTS_SIMS
次仿真,并与受过训练的网络一起提出最佳方案。 因此,将MCTS_SIMS
设置为 50 会使play.py
脚本运行更长的时间,但如果训练的模型不够好,并不一定会使玩家更强大。 在使用特定版本的模型时,可以将其更改为不同的值,以查看其如何影响其强度水平。 要手动玩一个特定版本,请将play.py
更改为:
playMatchesBetweenVersions(env, 1, 4, -1, 10, lg.logger_tourney, 0)
在这里,-1 表示人类玩家。 因此,上一行会要求您(玩家 2)与该模型的玩家 1,版本 4 对抗。 现在运行python play.py
后,您会看到输入提示Enter your chosen action:
; 打开另一个终端,转到DeepReinforcementLearning
目录,然后键入 tail -f run/logs/logger_tourney.log
命令,您将看到这样打印的电路板网格:
2018-03-15 00:03:43,907 INFO ====================
2018-03-15 00:03:43,907 INFO EPISODE 1 OF 10
2018-03-15 00:03:43,907 INFO ====================
2018-03-15 00:03:43,908 INFO player2 plays as X
2018-03-15 00:03:43,908 INFO --------------
2018-03-15 00:03:43,908 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:03:43,908 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:03:43,908 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:03:43,909 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:03:43,909 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:03:43,909 INFO ['-', '-', '-', '-', '-', '-', '-']
请注意,最后 6 行代表 6 行乘 7 列的板格:第一行对应于 7 个动作编号 0、1、2、3、4、5、6,第二行对应于 7、8、9 10、11、12、13 等,因此最后一行映射到 35、36、37、38、39、40、41 动作编号。
现在,在运行play.py
的第一个终端中输入数字 38,该模型的版本 4 的玩家 1(打为 O)将移动,显示新的棋盘格,如下所示:
2018-03-15 00:06:13,360 INFO action: 38
2018-03-15 00:06:13,364 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:13,365 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:13,365 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:13,365 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:13,365 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:13,365 INFO ['-', '-', '-', 'X', '-', '-', '-']
2018-03-15 00:06:13,366 INFO --------------
2018-03-15 00:06:15,155 INFO action: 31
2018-03-15 00:06:15,155 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:15,156 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:15,156 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:15,156 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:06:15,156 INFO ['-', '-', '-', 'O', '-', '-', '-']
2018-03-15 00:06:15,156 INFO ['-', '-', '-', 'X', '-', '-', '-']
在玩家 1 移至游戏结束后继续输入新动作,直到可能的新游戏开始:
代码语言:javascript复制2018-03-15 00:16:03,205 INFO action: 23
2018-03-15 00:16:03,206 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:16:03,206 INFO ['-', '-', '-', 'O', '-', '-', '-']
2018-03-15 00:16:03,206 INFO ['-', '-', '-', 'O', 'O', 'O', '-']
2018-03-15 00:16:03,207 INFO ['-', '-', 'O', 'X', 'X', 'X', '-']
2018-03-15 00:16:03,207 INFO ['-', '-', 'X', 'O', 'X', 'O', '-']
2018-03-15 00:16:03,207 INFO ['-', '-', 'O', 'X', 'X', 'X', '-']
2018-03-15 00:16:03,207 INFO --------------
2018-03-15 00:16:14,175 INFO action: 16
2018-03-15 00:16:14,178 INFO ['-', '-', '-', '-', '-', '-', '-']
2018-03-15 00:16:14,179 INFO ['-', '-', '-', 'O', '-', '-', '-']
2018-03-15 00:16:14,179 INFO ['-', '-', 'X', 'O', 'O', 'O', '-']
2018-03-15 00:16:14,179 INFO ['-', '-', 'O', 'X', 'X', 'X', '-']
2018-03-15 00:16:14,179 INFO ['-', '-', 'X', 'O', 'X', 'O', '-']
2018-03-15 00:16:14,180 INFO ['-', '-', 'O', 'X', 'X', 'X', '-']
2018-03-15 00:16:14,180 INFO --------------
2018-03-15 00:16:14,180 INFO player2 WINS!
2018-03-15 00:16:14,180 INFO ====================
2018-03-15 00:16:14,180 INFO EPISODE 2 OF 5
这样便可以手动测试模型特定版本的强度。 了解前面板上的表示形式还可以帮助您稍后了解 iOS 和 Android 代码。 如果您过于轻易地击败模型,可以采取几种措施来尝试改善模型:
- 在
run.ipynb
(第二个代码块)Python 笔记本中运行模型几天。 在我们的测试中,该模型的版本 19 在较旧的 iMac 上运行了大约一天后,击败了版本 1 或 4 10:0(回想一下版本 1 和版本 4 处于相同水平) - 为了提高 MCTS 评分公式的强度:MCTS 在模拟过程中使用上置信度树(UCT)评分来选择要做出的举动,并且仓库中的公式是这样的(请参见博客以及 AlphaZero 官方论文以获取更多详细信息):
edge.stats['P'] * np.sqrt(Nb) / (1 edge.stats['N'])
如果我们将其更改为更类似于 DeepMind 的用法:
代码语言:javascript复制edge.stats['P'] * np.sqrt(np.log(1 Nb) / (1 edge.stats['N']))
然后,即使将MCTS_SIMS
设置为 10,版本 19 仍以 10:0 完全击败版本 1。
- 微调深度神经网络模型以尽可能接近地复制 AlphaZero
关于模型的细节不在本书的讨论范围之内,但让我们继续看看如何在 Keras 中构建模型,以便在以后在 iOS 和 Android 上运行它时更加欣赏它(您可以查看其余部分)。 agent.py
,MCTS.py
和game.py
中的主要代码,以更好地了解游戏的工作方式)。
研究模型构建代码
在model.py
中,Keras 的导入如下:
from keras.models import Sequential, load_model, Model
from keras.layers import Input, Dense, Conv2D, Flatten, BatchNormalization, Activation, LeakyReLU, add
from keras.optimizers import SGD
from keras import regularizers
四种主要的模型构建方法是:
代码语言:javascript复制def residual_layer(self, input_block, filters, kernel_size)
def conv_layer(self, x, filters, kernel_size)
def value_head(self, x)
def policy_head(self, x)
它们都具有一个或多个Conv2d
层,然后激活BatchNormalization
和LeakyReLU
,如图 10.1 所示,但是value_head
和policy_head
也具有完全连接的层,如图 10.2 所示。 卷积层以生成我们之前谈到的输入状态的预测值和策略概率。 在_build_model
方法中,定义了模型输入和输出:
main_input = Input(shape = self.input_dim, name = 'main_input')
vh = self.value_head(x)
ph = self.policy_head(x)
model = Model(inputs=[main_input], outputs=[vh, ph])
_build_model
方法中还定义了深度神经网络以及模型损失和优化器:
if len(self.hidden_layers) > 1:
for h in self.hidden_layers[1:]:
x = self.residual_layer(x, h['filters'], h['kernel_size'])
model.compile(loss={'value_head': 'mean_squared_error', 'policy_head': softmax_cross_entropy_with_logits}, optimizer=SGD(lr=self.learning_rate, momentum = config.MOMENTUM), loss_weights={'value_head': 0.5, 'policy_head': 0.5})
为了找出确切的输出节点名称(输入节点名称指定为'main_input'
),我们可以在model.py
中添加print(vh)
和print(ph)
; 现在运行的python play.py
将输出以下两行:
Tensor("value_head/Tanh:0", shape=(?, 1), dtype=float32)
Tensor("policy_head/MatMul:0", shape=(?, 42), dtype=float32)
冻结 TensorFlow 检查点文件并将模型加载到移动应用时,我们将需要它们。
冻结模型
首先,我们需要创建 TensorFlow 检查点文件–只需取消注释funcs.py
中player1
和player2
的两行,然后再次运行python play.py
:
if player1version > 0:
player1_network = player1_NN.read(env.name, run_version, player1version)
player1_NN.model.set_weights(player1_network.get_weights())
# saver = tf.train.Saver()
# saver.save(K.get_session(), '/tmp/alphazero19.ckpt')
if player2version > 0:
player2_network = player2_NN.read(env.name, run_version, player2version)
player2_NN.model.set_weights(player2_network.get_weights())
# saver = tf.train.Saver()
# saver.save(K.get_session(), '/tmp/alphazero_4.ckpt')
您可能会觉得很熟悉,因为我们在第 8 章,“使用 RNN 预测股票价格”做了类似的操作。 确保将alphazero19.ckpt
和alphazero_4.ckpt
中的版本号(例如 19 或 4)与play.py
中定义的内容(例如playMatchesBetweenVersions(env, 1, 19, 4, 10, lg.logger_tourney, 0)
)以及 run_archive/connect4/run0001/models
目录中的版本号匹配。在这种情况下, version0019.h5
和 version0004.h5
都需要存在。
运行play.py
后,将在/tmp
目录中生成alphazero19
检查点文件:
-rw-r--r-- 1 jeffmbair wheel 99 Mar 13 18:17 checkpoint
-rw-r--r-- 1 jeffmbair wheel 1345545 Mar 13 18:17 alphazero19.ckpt.meta
-rw-r--r-- 1 jeffmbair wheel 7296096 Mar 13 18:17 alphazero19.ckpt.data-00000-of-00001
-rw-r--r-- 1 jeffmbair wheel 8362 Mar 13 18:17 alphazero19.ckpt.index
现在,您可以转到 TensorFlow 根源目录并运行freeze_graph
脚本:
python tensorflow/python/tools/freeze_graph.py
--input_meta_graph=/tmp/alphazero19.ckpt.meta
--input_checkpoint=/tmp/alphazero19.ckpt
--output_graph=/tmp/alphazero19.pb
--output_node_names="value_head/Tanh,policy_head/MatMul"
--input_binary=true
为简单起见,由于它是小型模型,因此我们不会我们不会进行图变换和内存映射变换,就像第 6 章,“用自然语言描述图像”和第 9 章,“用 GAN 生成和增强图像”。 现在,我们准备在移动设备上使用该模型并编写代码以在 iOS 和 Android 设备上玩 Connect4。
在 iOS 中使用模型玩 Connect4
对于新冻结的,可选的经过转换和映射的模型,您始终可以将其与 TensorFlow Pod 一起尝试,以查看是否有幸能够以简单的方式使用它。 在我们的案例中,当使用 TensorFlow Pod 加载它时,我们生成的alphazero19.pb
模型会导致以下错误:
Couldn't load model: Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs. Registered devices: [CPU], Registered kernels:
device='GPU'; T in [DT_FLOAT]
device='GPU'; T in [DT_INT32]
device='GPU'; T in [DT_BOOL]
device='GPU'; T in [DT_STRING]
device='CPU'; T in [DT_INT32]
device='CPU'; T in [DT_FLOAT]
[[Node: batch_normalization_13/cond/Switch = Switch[T=DT_BOOL, _output_shapes=[[], []]](batch_normalization_1/keras_learning_phase, batch_normalization_1/keras_learning_phase)]]
您应该已经知道如何解决这种类型的错误,因为前面的章节已经对此进行了讨论。 回顾一下,只需确保tensorflow/contrib/makefile/tf_op_files.txt
文件中包含Switch
操作的内核文件。 您可以通过运行grep 'REGISTER.*"Switch"' tensorflow/core/kernels/*.cc
来查找哪个Switch
内核文件,该文件应显示tensorflow/core/kernels/control_flow_ops.cc
。 默认情况下,从 TensorFlow 1.4 开始, tf_op_files.txt
中包含 control_flow_ops.cc
文件,因此您所需要做的就是通过运行tensorflow/contrib/makefile/build_all_ios.sh
来构建 TensorFlow iOS 自定义库。 如果您已在上一章中成功运行了 iOS 应用,则该库已经不错,您不需要或不想再次运行耗时的命令。
现在,只需创建一个名为 AlphaZero 的新 Xcode iOS 项目,然后将上一章中的 iOS 项目中的tensorflow_utils.mm
和tensorflow_utils.h
文件以及上一节中生成的alphazero19.pb
模型文件拖放到项目。 将ViewController.m
重命名为ViewController.mm
,并添加一些常量和变量。 您的项目应如图 10.3 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4AUhovAY-1681653119040)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/a0639a2a-55f3-485a-92cd-4ffdc99b398c.png)]
图 10.3:在 Xcode 中显示 AlphaZero iOS 应用
我们只需要使用三个 UI 组件:
- 一个
UIImageView
,显示棋盘和演奏的棋子。 - 显示游戏结果并提示用户采取措施的
UILabel
。 - 一个
UIButton
可以玩或重玩游戏。 和以前一样,我们以编程方式在viewDidLoad
方法中创建和定位它们。
轻按游玩或重放按钮时,随机决定谁先走,重置表示为整数数组的棋盘,清除存储我们的移动和 AI 的移动的两个向量,以及重新绘制原始板格:
代码语言:javascript复制 int n = rand() % 2;
aiFirst = (n==0);
if (aiFirst) aiTurn = true;
else aiTurn = false;
for (int i=0; i<PIECES_NUM; i )
board[i] = 0;
aiMoves.clear();
humanMoves.clear();
_iv.image = [self createBoardImageInRect:_iv.frame];
然后在辅助线程上开始游戏:
代码语言:javascript复制 dispatch_async(dispatch_get_global_queue(0, 0), ^{
std::string result = playGame(withMCTS);
dispatch_async(dispatch_get_main_queue(), ^{
NSString *rslt = [NSString stringWithCString:result.c_str() encoding:[NSString defaultCStringEncoding]];
[_lbl setText:rslt];
_iv.image = [self createBoardImageInRect:_iv.frame];
});
});
在playGame
方法中,首先检查是否已经加载了我们的模型,如果没有加载,则进行加载:
string playGame(bool withMCTS) {
if (!_modelLoaded) {
tensorflow::Status load_status;
load_status = LoadModel(MODEL_FILE, MODEL_FILE_TYPE, &tf_session);
if (!load_status.ok()) {
LOG(FATAL) << "Couldn't load model: " << load_status;
return "";
}
_modelLoaded = YES;
}
如果轮到我们了,请返回并告诉我们。 否则,按照模型的期望将板状态转换为二进制格式的输入:
代码语言:javascript复制 if (!aiTurn) return "Tap the column for your move";
int binary[PIECES_NUM*2];
for (int i=0; i<PIECES_NUM; i )
if (board[i] == 1) binary[i] = 1;
else binary[i] = 0;
for (int i=0; i<PIECES_NUM; i )
if (board[i] == -1) binary[42 i] = 1;
else binary[PIECES_NUM i] = 0;
例如,如果板数组为[0 1 1 -1 1 -1 0 0 1 -1 -1 -1 -1 1 0 0 1 -1 1 -1 1 0 0 -1 -1 -1 1 -1 0 1 1 1 -1 -1 -1 -1 1 1 1 -1 1 1 -1]
,代表以下板状态(X
表示 1,O
表示 -1,-
表示 0):
['-', 'X', 'X', 'O', 'X', 'O', '-']
['-', 'X', 'O', 'O', 'O', 'X', '-']
['-', 'X', 'O', 'X', 'O', 'X', '-']
['-', 'O', 'O', 'O', 'X', 'O', '-']
['X', 'X', 'X', 'O', 'O', 'O', 'O']
['X', 'X', 'X', 'O', 'X', 'X', 'O']
然后,使用前面的代码段构建的二进制数组将为[0 1 1 0 1 0 0 0 0 0 0 0 1 0 0 1 0 1 0 1 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 1 1 1 0 1 1 0 0 0 0 1 0 1 0 0 0 1 1 1 0 0 0 0 1 0 1 0 0 0 1 1 1 0 1 0 0 0 0 1 1 1 1 0 0 0 1 0 0 1]
,它在板上编码两个玩家的棋子。
仍然在playGame
方法中,调用getProbs
方法,该方法使用binary
输入运行冻结的模型,并在probs
中返回概率策略,并在策略中找到最大概率值:
float *probs = new float[PIECES_NUM];
for (int i=0; i<PIECES_NUM; i )
probs[i] = -100.0;
if (getProbs(binary, probs)) {
int action = -1;
float max = 0.0;
for (int i=0; i<PIECES_NUM; i ) {
if (probs[i] > max) {
max = probs[i];
action = i;
}
}
我们将所有probs
数组元素初始化为 -100.0 的原因是,在getProbs
方法内部(我们将很快显示),probs
数组将仅针对允许的操作更改为由策略返回的值(所有 -1.0 到 1.0 之间的小值),因此所有非法行为的probs
值将保持为 -100.0,并且在softmax
函数之后,这使得非法移动的可能性基本为零,我们可以使用合法行动的可能性。
我们仅使用最大概率值来指导 AI 的移动,而不使用 MCTS,如果我们希望 AI 在象棋或围棋这样的复杂游戏中真正强大,这将是必要的。 如前所述,如果从经过训练的模型返回的策略是完美的,则无需使用 MCTS。 我们将在书的源代码存储库中保留 MCTS 实现,以供您参考,而不是显示 MCTS 的所有实现细节。
playGame
方法中的其余代码根据模型返回的所有合法动作中的最大概率,以选定的动作来更新木板,将printBoard
辅助方法调用来在 Xcode 输出面板上打印板以进行更好的调试,将动作添加到 aiMoves
向量中,以便可以正确重绘板,并在游戏结束时返回正确的状态信息。 通过将 aiTurn
设置为 false
,您将很快看到的触摸事件处理器将接受人类的触摸手势,作为人类打算采取的动作; 如果 aiTurn
为 true
,则触摸处理器将忽略所有触摸手势:
board[action] = AI_PIECE;
printBoard(board);
aiMoves.push_back(action);
delete []probs;
if (aiWon(board)) return "AI Won!";
else if (aiLost(board)) return "You Won!";
else if (aiDraw(board)) return "Draw";
} else {
delete []probs;
}
aiTurn = false;
return "Tap the column for your move";
}
printBoard
辅助方法如下:
void printBoard(int bd[]) {
for (int i = 0; i<6; i ) {
for (int j=0; j<7; j ) {
cout << PIECE_SYMBOL[bd[i*7 j]] << " ";
}
cout << endl;
}
cout << endl << endl;
}
因此,在 Xcode 输出面板中,它将打印出如下内容:
代码语言:javascript复制- - - - - - -
- - - - - - -
- - O - - - -
X - O - - - O
O O O X X - X
X X O O X - X
在getProbs
键方法中,首先定义输入和输出节点名称,然后使用binary
中的值准备输入张量:
bool getProbs(int *binary, float *probs) {
std::string input_name = "main_input";
std::string output_name1 = "value_head/Tanh";
std::string output_name2 = "policy_head/MatMul";
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1,2,6,7}));
auto input_mapped = input_tensor.tensor<float, 4>();
for (int i = 0; i < 2; i ) {
for (int j = 0; j<6; j ) {
for (int k=0; k<7; k ) {
input_mapped(0,i,j,k) = binary[i*42 j*7 k];
}
}
}
现在使用输入运行模型并获取输出:
代码语言:javascript复制 std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = tf_session->Run({{input_name, input_tensor}}, {output_name1, output_name2}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Getting model failed:" << run_status;
return false;
}
tensorflow::Tensor* value_tensor = &outputs[0];
tensorflow::Tensor* policy_tensor = &outputs[1];
const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& value = value_tensor->flat<float>();
const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& policy = policy_tensor->flat<float>();
仅设置允许动作的概率值,然后调用softmax
以使允许动作的probs
值之和为 1:
vector<int> actions;
getAllowedActions(board, actions);
for (int action : actions) {
probs[action] = policy(action);
}
softmax(probs, PIECES_NUM);
return true;
}
getAllowedActions
函数定义如下:
void getAllowedActions(int bd[], vector<int> &actions) {
for (int i=0; i<PIECES_NUM; i ) {
if (i>=PIECES_NUM-7) {
if (bd[i] == 0)
actions.push_back(i);
}
else {
if (bd[i] == 0 && bd[i 7] != 0)
actions.push_back(i);
}
}
}
以下是softmax
函数,它们都很简单:
void softmax(float vals[], int count) {
float max = -FLT_MAX;
for (int i=0; i<count; i ) {
max = fmax(max, vals[i]);
}
float sum = 0.0;
for (int i=0; i<count; i ) {
vals[i] = exp(vals[i] - max);
sum = vals[i];
}
for (int i=0; i<count; i ) {
vals[i] /= sum;
}
}
定义了其他一些辅助函数来测试游戏结束状态:
代码语言:javascript复制bool aiWon(int bd[]) {
for (int i=0; i<69; i ) {
int sum = 0;
for (int j=0; j<4; j )
sum = bd[winners[i][j]];
if (sum == 4*AI_PIECE ) return true;
}
return false;
}
bool aiLost(int bd[]) {
for (int i=0; i<69; i ) {
int sum = 0;
for (int j=0; j<4; j )
sum = bd[winners[i][j]];
if (sum == 4*HUMAN_PIECE ) return true;
}
return false;
}
bool aiDraw(int bd[]) {
bool hasZero = false;
for (int i=0; i<PIECES_NUM; i ) {
if (bd[i] == 0) {
hasZero = true;
break;
}
}
if (!hasZero) return true;
return false;
}
bool gameEnded(int bd[]) {
if (aiWon(bd) || aiLost(bd) || aiDraw(bd)) return true;
return false;
}
aiWon
和aiLost
函数都使用一个常量数组,该数组定义了所有 69 个可能的获胜位置:
int winners[69][4] = {
{0,1,2,3},
{1,2,3,4},
{2,3,4,5},
{3,4,5,6},
{7,8,9,10},
{8,9,10,11},
{9,10,11,12},
{10,11,12,13},
......
{3,11,19,27},
{2,10,18,26},
{10,18,26,34},
{1,9,17,25},
{9,17,25,33},
{17,25,33,41},
{0,8,16,24},
{8,16,24,32},
{16,24,32,40},
{7,15,23,31},
{15,23,31,39},
{14,22,30,38}};
在触摸事件处理器中,首先确保轮到人了。 然后检查触摸点值是否在面板区域内,根据触摸位置获取点击的列,并更新board
数组和humanMoves
向量:
- (void) touchesEnded:(NSSet *)touches withEvent:(UIEvent *)event {
if (aiTurn) return;
UITouch *touch = [touches anyObject];
CGPoint point = [touch locationInView:self.view];
if (point.y < startY || point.y > endY) return;
int column = (point.x-startX)/BOARD_COLUMN_WIDTH;
for (int i=0; i<6; i )
if (board[35 column-7*i] == 0) {
board[35 column-7*i] = HUMAN_PIECE;
humanMoves.push_back(35 column-7*i);
break;
}
其余触摸处理器通过调用createBoardImageInRect
来重绘ImageView
,它使用BezierPath
绘制或重绘棋盘和所有已玩过的棋子,检查游戏状态并在游戏结束时返回结果,或者继续玩游戏,如果没有:
_iv.image = [self createBoardImageInRect:_iv.frame];
aiTurn = true;
if (gameEnded(board)) {
if (aiWon(board)) _lbl.text = @"AI Won!";
else if (aiLost(board)) _lbl.text = @"You Won!";
else if (aiDraw(board)) _lbl.text = @"Draw";
return;
}
dispatch_async(dispatch_get_global_queue(0, 0), ^{
std::string result = playGame(withMCTS));
dispatch_async(dispatch_get_main_queue(), ^{
NSString *rslt = [NSString stringWithCString:result.c_str() encoding:[NSString defaultCStringEncoding]];
[_lbl setText:rslt];
_iv.image = [self createBoardImageInRect:_iv.frame];
});
});
}
其余的 iOS 代码全部在createBoardImageInRect
方法中,该方法使用 UIBezierPath
中的moveToPoint
和addLineToPoint
方法绘制面板:
- (UIImage *)createBoardImageInRect:(CGRect)rect
{
int margin_y = 170;
UIGraphicsBeginImageContextWithOptions(CGSizeMake(rect.size.width, rect.size.height), NO, 0.0);
UIBezierPath *path = [UIBezierPath bezierPath];
startX = (rect.size.width - 7*BOARD_COLUMN_WIDTH)/2.0;
startY = rect.origin.y margin_y 30;
endY = rect.origin.y - margin_y rect.size.height;
for (int i=0; i<8; i ) {
CGPoint point = CGPointMake(startX i * BOARD_COLUMN_WIDTH, startY);
[path moveToPoint:point];
point = CGPointMake(startX i * BOARD_COLUMN_WIDTH, endY);
[path addLineToPoint:point];
}
CGPoint point = CGPointMake(startX, endY);
[path moveToPoint:point];
point = CGPointMake(rect.size.width - startX, endY);
[path addLineToPoint:point];
path.lineWidth = BOARD_LINE_WIDTH;
[[UIColor blueColor] setStroke];
[path stroke];
bezierPathWithOvalInRect
方法绘制由 AI 和人工移动的所有碎片–根据谁先采取行动,它开始交替绘制碎片,但顺序不同:
int columnPieces[] = {0,0,0,0,0,0,0};
if (aiFirst) {
for (int i=0; i<aiMoves.size(); i ) {
int action = aiMoves[i];
int column = action % 7;
CGRect r = CGRectMake(startX column * BOARD_COLUMN_WIDTH, endY - BOARD_COLUMN_WIDTH - BOARD_COLUMN_WIDTH * columnPieces[column], BOARD_COLUMN_WIDTH, BOARD_COLUMN_WIDTH);
UIBezierPath *path = [UIBezierPath bezierPathWithOvalInRect:r];
UIColor *color = [UIColor redColor];
[color setFill];
[path fill];
columnPieces[column] ;
if (i<humanMoves.size()) {
int action = humanMoves[i];
int column = action % 7;
CGRect r = CGRectMake(startX column * BOARD_COLUMN_WIDTH, endY - BOARD_COLUMN_WIDTH - BOARD_COLUMN_WIDTH * columnPieces[column], BOARD_COLUMN_WIDTH, BOARD_COLUMN_WIDTH);
UIBezierPath *path = [UIBezierPath bezierPathWithOvalInRect:r];
UIColor *color = [UIColor yellowColor];
[color setFill];
[path fill];
columnPieces[column] ;
}
}
}
else {
for (int i=0; i<humanMoves.size(); i ) {
int action = humanMoves[i];
int column = action % 7;
CGRect r = CGRectMake(startX column * BOARD_COLUMN_WIDTH, endY - BOARD_COLUMN_WIDTH - BOARD_COLUMN_WIDTH * columnPieces[column], BOARD_COLUMN_WIDTH, BOARD_COLUMN_WIDTH);
UIBezierPath *path = [UIBezierPath bezierPathWithOvalInRect:r];
UIColor *color = [UIColor yellowColor];
[color setFill];
[path fill];
columnPieces[column] ;
if (i<aiMoves.size()) {
int action = aiMoves[i];
int column = action % 7;
CGRect r = CGRectMake(startX column * BOARD_COLUMN_WIDTH, endY - BOARD_COLUMN_WIDTH - BOARD_COLUMN_WIDTH * columnPieces[column], BOARD_COLUMN_WIDTH, BOARD_COLUMN_WIDTH);
UIBezierPath *path = [UIBezierPath bezierPathWithOvalInRect:r];
UIColor *color = [UIColor redColor];
[color setFill];
[path fill];
columnPieces[column] ;
}
}
}
UIImage *image = UIGraphicsGetImageFromCurrentImageContext();
UIGraphicsEndImageContext();
return image;
}
现在运行该应用,您将看到类似于图 10.4 的屏幕:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kOcc3SVn-1681653119041)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/fa949c2a-3114-4a56-9041-9acb069b1ff4.png)]
图 10.4:在 iOS 上玩 Connect4
使用 AI 玩一些游戏,图 10.5 显示了一些可能的最终游戏:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-b6oYmCmM-1681653119041)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/1e45df24-e772-462b-aef9-67c8f49bed67.png)]
图 10.5:iOS 上 Connect4 的一些游戏结果
在我们暂停之前,让我们快速看一下使用该模型并玩游戏的 Android 代码。
在 Android 中使用模型玩 Connect4
毫不奇怪,我们不需要像第 7 章,“使用 CNN 和 LSTM 识别绘画”那样使用自定义 Android 库来加载模型。 只需创建一个名称为 AlphaZero 的新 Android Studio 应用,将alphazero19.pb
模型文件复制到新创建的素材资源文件夹,然后将 compile 'org.tensorflow:tensorflow-android: '
行添加到应用的build.gradle
文件。
我们首先创建一个新类BoardView
,该类扩展了View
并负责绘制游戏板以及 AI 和用户制作的棋子:
public class BoardView extends View {
private Path mPathBoard, mPathAIPieces, mPathHumanPieces;
private Paint mPaint, mCanvasPaint;
private Canvas mCanvas;
private Bitmap mBitmap;
private MainActivity mActivity;
private static final float MARGINX = 20.0f;
private static final float MARGINY = 210.0f;
private float endY;
private float columnWidth;
public BoardView(Context context, AttributeSet attrs) {
super(context, attrs);
mActivity = (MainActivity) context;
setPathPaint();
}
我们使用了mPathBoard
,mPathAIPieces
和mPathHumanPieces
这三个Path
实例分别绘制了板子,AI 做出的动作和人类做出的不同颜色的。 。 BoardView
的绘制功能是通过Path
的moveTo
和lineTo
方法以及Canvas
的drawPath
方法在onDraw
方法中实现的:
protected void onDraw(Canvas canvas) {
canvas.drawBitmap(mBitmap, 0, 0, mCanvasPaint);
columnWidth = (canvas.getWidth() - 2*MARGINX) / 7.0f;
for (int i=0; i<8; i ) {
float x = MARGINX i * columnWidth;
mPathBoard.moveTo(x, MARGINY);
mPathBoard.lineTo(x, canvas.getHeight()-MARGINY);
}
mPathBoard.moveTo(MARGINX, canvas.getHeight()-MARGINY);
mPathBoard.lineTo(MARGINX 7*columnWidth, canvas.getHeight()-
MARGINY);
mPaint.setColor(0xFF0000FF);
canvas.drawPath(mPathBoard, mPaint);
如果 AI 首先移动,我们开始绘制第一个 AI 移动,然后绘制第一个人类移动(如果有的话),并交替绘制 AI 和人类移动的图形:
代码语言:javascript复制 endY = canvas.getHeight()-MARGINY;
int columnPieces[] = {0,0,0,0,0,0,0};
for (int i=0; i<mActivity.getAIMoves().size(); i ) {
int action = mActivity.getAIMoves().get(i);
int column = action % 7;
float x = MARGINX column * columnWidth columnWidth /
2.0f;
float y = canvas.getHeight()-MARGINY-
columnWidth*columnPieces[column]-columnWidth/2.0f;
mPathAIPieces.addCircle(x,y, columnWidth/2,
Path.Direction.CW);
mPaint.setColor(0xFFFF0000);
canvas.drawPath(mPathAIPieces, mPaint);
columnPieces[column] ;
if (i<mActivity.getHumanMoves().size()) {
action = mActivity.getHumanMoves().get(i);
column = action % 7;
x = MARGINX column * columnWidth columnWidth /
2.0f;
y = canvas.getHeight()-MARGINY-
columnWidth*columnPieces[column]-columnWidth/2.0f;
mPathHumanPieces.addCircle(x,y, columnWidth/2,
Path.Direction.CW);
mPaint.setColor(0xFFFFFF00);
canvas.drawPath(mPathHumanPieces, mPaint);
columnPieces[column] ;
}
}
如果人先移动,则将应用类似的绘图代码,如 iOS 代码中一样。 在BoardView
的public boolean onTouchEvent(MotionEvent event)
内部,如果轮到 AI 了,则返回它,我们检查哪一列已被挖掘,并且如果该列还没有被全部六个可能的片断填满,则将新的人工移动添加到humanMoves
MainActivity
的向量,然后重绘视图:
public boolean onTouchEvent(MotionEvent event) {
if (mActivity.getAITurn()) return true;
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
break;
case MotionEvent.ACTION_MOVE:
break;
case MotionEvent.ACTION_UP:
if (y < MARGINY || y > endY) return true;
int column = (int)((x-MARGINX)/columnWidth);
for (int i=0; i<6; i )
if (mActivity.getBoard()[35 column-7*i] == 0) {
mActivity.getBoard()[35 column-7*i] =
MainActivity.HUMAN_PIECE;
mActivity.getHumanMoves().add(35 column-7*i);
break;
}
invalidate();
之后,将回合设置为 AI,如果游戏结束则返回。 否则,在人类可以触摸并选择下一步动作之前,让 AI 根据模型的策略返回进行下一步动作,以启动新线程继续玩游戏:
代码语言:javascript复制 mActivity.setAiTurn();
if (mActivity.gameEnded(mActivity.getBoard())) {
if (mActivity.aiWon(mActivity.getBoard()))
mActivity.getTextView().setText("AI Won!");
else if (mActivity.aiLost(mActivity.getBoard()))
mActivity.getTextView().setText("You Won!");
else if (mActivity.aiDraw(mActivity.getBoard()))
mActivity.getTextView().setText("Draw");
return true;
}
Thread thread = new Thread(mActivity);
thread.start();
break;
default:
return false;
}
return true;
}
UI 的主要布局是在activity_main.xml
中定义的,它由三个 UI 元素组成:TextView
,自定义BoardView
和Button
:
<TextView
android:id="@ id/textview"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text=""
android:textAlignment="center"
android:textColor="@color/colorPrimary"
android:textSize="24sp"
android:textStyle="bold"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintVertical_bias="0.06"/>
<com.ailabby.alphazero.BoardView
android:id="@ id/boardview"
android:layout_width="fill_parent"
android:layout_height="fill_parent"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent"/>
<Button
android:id="@ id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Play"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintVertical_bias="0.94" />
在MainActivity.java
中,首先定义一些常量和字段:
public class MainActivity extends AppCompatActivity implements Runnable {
private static final String MODEL_FILE =
"file:///android_asset/alphazero19.pb";
private static final String INPUT_NODE = "main_input";
private static final String OUTPUT_NODE1 = "value_head/Tanh";
private static final String OUTPUT_NODE2 = "policy_head/MatMul";
private Button mButton;
private BoardView mBoardView;
private TextView mTextView;
public static final int AI_PIECE = -1;
public static final int HUMAN_PIECE = 1;
private static final int PIECES_NUM = 42;
private Boolean aiFirst = false;
private Boolean aiTurn = false;
private Vector<Integer> aiMoves = new Vector<>();
private Vector<Integer> humanMoves = new Vector<>();
private int board[] = new int[PIECES_NUM];
private static final HashMap<Integer, String> PIECE_SYMBOL;
static
{
PIECE_SYMBOL = new HashMap<Integer, String>();
PIECE_SYMBOL.put(AI_PIECE, "X");
PIECE_SYMBOL.put(HUMAN_PIECE, "O");
PIECE_SYMBOL.put(0, "-");
}
private TensorFlowInferenceInterface mInferenceInterface;
然后像在 iOS 版本的应用中一样定义所有获胜职位:
代码语言:javascript复制 private final int winners[][] = {
{0,1,2,3},
{1,2,3,4},
{2,3,4,5},
{3,4,5,6},
{7,8,9,10},
{8,9,10,11},
{9,10,11,12},
{10,11,12,13},
...
{0,8,16,24},
{8,16,24,32},
{16,24,32,40},
{7,15,23,31},
{15,23,31,39},
{14,22,30,38}};
BoardView
类使用的一些获取器和设置器:
public boolean getAITurn() {
return aiTurn;
}
public boolean getAIFirst() {
return aiFirst;
}
public Vector<Integer> getAIMoves() {
return aiMoves;
}
public Vector<Integer> getHumanMoves() {
return humanMoves;
}
public int[] getBoard() {
return board;
}
public void setAiTurn() {
aiTurn = true;
}
还有一些助手,它们是 iOS 代码的直接端口,用于检查游戏状态:
代码语言:javascript复制 public boolean aiWon(int bd[]) {
for (int i=0; i<69; i ) {
int sum = 0;
for (int j=0; j<4; j )
sum = bd[winners[i][j]];
if (sum == 4*AI_PIECE ) return true;
}
return false;
}
public boolean aiLost(int bd[]) {
for (int i=0; i<69; i ) {
int sum = 0;
for (int j=0; j<4; j )
sum = bd[winners[i][j]];
if (sum == 4*HUMAN_PIECE ) return true;
}
return false;
}
public boolean aiDraw(int bd[]) {
boolean hasZero = false;
for (int i=0; i<PIECES_NUM; i ) {
if (bd[i] == 0) {
hasZero = true;
break;
}
}
if (!hasZero) return true;
return false;
}
public boolean gameEnded(int[] bd) {
if (aiWon(bd) || aiLost(bd) || aiDraw(bd)) return true;
return false;
}
getAllowedActions
方法(也是 iOS 代码的直接端口)将给定板位置的所有允许的动作设置为actions
向量:
void getAllowedActions(int bd[], Vector<Integer> actions) {
for (int i=0; i<PIECES_NUM; i ) {
if (i>=PIECES_NUM-7) {
if (bd[i] == 0)
actions.add(i);
}
else {
if (bd[i] == 0 && bd[i 7] != 0)
actions.add(i);
}
}
}
在onCreate
方法中,实例化三个 UI 元素,并设置按钮单击监听器,以便它随机决定谁先采取行动。 当用户想要重玩游戏时,也会点击该按钮,因此我们需要在绘制面板和启动线程进行游戏之前重置aiMoves
和humanMoves
向量:
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mButton = findViewById(R.id.button);
mTextView = findViewById(R.id.textview);
mBoardView = findViewById(R.id.boardview);
mButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
mButton.setText("Replay");
mTextView.setText("");
Random rand = new Random();
int n = rand.nextInt(2);
aiFirst = (n==0);
if (aiFirst) aiTurn = true;
else aiTurn = false;
if (aiTurn)
mTextView.setText("Waiting for AI's move");
else
mTextView.setText("Tap the column for your move");
for (int i=0; i<PIECES_NUM; i )
board[i] = 0;
aiMoves.clear();
humanMoves.clear();
mBoardView.drawBoard();
Thread thread = new Thread(MainActivity.this);
thread.start();
}
});
}
线程启动run
方法,该方法进一步调用playGame
方法,首先将板的位置转换为binary
整数数组,以用作模型的输入:
public void run() {
final String result = playGame();
runOnUiThread(
new Runnable() {
@Override
public void run() {
mBoardView.invalidate();
mTextView.setText(result);
}
});
}
String playGame() {
if (!aiTurn) return "Tap the column for your move";
int binary[] = new int[PIECES_NUM*2];
for (int i=0; i<PIECES_NUM; i )
if (board[i] == 1) binary[i] = 1;
else binary[i] = 0;
for (int i=0; i<PIECES_NUM; i )
if (board[i] == -1) binary[42 i] = 1;
else binary[PIECES_NUM i] = 0;
playGame
方法的其余部分也几乎是 iOS 代码的直接端口,它调用getProbs
方法以使用为所有操作返回的概率值来获取所有允许的操作中的最大概率值, 该模型的策略输出中总共包括 42 个法律和非法的:
float probs[] = new float[PIECES_NUM];
for (int i=0; i<PIECES_NUM; i )
probs[i] = -100.0f;
getProbs(binary, probs);
int action = -1;
float max = 0.0f;
for (int i=0; i<PIECES_NUM; i ) {
if (probs[i] > max) {
max = probs[i];
action = i;
}
}
board[action] = AI_PIECE;
printBoard(board);
aiMoves.add(action);
if (aiWon(board)) return "AI Won!";
else if (aiLost(board)) return "You Won!";
else if (aiDraw(board)) return "Draw";
aiTurn = false;
return "Tap the column for your move";
}
如果尚未加载getProbs
方法,则加载模型;使用当前板状态作为输入运行模型;并在调用softmax
以获得真实概率值之前获取输出策略,该值之和对于允许的动作为 1:
void getProbs(int binary[], float probs[]) {
if (mInferenceInterface == null) {
AssetManager assetManager = getAssets();
mInferenceInterface = new
TensorFlowInferenceInterface(assetManager, MODEL_FILE);
}
float[] floatValues = new float[2`6`7];
for (int i=0; i<2`6`7; i ) {
floatValues[i] = binary[i];
}
float[] value = new float[1];
float[] policy = new float[42];
mInferenceInterface.feed(INPUT_NODE, floatValues, 1, 2, 6, 7);
mInferenceInterface.run(new String[] {OUTPUT_NODE1, OUTPUT_NODE2},
false);
mInferenceInterface.fetch(OUTPUT_NODE1, value);
mInferenceInterface.fetch(OUTPUT_NODE2, policy);
Vector<Integer> actions = new Vector<>();
getAllowedActions(board, actions);
for (int action : actions) {
probs[action] = policy[action];
}
softmax(probs, PIECES_NUM);
}
softmax
方法的定义与 iOS 版本中的定义几乎相同:
void softmax(float vals[], int count) {
float maxval = -Float.MAX_VALUE;
for (int i=0; i<count; i ) {
maxval = max(maxval, vals[i]);
}
float sum = 0.0f;
for (int i=0; i<count; i ) {
vals[i] = (float)exp(vals[i] - maxval);
sum = vals[i];
}
for (int i=0; i<count; i ) {
vals[i] /= sum;
}
}
现在,在 Android 虚拟或真实设备上运行该应用并使用该应用进行游戏,您将看到初始屏幕和一些游戏结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aUqRx2QC-1681653119041)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/2d40ca7c-5d3d-4586-9034-2b2ae1c71ecb.png)]
图 10.6:在 Android 上显示游戏板和一些结果
当您使用前面的代码在 iOS 和 Android 上玩游戏时,很快就会发现该模型返回的策略并不强大-主要原因是 MCTS 没有出现在这里,由于范围限制,不会与深度神经网络模型一起使用。 强烈建议您自己研究和实现 MCTS,或者在源代码存储库中使用我们的实现作为参考。 您还应该将网络模型和 MCTS 应用于您感兴趣的其他游戏-毕竟,AlphaZero 使用了通用 MCTS 和无领域知识的自我强化学习,从而使超人学习轻松移植到其他问题领域。 通过将 MCTS 与深度神经网络模型结合,您可以实现 AlphaZero 所做的事情。
总结
在本章中,我们介绍了 AlphaZero 的惊人世界,这是 DeepMind 截至 2017 年 12 月的最新和最大成就。我们向您展示了如何使用功能强大的 Keras API 和 TensorFlow 后端为 Connect4 训练类似 AlphaZero 的模型,以及如何测试并可能改善这种模型。 然后,我们冻结了该模型,并详细介绍了如何构建 iOS 和 Android 应用以使用该模型,以及如何使用基于模型的 AI 玩 Connect4。 尚不能完全击败人类象棋或 GO 冠军的确切 AlphaZero 模型,但我们希望本章为您提供扎实的基础,并激发您继续进行工作,以复制 AlphaZero 最初所做的工作并将其进一步扩展到其他问题领域。 这将需要很多努力,但完全值得。
如果最新的 AI 进展(例如 AlphaZero)使您兴奋不已,那么您还可能会发现由 TensorFlow 驱动的最新移动平台解决方案或工具包令人兴奋。 如我们在第 1 章“移动 TensorFlow 入门”中提到的,TensorFlow Lite 是 TensorFlow Mobile 的替代解决方案,我们在前面的所有章节中都有介绍。 根据 Google 的说法,TensorFlow Lite 将成为 TensorFlow 在移动设备上的未来,尽管在此时和可预见的将来,TensorFlow Mobile 仍应用于生产场合。
虽然 TensorFlow Lite 在 iOS 和 Android 上均可使用,但在 Android 设备上运行时,它也可以利用 Android Neural Networks API 进行硬件加速。 另一方面,iOS 开发人员可以利用 Core ML, Apple 针对 iOS 11 或更高版本的最新机器学习框架,该框架支持运行许多强大的预训练深度学习模型,以及使用经典的机器学习算法和 Keras,以优化的方式在设备上以最小的应用二进制文件大小运行。 在下一章中,我们将介绍如何在 iOS 和 Android 应用中使用 TensorFlow Lite 和 Core ML。