Source code for aup.dlconvert.checkpoint_to_pb

#!/usr/bin/env python
"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later

Checkpoint to TF ProtoBuf
=========================

Require checkpoint folder with `.meta` file.  Otherwise, please save the meta file manually before convertion.

Example
-------

.. code-block:: bash

   $ python -m aup.dlconvert.checkpoint_to_pb.py --model  model_ckpt/model.meta \\
       --output model_frozen.pb \\
       --frozen \\
       --output_nodes output/Softmax:0

"""
from os import path
from typing import List
import tensorflow.compat.v1 as tf # pylint: disable=import-error
from absl import flags, app
from .utils import reset_flag

FLAGS = flags.FLAGS
reset_flag()

# pylint: disable=wrong-import-position
from .to_frozen_pb import to_frozen

flags.DEFINE_string("model", "model-ckpt/model.meta", "input model ckpt meta file path", short_name="i")
flags.DEFINE_string("output", "model.pb", "output filename", short_name='o')
flags.DEFINE_bool("frozen", True, "create frozen protobuf")
flags.DEFINE_string("output_nodes", "", "model output names (separated by comma)")
flags.register_validator("model", path.isfile, message="Input checkpoint meta file is missing")
flags.register_validator("output_nodes", lambda x: len(x.split(','))>0, message="Provide at least one output node name")


[docs]def convert(checkpoint_meta_file: str, frozen: bool = False, output_nodes: List[str] = ()) -> tf.GraphDef: """Convert TF Checkpoint to ProtoBuf Args: checkpoint_meta_file (str): checkpoint meta file name frozen (bool, optional): to create a frozen graphdef. Defaults to False. output_nodes (List[str], optional): A list of output node names for frozen graph. Defaults to (). Returns: tf.GraphDef: Tensorflow Graph to be written to file """ model_dir = path.dirname(checkpoint_meta_file) ckpt = tf.train.get_checkpoint_state(model_dir) g = tf.Graph() with g.as_default(): saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path+".meta", clear_devices=True) input_graph_def = g.as_graph_def() with tf.Session(graph=g) as sess: saver.restore(sess, ckpt.model_checkpoint_path) if frozen: input_graph_def = to_frozen(sess, output_nodes) return input_graph_def
def _main(_): protobuf = convert(FLAGS.model, FLAGS.frozen, FLAGS.output_nodes.split(",")) with open(FLAGS.output, "wb") as fp: # pylint: disable=invalid-name fp.write(protobuf.SerializeToString()) if __name__ == "__main__": app.run(_main)