Tensorflow Lite Model Maker --- 图像分类篇+源码

2021-10-11 10:32:18 浏览数 (1)

TFLite_tutorials

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications. 解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署

下表罗列的是 TFLite Model Maker 目前支持的几个任务类型

Supported Tasks

Task Utility

Image Classification: tutorial, api

Classify images into predefined categories.

Object Detection: tutorial, api

Detect objects in real time.

Text Classification: tutorial, api

Classify text into predefined categories.

BERT Question Answer: tutorial, api

Find the answer in a certain context for a given question with BERT.

Audio Classification: tutorial, api

Classify audio into predefined categories.

Recommendation: demo, api

Recommend items based on the context information for on-device scenario.

If your tasks are not supported, please first use TensorFlow to retrain a TensorFlow model with transfer learning (following guides like images, text, audio) or train it from scratch, and then convert it to TensorFlow Lite model. 解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite

想用使用 Tensorflow Lite Model Maker 我们需要先安装:

代码语言:javascript复制
pip install tflite-model-maker

本质完成的是分类任务 更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等

下面的代码片段是用于下载数据集的

代码语言:javascript复制
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

数据集结构如下所示: flower_photos |__ daisy |__ 100080576_f52e8ee070_n.jpg |__ 14167534527_781ceb1b7a_n.jpg |__ ... |__ dandelion |__ 10043234166_e6dd915111_n.jpg |__ 1426682852_e62169221f_m.jpg |__ ... |__ roses |__ 102501987_3cdb8e5394_n.jpg |__ 14982802401_a3dfb22afb.jpg |__ ... |__ sunflowers |__ 12471791574_bb1be83df4.jpg |__ 15122112402_cafa41934f.jpg |__ ... |__ tulips |__ 13976522214_ccec508fe7.jpg |__ 14487943607_651e8062a1_m.jpg |__ ...

加载数据集并切分

代码语言:javascript复制
data = DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
代码语言:javascript复制
assert tf.__version__.startswith('2')

判断是否为 '2' 开头

模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律

代码语言:javascript复制
import os
import time
​
import numpy as np
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
​
assert tf.__version__.startswith('2')
​
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
​
data = DataLoader.from_folder(image_path)
# data = data.gen_dataset(batch_size=1)
train_data, rest_data = data.split(0.8)
# for batch in data.take(1):
#     print(batch)
#     break
​
validation_data, test_data = rest_data.split(0.5)
​
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('efficientnet_lite0'), epochs=20)
​
loss, accuracy = model.evaluate(test_data)
​
model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))
​
start = time.time()
print(model.evaluate_tflite('./testTFlite/model.tflite', test_data))
end = time.time()
print('elapsed time: ', end - start)

从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0) 从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s

代码语言:javascript复制
config = QuantizationConfig.for_float16()
model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))

如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多

代码语言:javascript复制
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('mobilenet_v2'), epochs=20)

将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s

代码语言:javascript复制
inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=inception_v3_spec, epochs=20)

将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s

代码语言:javascript复制
Common Dataset used for tasks.
​
class DataLoader(object):
  """This class provides generic utilities for loading customized domain data that will be used later in model retraining.
​
  For different ML problems or tasks, such as image classification, text
  classification etc., a subclass is provided to handle task-specific data
  loading requirements.
  """
​
  def __init__(self, dataset, size):
    """Init function for class `DataLoader`.
​
    In most cases, one should use helper functions like `from_folder` to create
    an instance of this class.
​
    Args:
      dataset: A tf.data.Dataset object that contains a potentially large set of
        elements, where each element is a pair of (input_data, target). The
        `input_data` means the raw input data, like an image, a text etc., while
        the `target` means some ground truth of the raw input data, such as the
        classification label of the image etc.
      size: The size of the dataset. tf.data.Dataset donesn't support a function
        to get the length directly since it's lazy-loaded and may be infinite.
    """
    self._dataset = dataset
    self._size = size
​
  def gen_dataset(self,
                  batch_size=1,
                  is_training=False,
                  shuffle=False,
                  input_pipeline_context=None,
                  preprocess=None,
                  drop_remainder=False):
    """Generate a shared and batched tf.data.Dataset for training/evaluation.
代码语言:javascript复制
Image dataloader
​
class ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
  """DataLoader for image classifier."""
​
  @classmethod
  def from_folder(cls, filename, shuffle=True):
    """Image analysis for image classification load images with labels.
​
    Assume the image data of the same label are in the same subdirectory.
​
    Args:
      filename: Name of the file.
      shuffle: boolean, if shuffle, random shuffle data.
​
    Returns:
      ImageDataset containing images and labels and other related info.
    """
   @classmethod
   def from_tfds(cls, name):
     """Loads data from tensorflow_datasets."""
代码语言:javascript复制
ImageNet preprocessing
​
class Preprocessor(object):
  """Preprocessing for image classification."""
​
  def __init__(self,
               input_shape,
               num_classes,
               mean_rgb,
               stddev_rgb,
               use_augmentation=False):
    self.input_shape = input_shape
    self.num_classes = num_classes
    self.mean_rgb = mean_rgb
    self.stddev_rgb = stddev_rgb
    self.use_augmentation = use_augmentation
​
  def __call__(self, image, label, is_training=True):
    if self.use_augmentation:
      return self._preprocess_with_augmentation(image, label, is_training)
    return self._preprocess_without_augmentation(image, label)
​
  def _preprocess_with_augmentation(self, image, label, is_training):
    """Image preprocessing method with data augmentation."""
    image_size = self.input_shape[0]
    if is_training:
      image = preprocess_for_train(image, image_size)
    else:
      image = preprocess_for_eval(image, image_size)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
​
  # TODO(yuqili): Changes to preprocess to support batch input.
  def _preprocess_without_augmentation(self, image, label):
    """Image preprocessing method without data augmentation."""
    image = tf.cast(image, tf.float32)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    image = tf.compat.v1.image.resize(image, self.input_shape)
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
代码语言:javascript复制
class ImageClassifier(classification_model.ClassificationModel):
  """ImageClassifier class for inference and exporting to tflite."""
​
  def __init__(self,
               model_spec,
               index_to_label,
               shuffle=True,
               hparams=hub_lib.get_default_hparams(),
               use_augmentation=False,
               representative_data=None):
    """Init function for ImageClassifier class.
​
    Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      hparams: A namedtuple of hyperparameters. This function expects
        .dropout_rate: The fraction of the input units to drop, used in dropout
          layer.
        .do_fine_tuning: If true, the Hub module is trained together with the
          classification layer on top.
      use_augmentation: Use data augmentation for preprocessing.
      representative_data:  Representative dataset for full integer
        quantization. Used when converting the keras model to the TFLite model
        with full interger quantization.
    """
    super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle,
                                          hparams.do_fine_tuning)
    num_classes = len(index_to_label)
    self._hparams = hparams
    self.preprocess = image_preprocessing.Preprocessor(
        self.model_spec.input_image_shape,
        num_classes,
        self.model_spec.mean_rgb,
        self.model_spec.stddev_rgb,
        use_augmentation=use_augmentation)
    self.history = None  # Training history that returns from `keras_model.fit`.
    self.representative_data = representative_data
​
  def _get_tflite_input_tensors(self, input_tensors):
    """Gets the input tensors for the TFLite model."""
    return input_tensors
​
  def create_model(self, hparams=None, with_loss_and_metrics=False):
    """Creates the classifier model for retraining."""
    hparams = self._get_hparams_or_default(hparams)
​
    module_layer = hub_loader.HubKerasLayerV1V2(
        self.model_spec.uri, trainable=hparams.do_fine_tuning)
    self.model = hub_lib.build_model(module_layer, hparams,
                                     self.model_spec.input_image_shape,
                                     self.num_classes)
    if with_loss_and_metrics:
      # Adds loss and metrics in the keras model.
      self.model.compile(
          loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
          metrics=['accuracy'])
代码语言:javascript复制
Custom classification model that is already retained by data
​
class ClassificationModel(custom_model.CustomModel):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
                           ExportFormat.SAVED_MODEL, ExportFormat.TFJS)
​
  def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
    """Initialize a instance with data, deploy mode and other related parameters.
​
    Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top
        classification layer.
    """
    super(ClassificationModel, self).__init__(model_spec, shuffle)
    self.index_to_label = index_to_label
    self.num_classes = len(index_to_label)
    self.train_whole_model = train_whole_model
​
  def evaluate(self, data, batch_size=32):
    """Evaluates the model.
​
    Args:
      data: Data to be evaluated.
      batch_size: Number of samples per evaluation step.
​
    Returns:
      The loss value and accuracy.
    """
    ds = data.gen_dataset(
        batch_size, is_training=False, preprocess=self.preprocess)
    return self.model.evaluate(ds)
​
  def predict_top_k(self, data, k=1, batch_size=32):
    """Predicts the top-k predictions.
代码语言:javascript复制
class CustomModel(abc.ABC):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
                           ExportFormat.TFJS)
​
  def __init__(self, model_spec, shuffle):
    """Initialize a instance with data, deploy mode and other related parameters.
​
    Args:
      model_spec: Specification for the model.
      shuffle: Whether the training data should be shuffled.
    """
    self.model_spec = model_spec
    self.shuffle = shuffle
    self.model = None
    # TODO(yuqili): remove this method once preprocess for image classifier is
    # also moved to DataLoader part.
    self.preprocess = None
​
  @abc.abstractmethod
  def train(self, train_data, validation_data=None, **kwargs):
    return
​
  def summary(self):
    self.model.summary()
​
  @abc.abstractmethod
  def evaluate(self, data, **kwargs):
    return
代码语言:javascript复制
def export_tflite(model,
                  tflite_filepath,
                  quantization_config=None,
                  convert_from_saved_model_tf2=False,
                  preprocess=None,
                  supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)):
  """Converts the retrained model to tflite format and saves it.
​
  Args:
    model: model to be converted to tflite.
    tflite_filepath: File path to save tflite model.
    quantization_config: Configuration for post-training quantization.
    convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x.
    preprocess: A preprocess function to apply on the dataset.
        # TODO(wangtz): Remove when preprocess is split off from CustomModel.
    supported_ops: A list of supported ops in the converted TFLite file.
  """
  if tflite_filepath is None:
    raise ValueError(
        "TFLite filepath couldn't be None when exporting to tflite.")
​
  if compat.get_tf_behavior() == 1:
    lite = tf.compat.v1.lite
  else:
    lite = tf.lite
​
  convert_from_saved_model = (
      compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2)
  with _create_temp_dir(convert_from_saved_model) as temp_dir_name:
    if temp_dir_name:
      save_path = os.path.join(temp_dir_name, 'saved_model')
      model.save(save_path, include_optimizer=False, save_format='tf')
      converter = lite.TFLiteConverter.from_saved_model(save_path)
    else:
      converter = lite.TFLiteConverter.from_keras_model(model)
​
    if quantization_config:
      converter = quantization_config.get_converter_with_quantization(
          converter, preprocess=preprocess)
​
    converter.target_spec.supported_ops = supported_ops
    tflite_model = converter.convert()
​
  with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
    f.write(tflite_model)
​
​
def get_lite_runner(tflite_filepath, model_spec=None):
  """Gets `LiteRunner` from file path to TFLite model and `model_spec`."""
  # Gets the functions to handle the input & output indexes if exists.
  reorder_input_details_fn = None
  if hasattr(model_spec, 'reorder_input_details'):
    reorder_input_details_fn = model_spec.reorder_input_details
​
  reorder_output_details_fn = None
  if hasattr(model_spec, 'reorder_output_details'):
    reorder_output_details_fn = model_spec.reorder_output_details
​
  lite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn,
                           reorder_output_details_fn)
  return lite_runner

0 人点赞