Source code for aup.dlconvert.pb_to_onnx

#!/usr/bin/env python

"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later

ProtoBuf to ONNX
================

Input and output node names are needed.

Example
-------

.. code-block:: bash

   $ python -m aup.dlconvert.pb_to_onnx.py --model model.pb --output model.onnx \\
       --input_nodes input:0 --output_nodes output/Softmax:0 \\
       [--input_shape 1,224,224,3]

"""
from os import path
from absl import flags, app
from .utils import reset_flag

FLAGS = flags.FLAGS
reset_flag()

from . import to_onnx # pylint: disable=wrong-import-position

flags.DEFINE_string("model", "model.pb", "input model", short_name="i")
flags.DEFINE_string("output", "model.onnx", "output onnx model", short_name="o")
flags.DEFINE_string("input_nodes", "input:0", "model input names (sep by comma)")
flags.DEFINE_string("output_nodes", "label:0", "model output names (sep by comma)")
flags.DEFINE_string("input_shape", None, "input shape")
flags.register_validator("model", path.isfile, message="missing input model")

[docs]def convert(model: str, output: str, input_nodes: str, output_nodes: str, input_shape: str): """Convert TF ProtoBuf to ONNX model Args: model (str): TF ProtoBuf file name output (str): output ONNX model name input_nodes (str): model input names output_nodes (str): model output names [input_shape (str): model input shape, needed if input dimension is not specified in model graph] """ if input_shape is not None: # match input_shape with input_node and recreate the input_nodes string shapes = [] input_shapes = FLAGS.input_shape.split(";") input_nodes = input_nodes.split(",") new_input_nodes = [] for input_node, input_shape in zip(input_nodes, input_shapes): new_input_node = input_node + "[" + input_shape + "]" new_input_nodes.append(new_input_node) input_nodes = ",".join(new_input_nodes) argv = {"graphdef" : model, "output" : output, "inputs" : input_nodes, "outputs" : output_nodes} to_onnx.convert_onnx(**argv)
def _main(_): convert(FLAGS.model, FLAGS.output,FLAGS.input_nodes, FLAGS.output_nodes, FLAGS.input_shape) if __name__ == "__main__": app.run(_main)