TensorFlow:使用Cloud TPU在30分钟内训练出实时移动对象检测器

2018-07-27 14:38:11 浏览数 (1)

编译:yxy

出品:ATYUN订阅号

是否能够更快地训练和提供对象检测模型?我们已经听到了这种的反馈,在今天我们很高兴地宣布支持训练Cloud TPU上的对象检测模型,模型量化以及并添加了包括RetinaNet和MobileNet改编的RetinaNet在内的新模型。本文将引导你使用迁移学习在Cloud TPU上训练量化的宠物品种检测器。

公告:https://ai.googleblog.com/2018/07/accelerated-training-and-inference-with.html

整个过程,从训练到在Android设备上推理 只需要30分钟,Google云的花费不到5美元。完成后,你将拥有一个Android应用程序(即将推出的iOS教程!),可以对狗和猫品种进行实时检测,并且手机上的空间不超过12M。请注意,除了在云中训练对象检测模型之外,你也可以在自己的硬件或Colab上运行训练。

设置你的环境

我们将首先建立训练模型所需的一些库和其他先决条件。请注意,设置过程可能比训练模型本身花费更长的时间。为方便起见,你可以使用Dockerfile,它提供了从源代码安装Tensorflow并下载本教程所需的数据集和模型的依赖项。。如果你决定使用Docker,则仍应使用“Google Cloud Setup”部分,然后跳至“将数据集上传到GCS”部分。Dockerfile还将为Tensorflow Lite部分构建Android依赖项。更多信息,请参阅随附的README文件。

Dockerfile:https://github.com/tensorflow/models/blob/master/research/object_detection/dockerfiles/android/Dockerfile

Google云设置

首先,在谷歌云控制台中创建一个项目,并启用该项目的计费。我们使用Cloud Machine Learning Engine在Cloud TPU上运行我们的训练工作。ML Engine是Google Cloud的TensorFlow托管平台,它简化了训练和提供机器学习模型的过程。要使用它,请为刚刚创建的项目启用必要的API。

API:https://console.cloud.google.com/flows/enableapi?apiid=ml.googleapis.com,compute_component&_ga=2.43515109.-1978295503.1509743045

其次,我们将创建一个Google云存储桶,用于存储我们模型的训练和测试数据,以及我们训练工作中的模型检查点。

请注意,本教程中的所有命令都假设你正在运行Ubuntu。对于本教程中的许多命令,我们将使用Google Cloud gcloud CLI,并和Cloud Storage gsutil CLI一起与我们的GCS存储桶交互。如果你没有安装它们,你可以在访问下方链接安装

gcloud:https://cloud.google.com/sdk/docs/quickstart-debian-ubuntu

gsutil:https://cloud.google.com/storage/docs/gsutil_install

运行以下命令将当前项目设置为刚创建的项目,将YOUR_PROJECT_NAME替换为项目名称:

代码语言:javascript复制
gcloud config set project YOUR_PROJECT_NAME

然后,我们将使用以下命令创建云存储桶。请注意,存储桶名称必须全局唯一,因此如果选择的名称被占用,则可能会出错。

gsutil mb gs:// YOUR_UNIQUE_BUCKET_NAME

这里可能会提示你先运行gcloud auth login,之后你需要提供验证码。

然后在本教程中设置两个环境变量以简化命令:

代码语言:javascript复制
export PROJECT="YOUR_PROJECT_ID"
export YOUR_GCS_BUCKET="YOUR_UNIQUE_BUCKET_NAME"

接下来,为了让我们的Cloud TPU能够访问我们的项目,我们需要添加一个特定的TPU服务帐户。首先,使用以下命令获取服务帐户的名称:

代码语言:javascript复制
curl -H "Authorization: Bearer $(gcloud auth print-access-token)"  
    https://ml.googleapis.com/v1/projects/${PROJECT}:getConfig

当此命令完成时,复制tpuServiceAccount(它看起来像your-service-account-12345@cloud-tpu.iam.gserviceaccount.com)的值,然后将其保存为环境变量:

代码语言:javascript复制
export TPU_ACCOUNT=your-service-account

最后,允许ml.serviceAgent任务到你的TPU服务帐户

代码语言:javascript复制
gcloud projects add-iam-policy-binding $PROJECT  
    --member serviceAccount:$TPU_ACCOUNT --role roles/ml.serviceAgent

安装Tensorflow

如果你没有安装TensorFlow,请按照官网步骤操作。要按照本教程的设备上的部分进行操作,你需要按照下方链接的说明使用Bazel从源代码安装TensorFlow 。编译TensorFlow可能需要一段时间。如果你只想按照本教程的Cloud TPU训练部分进行操作,则无需从源代码编译TensorFlow,并且可以通过pip,Anaconda等安装已发布的版本。

链接:https://www.tensorflow.org/install/install_sources

安装TensorFlow对象检测

如果这是你第一次使用TensorFlow对象检测,欢迎!

安装:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md

安装对象检测后,请通过运行以下命令来测试安装:

代码语言:javascript复制
python object_detection / builders / model_builder_test.py

如果安装成功,应该看到以下输出:

代码语言:javascript复制
Ran 18 tests in 0.079s

OK

设置数据集

为了简单起见,我们将使用上一篇文章中关于训练对象检测模型的相同宠物品种数据集。该数据集包括大约7,400张图像 - 37种不同品种的猫和狗图像,每种200张图像。每个图像都有一个关联的注释文件,其中包括特定宠物在图像中所在的边界框坐标。我们不能直接将这些图像和注释提供给我们的模型;而是需要将它们转换为我们的模型可以理解的格式。为此,我们将使用TFRecord格式。

上一篇文:https://cloud.google.com/blog/big-data/2017/06/training-an-object-detector-using-cloud-machine-learning-engine

为了直接参加训练,我们在这里公开了pet_faces_train.record和pet_faces_val.record文件(下方链接下载)。可以使用公共TFRecord文件,或者如果你想自己生成它们,请按照GitHub上的步骤操作。

链接:http://download.tensorflow.org/models/object_detection/pet_faces_tfrecord.tar.gz

GitHub:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/preparing_inputs.md#generating-the-oxford-iiit-pet-tfrecord-files

可以使用以下命令下载并解压缩公共TFRecord文件:

代码语言:javascript复制
mkdir /tmp/pet_faces_tfrecord/
cd /tmp/pet_faces_tfrecord/
curl "http://download.tensorflow.org/models/object_detection/pet_faces_tfrecord.tar.gz" | tar xzf -

请注意,这些TFRecord文件是分片的,因此提取它们,你会有10个pet_faces_train.record文件和10个pet_faces_val.record文件。

将数据集上载到GCS

在本地获得TFRecord文件后,将它们复制到/data子目录下的GCS存储桶中:

代码语言:javascript复制
gsutil -m cp -r / tmp / pet_faces_tfrecord / pet_faces * gs:// $ {YOUR_GCS_BUCKET} / data /

使用GCS中的TFRecord文件,返回models/research本地计算机上的目录。接下来,你将在GCS存储桶中添加该pet_label_map.pbtxt文件。这将我们将要检测的37个宠物品种中的每一个映射到整数,以便我们的模型可以以数字格式理解它们。从models/research目录中,运行以下命令:

代码语言:javascript复制
gsutil cp object_detection / data / pet_label_map.pbtxt gs:// $ {YOUR_GCS_BUCKET} /data/pet_label_map.pbtxt

此时,在GCS bucket的/data子目录中应该有21个文件:20个用于训练和测试的分片TFRecord文件,以及标签映射文件。

使用SSD MobileNet检查点进行迁移学习

从头开始训练模型以识别宠物品种需要为每个宠物品种拍摄数千张训练图像并花费数小时或数天的训练时间。为了加快这一速度,我们可以利用迁移学习  - 我们采用已经在大量数据上训练执行类似的任务的模型权重来,然后用我们自己的数据上训练模型,微调预训练模型的层。

我们可以使用许多模型来训练识别图像中的各种对象。我们可以使用这些训练模型中的检查点,然后将它们应用于我们的自定义对象检测任务。这是有效的,对于机器而言,识别包含基本对象(如桌子,椅子或猫)的图像中的像素的任务与识别包含特定宠物品种的图像中的像素区别不大。

对于这个例子,我们使用MobileNet的SSD,MobileNet是一种针对移动设备进行优化的对象检测模型。首先,下载并提取已在COCO数据集上预训练的最新MobileNet检查点。要查看Object Detection API支持的所有模型的列表,请查看下方链接(model zoo)。提取检查点后,将3个文件复制到GCS存储桶中。运行以下命令下载检查点并将其复制到存储桶中:

代码语言:javascript复制
cd / tmp
curl -O http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03.tar.gz
tar xzf ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03.tar.gz

gsutil cp /tmp/ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03/model.ckpt。* gs:// $ {YOUR_GCS_BUCKET} / data /

model zoo:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

当我们训练我们的模型时,它将使用这些检查点作为训练的起点。现在,你的GCS存储桶中应该有24个文件。我们几乎准备好开展我们的训练工作,但我们需要一个方法来告诉ML Engine我们的数据和模型检查点的位置。我们将使用配置文件执行此操作,我们将在下一步中设置该配置文件。我们的配置文件为我们的模型提供超参数,以及我们的训练数据、测试数据和初始模型检查点的文件路径。

使用Cloud ML Engine上使用Cloud TPU训练量化模型

机器学习模型有两个不同的计算组件:训练和推理。在此示例中,我们正在利用Cloud TPU来加速训练。配置文件中有几行专门与TPU训练相关。我们可以在TPU训练时使用更大的批尺寸,因为它们可以更轻松地处理大型数据集(在你自己的数据集上试验批尺寸时,请使用8的倍数,因为数据需要均匀分配8个TPU核心)。对于我们的模型来说,使用更大的批尺寸,我们可以减少训练步骤的数量(在本例中我们使用2000)。我们用于此训练作业的focal loss函数(在配置中的以下行中定义)也非常适合TPU:

代码语言:javascript复制
loss {
  classification_loss {
    weighted_sigmoid_focal {
      alpha: 0.75,
      gamma: 2.0
    }
  }

损失函数计算数据集中每个实例的损失,然后重新计算权重,将更多的相对权重分配给难分类的实例。与其他训练工作中使用的难的实例挖掘操作相比,它更适合TPU(更多,阅读论文:https://arxiv.org/abs/1708.02002)。

综上,初始化预训练模型检查点然后添加我们自己的训练数据的过程称为迁移学习。配置中的以下几行告诉我们的模型,我们将从预先训练的检查点开始进行对象检测的迁移学习。

代码语言:javascript复制
fine_tune_checkpoint: "gs://your-bucket/data/model.ckpt"
fine_tune_checkpoint_type: "detection"

我们还需要考虑我们的模型在经过训练后如何使用。假设我们的宠物检测器成为全球热门,动物爱好者和宠物商店随处可见。我们需要一种可扩展的方法来以低延迟处理这些推理请求。机器学习模型的输出是一个二进制文件,其中包含我们模型的训练权重 - 这些文件通常非常大,但由于我们将直接在移动设备上提供此模型,我们需要将其设置到尽可能小。

这时就要用到模型量化。量化将我们模型中的权重和激活压缩为8位定点表示。配置文件中的以下行将生成量化模型:

代码语言:javascript复制
graph_rewriter {
  quantization {
    delay: 1800
    activation_bits: 8
    weight_bits: 8
  }
}

通常通过量化,一个模型在转换到量化训练之前,会对一定数量的步骤进行完全精确的训练。上面的延迟(delay)数告诉ML Engine在1800个训练步骤后开始量化我们的权重和激活。

要告诉ML Engine在哪里找到我们的训练和测试文件以及模型检查点,你需要在我们为你创建的配置文件中更新几行,以指向你的存储桶。从研究目录中,找到object_detection/samples/configs/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config。使用GCS数据目录的完整路径更新所有PATH_TO_BE_CONFIGURED字符串。例如,train_input_reader配置的部分将如下所示(确保替换YOUR_GCS_BUCKET为你的存储桶的名称):

代码语言:javascript复制
train_input_reader:{
  tf_record_input_reader {
    input_path:“gs:// YOUR_GCS_BUCKET / data / pet_faces_train *”
  }
  label_map_path:“gs://YOUR_GCS_BUCKET/data/pet_label_map.pbtxt”
}

然后将此量化的配置文件复制到你的GCS存储桶中:

代码语言:javascript复制
gsutil cp object_detection / samples / configs / ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config gs:// $ {YOUR_GCS_BUCKET} /data/pipeline.config

在我们启动Cloud ML Engine的训练工作之前,我们需要打包Object Detection API,pycocotools和TF Slim。我们可以使用以下命令执行此操作(从research/目录运行此命令,同时注意,括号也是命令的一部分):

代码语言:javascript复制
bash object_detection / dataset_tools / create_pycocotools_package.sh / tmp / pycocotools
python setup.py sdist
(cd slim && python setup.py sdist)

我们准备训练我们的模型了!要启动训练,请运行以下gcloud命令:

代码语言:javascript复制
gcloud ml-engine jobs submit training `whoami`_object_detection_`date  %s` 
--job-dir=gs://${YOUR_GCS_BUCKET}/train 
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz 
--module-name object_detection.model_tpu_main 
--runtime-version 1.8 
--scale-tier BASIC_TPU 
--region us-central1 
-- 
--model_dir=gs://${YOUR_GCS_BUCKET}/train 
--tpu_zone us-central1 
--pipeline_config_path=gs://${YOUR_GCS_BUCKET}/data/pipeline.config

请注意,如果你到错误消息,指出没有可用的Cloud TPU,我们建议你只在另一个区域重试(Cloud TPU目前在us-central1-b,us-central1-c,europe-west4-a和asia-east1-c上可用)。

在我们开始我们的训练工作后,运行以下命令来开始评估工作:

代码语言:javascript复制
gcloud ml-engine jobs submit training `whoami`_object_detection_eval_validation_`date  %s` 
--job-dir=gs://${YOUR_GCS_BUCKET}/train 
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz 
--module-name object_detection.model_main 
--runtime-version 1.8 
--scale-tier BASIC_GPU 
--region us-central1 
-- 
--model_dir=gs://${YOUR_GCS_BUCKET}/train 
--pipeline_config_path=gs://${YOUR_GCS_BUCKET}/data/pipeline.config 
--checkpoint_dir=gs://${YOUR_GCS_BUCKET}/train

训练和评估都应在大约30分钟内完成。在运行时,你可以使用TensorBoard查看模型的准确性。要启动TensorBoard,请运行以下命令:

代码语言:javascript复制
tensorboard --logdir = gs:// $ {YOUR_GCS_BUCKET} / train

请注意,你可能需要先运行gcloud auth application-default login。

导航到localhost:6006查看你的TensorBoard输出。在这里,你将看到一些常用的ML指标,用于分析模型的准确性。请注意,这些图表仅绘制了2个点,因为模型在很短的步骤中快速训练(如果你在使用TensorBoard之前可能会习惯于在此处查看更多曲线)。这里的第一点是训练过程的早期,最后一点显示最后一步的权重(步骤2000)。

首先,让我们看一下0.5 IOU(mAP @ .50IOU)平均精度的图表:

平均精确度衡量我们模型对所有37个标签的正确预测百分比。IoU特定于对象检测模型,代表Intersection-over-Union。这测量我们的模型生成的边界框与地面实况边界框之间的重叠,以百分比表示。此图表测量我们的模型返回的正确边界框和标签的百分比,在这种情况下“正确”指的是与其对应的地面真值边框重叠50%或更多。训练后,我们的模型实现了82%的平均精确度。

接下来,查看TensorBoard 中的Images选项卡:

在左图中,我们看到了模型对此图像的预测,在右侧我们看到了正确的地面真值边框。边界框非常准确,但在这种特殊情况下,我们模型的标签预测是不正确的。没有ML模型可以是完美的。:)

使用TensorFlow Lite在移动设备上运行

此时,你以及拥有了一个训练好的宠物种类检测器,你可以使用Colab notebook在零点设置的情况下在浏览器中测试你自己的图像。要在手机上实时运行此模型需要一些额外的步骤。在本节中,我们将向你展示如何使用TensorFlow Lite获得更小的模型,并允许你利用针对移动设备优化的操作。TensorFlow Lite是TensorFlow针对移动和嵌入式设备的轻量级解决方案。它支持设备内机器学习推理,具有低延迟和小的二进制尺寸。TensorFlow Lite使用了许多技术,例如允许更小和更快(定点数学)模型的量化内核。

Colab notebook:https://colab.research.google.com/github/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb

如上所述,对于本节,你需要使用提供的Dockerfile,或者从源构建TensorFlow(支持GCP)并安装bazel构建工具。请注意,如果你只想在不训练模型的情况下完成本教程的第二部分,我们提供了一个预训练的模型。

预训练模型:https://storage.googleapis.com/download.tensorflow.org/models/tflite/ssd_mobilenet_v1_0.75_depth_300x300_quant_pets_2018_06_29.zip

为了使这些命令更容易运行,让我们设置一些环境变量:

代码语言:javascript复制
export CONFIG_FILE = gs:// $ {YOUR_GCS_BUCKET} /data/pipeline.config
export CHECKPOINT_PATH = gs:// $ {YOUR_GCS_BUCKET} /train/model.ckpt-2000
export OUTPUT_DIR = / tmp / tflite

我们首先获得一个TensorFlow冻结图(frozen graph),其中包含我们可以与TensorFlow Lite一起使用的兼容操作。首先,你需要安装这些python库。然后,要获取冻结图,请使用以下命令从models/research目录运行脚本export_tflite_ssd_graph.py:

代码语言:javascript复制
python object_detection/export_tflite_ssd_graph.py 
--pipeline_config_path=$CONFIG_FILE 
--trained_checkpoint_prefix=$CHECKPOINT_PATH 
--output_directory=$OUTPUT_DIR 
--add_postprocessing_op=true

在/tmp/tflite目录中,你现在应该看到两个文件:tflite_graph.pb 和tflite_graph.pbtxt(样本冻结图见下方链接)。请注意,add_postprocessing标志使模型能够利用自定义最优化检测的后处理操作,可被视为替代tf.image.non_max_suppression。确保不要将同一个目录中的export_tflite_ssd_graph与export_inference_graph混淆。这两个脚本都输出了冻结图:export_tflite_ssd_graph输出我们可以直接输入到TensorFlow Lite的冻结图,并且这是我们要使用的图。

链接:https://storage.googleapis.com/download.tensorflow.org/models/tflite/frozengraphs_ssd_mobilenet_v1_0.75_quant_pets_2018_06_29.zip

接下来,我们将使用TensorFlow Lite获得优化模型,我们要使用的是TOCO(TensorFlow Lite Optimizing Converter)。这将通过以下命令将生成的冻结图(tflite_graph.pb)转换为TensorFlow Lite flatbuffer格式(detec .tflite)。从tensorflow /目录运行:

代码语言:javascript复制
bazel run -c opt tensorflow / contrib / lite / toco:toco  -  
--input_file = $ OUTPUT_DIR /tflite_graph.pb 
--output_file = $ OUTPUT_DIR /detect.tflite 
--input_shapes = 1,300,300,3 
--input_arrays = normalized_input_image_tensor 
--output_arrays = ' TFLite_Detection_PostProcess ',' TFLite_Detection_PostProcess:1 ',' TFLite_Detection_PostProcess:2 ',' TFLite_Detection_PostProcess:3 '   
--inference_type = QUANTIZED_UINT8 
--mean_values = 128 
--std_values = 128 
--change_concat_input_ranges = false 
--allow_custom_ops

调整每张图像到300x300之后,这个命令获取normalized_input_image_tensor输入。量化模型的输出被命名为‘TFLite_Detection_PostProcess’,‘TFLite_Detection_PostProcess:1’,‘TFLite_Detection_PostProcess:2’,和‘TFLite_Detection_PostProcess:3’,代表四个数组:detection_boxes,detection_classes,detection_scores,和num_detections。如果成功运行,你现在应该在/tmp/tflite目录中看到第三个文件detect.tflite(示例tflite文件如下)。文件包含图形和所有模型参数,可以通过Android设备上的TensorFlow Lite解释器运行,并且应该小于4 Mb。

示例tflite文件:https://storage.googleapis.com/download.tensorflow.org/models/tflite/pets_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip

标志文档:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md

在Android上运行我们的模型

要在设备上运行我们的最终模型,我们需要使用提供的Dockerfile,或者安装Android NDK和SDK。目前推荐的Android NDK版本为14b ,可以在NDK档案页上找到。请注意,Bazel的当前版本与NDK 15及更高版本不兼容。Android SDK和构建工具可以单独下载,也可以作为Android Studio的一部分使用。要构建TensorFlow Lite Android demo,构建工具需要API >= 23(但它将在API> = 21的设备上运行)。其他详细信息可在TensorFlow Lite Android App页面上找到。

14b:https://dl.google.com/android/repository/android-ndk-r14b-linux-x86_64.zip

NDK:https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads

TensorFlow Lite Android App page:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md

在尝试获得刚训练的宠物分类模型之前,首先运行带有默认模型的演示应用程序,该模型是在COCO数据集上训练的。要构建演示应用程序,请从tensorflow目录运行bazel命令:

代码语言:javascript复制
bazel build -c opt --config=android_arm{,64} --cxxopt='--std=c  11' 
//tensorflow/contrib/lite/examples/android:tflite_demo

上面的apk将针对64位架构而构建,你可以用-- config=android_arm替换它,以获得32位支持。现在通过Android Debug Bridge(adb)在支持调试的 Android手机上安装演示程序:

代码语言:javascript复制
adb install bazel-bin / tensorflow / contrib / lite / examples / android / tflite_demo.apk

尝试运行这个初级app(TFLDetect),把你的相机对准人、家具、汽车、宠物等等。你将在检测到的对象周围看到带有标签的框。运行的测试应用程序是使用COCO数据集训练的。

示例:https://www.youtube.com/watch?v=jU5jYwbMTPQ&feature=youtu.be

当你使用通用检测器时,将其替换为你定制的宠物检测器非常简单。我们需要做的就是将应用程序指向我们新的detect.tflite文件,并为其指定新标签的名称。具体来说,我们使用以下命令将我们的TensorFlow Lite flatbuffer复制到app assets目录:

代码语言:javascript复制
cp /tmp/tflite/detect.tflite 
tensorflow/contrib/lite/examples/android/app/src/main/assets

我们现在将编辑BUILD文件以指向这个新模型。首先,打开BUILD文件tensorflow/contrib/lite/examples/android/BUILD。然后找到assets部分,并将行“@tflite_mobilenet_ssd_quant//:detect.tflite”(默认情况下指向COCO预训练模型)替换为你的TFLite宠物模型“ //tensorflow/contrib/lite/examples/android/app/src/main/assets:detect.tflite” 的路径。最后,更改assets部分的最后一行以使用新的标签映射。最后,assets部分应如下所示:

代码语言:javascript复制
assets = [
   "//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
     "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
     "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite",
     "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt",
     "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite",
     "//tensorflow/contrib/lite/examples/android/app/src/main/assets:detect.tflite",
     "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt",
     "//tensorflow/contrib/lite/examples/android/app/src/main/assets:pets_labels_list.txt",
],

我们还需要告诉我们的应用程序使用新的标签映射。在文本编辑器中打开tensorflow / contrib / lite / examples / android / app / src / main / java / org / tensorflow / demo / DetectorActivity.java文件,找到其定义TF_OD_API_LABELS_FILE。更新此路径以指向你的宠物标签映射文件:“ file:///android_asset/pets_labels_list.txt”。请注意,为了您的方便,我们已提供了pets_labels_list.txt文件。DetectorActivity.java的新部分(第50行)现在应该如下所示:

代码语言:javascript复制
// Configuration values for the prepackaged SSD model.
private static final int TF_OD_API_INPUT_SIZE = 300;
private static final boolean TF_OD_API_IS_QUANTIZED = true;
private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/pets_labels_list.txt";

复制TensorFlow Lite文件并编辑BUILD和DetectorActivity.java文件后,使用以下命令重建并重新安装应用程序:

代码语言:javascript复制
bazel build -c opt --config=android_arm{,64} --cxxopt='--std=c  11' 
//tensorflow/contrib/lite/examples/android:tflite_demo
adb install -r bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk

现在来看最精彩的部分:找到最近的狗或猫,并尝试检测它。在像素2上,我们每秒大于15帧。

对象检测文档:https://github.com/tensorflow/models/tree/master/research/object_detection/g3doc

资源:https://github.com/tensorflow/models/tree/master/research/object_detection/dataset_tools

GitHub:https://github.com/tensorflow/models/tree/master/research/object_detection

0 人点赞