"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later
Convert to TFLite
=================
There are four major control parameters for tflite runtime, see :func:`.setup_tfconverter`.
The data feeding function (`data_fun`) is loaded by `--load`, where the argument is the Python filename defining
`get_data()` to `generate data <https://www.tensorflow.org/lite/performance/post_training_integer_quant#convert_using_quantization>`_ for `int8` quantization. 
Combine with `--undefok` flag to pass more control arguments.
"""
import logging
from os import path, environ
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  #disable tensorflow debugging messages
from typing import Callable
import numpy as np
from absl import flags
from tensorflow import lite
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
try:
    from tensorflow import enable_eager_execution  # pylint: disable=no-name-in-module
    enable_eager_execution()
except Exception:  # pylint: disable=broad-except
    logger.fatal("Compatibility issue with TF eager execution. Tweak with caution.")
try:
    from tensorflow.lite.python import lite_constants
except ImportError:
    from tensorflow.lite import constants as lite_constants
flags.DEFINE_string("opt", "none", "optimization")
flags.DEFINE_string("type", "float", "data type after quantization")
flags.DEFINE_string("ops", "tflite", "operation set to be used for quantization")
# -- optional for representative data
flags.DEFINE_string("load", "", "use representative data defined in additional python file", short_name="d")
# when using quantization, OPT_ARGS needs to be set to "default"
OPT_ARGS = {
    "default": [lite.Optimize.DEFAULT],
    "none": [],
}
flags.register_validator("opt", lambda x: x in OPT_ARGS, "Keyword not recognized, choose from %s" % OPT_ARGS.keys())
OPS_ARGS = {
    "int8": [lite.OpsSet.TFLITE_BUILTINS_INT8],
    "tflite": [lite.OpsSet.TFLITE_BUILTINS],  # default
    "tf": [lite.OpsSet.SELECT_TF_OPS, lite.OpsSet.TFLITE_BUILTINS],
}
flags.register_validator("ops", lambda x: x in OPS_ARGS, "Keyword not recognized, choose from %s" % OPS_ARGS.keys())
TYPE_ARGS = {
    "float": [lite_constants.FLOAT],  # tf.float32, default
    "int8": [lite_constants.INT8],  # tf.int8
    "float16": [lite_constants.FLOAT16],  # tf.float16 -> tensorflow>=1.15
    "uint8": [lite_constants.QUANTIZED_UINT8],  # tf.uint
}
flags.register_validator("type", lambda x: x in TYPE_ARGS, "Keyword not recognized, choose from %s" % TYPE_ARGS.keys())
FLAGS = flags.FLAGS
[docs]def setup_tfconverter(
    converter: lite.TFLiteConverter, dtype: str, opt: str, ops: str, data_fun: Callable = None
) -> lite.TFLiteConverter:
    """Setup control arguments for `TFLiteConverter <https://www.tensorflow.org/lite/convert>`_
    
    Args:
        converter (lite.TFLiteConverter): loaded `TFLiteConverter`.
        dtype (str): data types: `float`, `float16`, `int8`, `uint8`.
        opt (str): optimization: `none` for `float`, `default` for ther data types.
        ops (str): operation sets: `tflite`, `tf`, `int8`.
        data_fun (Callable, optional): [description]. Defaults to None.
    
    Returns:
        lite.TFLiteConverter: `TFLiteConverter` with additional arguments set up.
    """
    if data_fun:
        converter.representative_dataset = data_fun
    converter.optimizations = OPT_ARGS[opt]
    converter.target_spec.supported_ops = OPS_ARGS[ops]
    converter.target_spec.supported_types = TYPE_ARGS[dtype]
    return converter 
[docs]def create_converter(model: str, model_loader: Callable[[str], lite.TFLiteConverter]) -> lite.TFLiteConverter:
    """Setup the TFLite converter
    Args:
        model (str): model filename
        model_loader (Callable[[str], lite.TFLiteConverter]): function to load model file and return a `TFLiteConverter`
    Returns:
        lite.TFLiteConverter: `TFLiteConverter` with additional arguments set up.
    """
    if FLAGS.type != "float":
        assert FLAGS.opt != "none", "--opt=default is required for quantization."
    else:
        assert FLAGS.opt == "none", "--opt=none is required for float32 operation."
    converter = model_loader(model)
    if FLAGS.load:
        try:
            import sys, importlib
            sys.path.insert(0, path.dirname(path.abspath(FLAGS.load)))
            mod = path.basename(FLAGS.load).rstrip(".py")
            mod = importlib.import_module(mod)
            get_dataset = getattr(mod, "get_dataset")
        except Exception as error:  # pragma: no cover
            logger.fatal("Failed to import get_dataset from %s.py", FLAGS.load)
            raise error
        data = get_dataset()
        def data_gen():
            for i in data:
                yield [i[0].numpy().astype(np.float32)]
    else:
        data_gen = None
    return setup_tfconverter(converter, FLAGS.type, FLAGS.opt, FLAGS.ops, data_gen)