Source code for aup.dlconvert.to_tflite

"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later

Convert to TFLite
=================

There are four major control parameters for tflite runtime, see :func:`.setup_tfconverter`.

The data feeding function (`data_fun`) is loaded by `--load`, where the argument is the Python filename defining
`get_data()` to `generate data <https://www.tensorflow.org/lite/performance/post_training_integer_quant#convert_using_quantization>`_ for `int8` quantization. 
Combine with `--undefok` flag to pass more control arguments.

"""
import logging
from os import path, environ
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  #disable tensorflow debugging messages
from typing import Callable
import numpy as np
from absl import flags
from tensorflow import lite


logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

try:
    from tensorflow import enable_eager_execution  # pylint: disable=no-name-in-module

    enable_eager_execution()
except Exception:  # pylint: disable=broad-except
    logger.fatal("Compatibility issue with TF eager execution. Tweak with caution.")

try:
    from tensorflow.lite.python import lite_constants
except ImportError:
    from tensorflow.lite import constants as lite_constants

flags.DEFINE_string("opt", "none", "optimization")
flags.DEFINE_string("type", "float", "data type after quantization")
flags.DEFINE_string("ops", "tflite", "operation set to be used for quantization")
# -- optional for representative data
flags.DEFINE_string("load", "", "use representative data defined in additional python file", short_name="d")

# when using quantization, OPT_ARGS needs to be set to "default"
OPT_ARGS = {
    "default": [lite.Optimize.DEFAULT],
    "none": [],
}
flags.register_validator("opt", lambda x: x in OPT_ARGS, "Keyword not recognized, choose from %s" % OPT_ARGS.keys())

OPS_ARGS = {
    "int8": [lite.OpsSet.TFLITE_BUILTINS_INT8],
    "tflite": [lite.OpsSet.TFLITE_BUILTINS],  # default
    "tf": [lite.OpsSet.SELECT_TF_OPS, lite.OpsSet.TFLITE_BUILTINS],
}
flags.register_validator("ops", lambda x: x in OPS_ARGS, "Keyword not recognized, choose from %s" % OPS_ARGS.keys())

TYPE_ARGS = {
    "float": [lite_constants.FLOAT],  # tf.float32, default
    "int8": [lite_constants.INT8],  # tf.int8
    "float16": [lite_constants.FLOAT16],  # tf.float16 -> tensorflow>=1.15
    "uint8": [lite_constants.QUANTIZED_UINT8],  # tf.uint
}
flags.register_validator("type", lambda x: x in TYPE_ARGS, "Keyword not recognized, choose from %s" % TYPE_ARGS.keys())
FLAGS = flags.FLAGS


[docs]def setup_tfconverter( converter: lite.TFLiteConverter, dtype: str, opt: str, ops: str, data_fun: Callable = None ) -> lite.TFLiteConverter: """Setup control arguments for `TFLiteConverter <https://www.tensorflow.org/lite/convert>`_ Args: converter (lite.TFLiteConverter): loaded `TFLiteConverter`. dtype (str): data types: `float`, `float16`, `int8`, `uint8`. opt (str): optimization: `none` for `float`, `default` for ther data types. ops (str): operation sets: `tflite`, `tf`, `int8`. data_fun (Callable, optional): [description]. Defaults to None. Returns: lite.TFLiteConverter: `TFLiteConverter` with additional arguments set up. """ if data_fun: converter.representative_dataset = data_fun converter.optimizations = OPT_ARGS[opt] converter.target_spec.supported_ops = OPS_ARGS[ops] converter.target_spec.supported_types = TYPE_ARGS[dtype] return converter
[docs]def create_converter(model: str, model_loader: Callable[[str], lite.TFLiteConverter]) -> lite.TFLiteConverter: """Setup the TFLite converter Args: model (str): model filename model_loader (Callable[[str], lite.TFLiteConverter]): function to load model file and return a `TFLiteConverter` Returns: lite.TFLiteConverter: `TFLiteConverter` with additional arguments set up. """ if FLAGS.type != "float": assert FLAGS.opt != "none", "--opt=default is required for quantization." else: assert FLAGS.opt == "none", "--opt=none is required for float32 operation." converter = model_loader(model) if FLAGS.load: try: import sys, importlib sys.path.insert(0, path.dirname(path.abspath(FLAGS.load))) mod = path.basename(FLAGS.load).rstrip(".py") mod = importlib.import_module(mod) get_dataset = getattr(mod, "get_dataset") except Exception as error: # pragma: no cover logger.fatal("Failed to import get_dataset from %s.py", FLAGS.load) raise error data = get_dataset() def data_gen(): for i in data: yield [i[0].numpy().astype(np.float32)] else: data_gen = None return setup_tfconverter(converter, FLAGS.type, FLAGS.opt, FLAGS.ops, data_gen)