Source code for aup.dlconvert.keras_to_onnx

#!/usr/bin/env python
"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later

Keras to ONNX
=============

Depends on `tf2onnx <https://github.com/onnx/tensorflow-onnx>`_ and
`keras2onnx <https://github.com/onnx/keras-onnx>`_ (TF1.x support only).
Needs to import `tf2onnx` first to resolve dependency
otherwise import `keras2onnx` may result in error

Example
-------

.. code-block:: bash

   $ python -m aup.dlconvert.keras_to_onnx.py -i model.h5 -o model.onnx

"""
from os import path, environ
environ["TF_KERAS"] = "1"

import tf2onnx
import keras2onnx # pylint: disable=import-error

from absl import flags, app
from tensorflow.compat.v1 import keras # pylint: disable=import-error
from .utils import reset_flag

FLAGS = flags.FLAGS
reset_flag()

flags.DEFINE_string("model", "model.h5", "input model", short_name="i")
flags.DEFINE_string("output", "model.onnx", "output", short_name="o")
flags.register_validator("model", path.isfile, message="missing input model")


[docs]def convert(model: str, output: str): """Convert Keras model to ONNX Args: model (str): Keras h5 model file path output (str): output ONNX file path """ model = keras.models.load_model(model) model = keras2onnx.convert_keras(model) with open(output, "wb") as fp: fp.write(model.SerializeToString())
def _main(_): convert(FLAGS.model, FLAGS.output) if __name__ == "__main__": app.run(_main)