Source code for aup.dlconvert.keras_to_tflite

#!/usr/bin/env python3
"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later

Keras to TFlite
===============

See :func:`dlconvert.to_tflite.setup_converter` for more control arguments for `tflite`.

Example
-------

.. code-block:: bash

   $ python -m aup.dlconvert.keras_to_tflite.py --model model.h5 \\
       --output model_keras.tflite \\
       [--load rep_data] \\
       [--opt default --ops int8 --type int8]

"""
from os import path
from tensorflow import lite, keras, __version__ # pylint: disable=no-name-in-module
from absl import flags, app
from .utils import reset_flag

FLAGS = flags.FLAGS
reset_flag()

from .to_tflite import create_converter # pylint: disable=wrong-import-position

flags.DEFINE_string("model", "model.h5", "input model", short_name="i")
flags.DEFINE_string("output", "model.tflite", "output", short_name="o")
flags.register_validator("model", path.isfile, message="missing input model")


[docs]def model_loader(filename: str) -> lite.TFLiteConverter: """TF 1/2 tflite converter loading function Args: filename (str): Keras model file Returns: lite.TFLiteConverter: TFLite converter to be used """ if __version__[0] == "2": model = keras.models.load_model(filename) return lite.TFLiteConverter.from_keras_model(model) return lite.TFLiteConverter.from_keras_model_file(filename)
[docs]def convert(model: str, output: str): """Convert Keras model to tflite model Args: model (str): input model file name output (str): output model file name """ converter = create_converter(model, model_loader) tflite_model = converter.convert() with open(output, "wb") as f: f.write(tflite_model)
def _main(_): convert(FLAGS.model, FLAGS.output) if __name__ == "__main__": app.run(_main)