#!/usr/bin/env python
"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later
ProtoBuf to TFlite
==================
See :func:`dlconvert.to_tflite.setup_converter` for more control arguments for `tflite`.
Example
-------
.. code-block:: bash
    $ python -m aup.dlconvert.pb_to_tflite \\
        --model model.pb --output model.tflite \\
        [--load rep_data] \\
        [--opt default --ops int8 --type int8]
        [--input_shape 1,224,224,3]
"""
from os import path
import logging
from typing import List
from absl import flags, app
from tensorflow import lite, keras
import tensorflow as tf
from .utils import reset_flag
reset_flag()
# pylint: disable=wrong-import-position
from .to_tflite import create_converter
from .spec_utils import pb
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
tf.compat.v1.disable_eager_execution()
flags.DEFINE_string("model", "model.pb", "input model", short_name="i")
flags.DEFINE_string("output", "model.tflite", "output", short_name="o")
flags.DEFINE_string("input_nodes", "input:0", "input tensor names")
flags.DEFINE_string("output_nodes", "output/Softmax:0", "output tensor names")
flags.DEFINE_string("input_shape", None, "input shape")
flags.register_validator("model", path.isfile, message="missing input model")
FLAGS = flags.FLAGS
[docs]def verify_output_names(output_names, graph_def):
    """Check if output_names are correct
    """
    flag = False
    for n in output_names:
        if not pb.verify_name(n, graph_def):
            logger.fatal("%s is not found in graph", n)
            flag = True
    if flag:
        names = pb.search_output_names(graph_def)
        logger.fatal("Potential output names:\n%s", "\n".join(names))
        return False
    else:
        return True 
[docs]def find_node_shape(tensor_name:str, graph_def:tf.compat.v1.GraphDef) -> List[int]:
    """Find node shape for the given tensor name
    
    Args:
        tensor_name (str): name of the tensor
        graph_def (tf.compat.v1.GraphDef): TF GraphDef
    
    Raises:
        ValueError: When node name is not in the graph
    
    Returns:
        List[int]: tensor shape, excluding first (batch) dimension
    """
    tensor_name = tensor_name.split(":")[0]  # if :x is given
    shape = None
    for n in graph_def.node:
        if tensor_name in n.name:
            shape = [i.size for i in n.attr['shape'].shape.dim]  
            break
    if shape is None:
        raise ValueError("No match node for %s"%tensor_name)
    return shape[1:] 
[docs]def model_loader(filename: str) -> lite.TFLiteConverter:
    """Load TF ProtoBuf (for TF v1 and v2)
    
    Args:
        filename (str): ProtoBuf file name
    
    Returns:
        lite.TFLiteConverter: TFLite converter
    """
    input_names = FLAGS.input_nodes.split(",")
    output_names = FLAGS.output_nodes.split(",")
    # Overwrite default input_shape with user defined input_shape
    if FLAGS.input_shape is not None:
        shapes = []
        input_shapes = FLAGS.input_shape.split(";")
        for input_shape in input_shapes:
            shape = [int(x) for x in input_shape.split(",")]
            shapes.append(shape)
    if tf.__version__[0] == "2":  # pylint: disable=no-member
        # require frozen pb
        logger.info("Tensorflow version 2.x")
        g = tf.Graph()
        graph_def = pb.load_graphdef(filename)
        if verify_input_names([i.split(":")[0] for i in input_names], graph_def):
            logger.info("Correct input names")
        else:
            raise Exception("Wrong input nodes")
        if verify_output_names([i.split(":")[0] for i in output_names], graph_def):
            logger.info("Correct output names")
        else:
            raise Exception("Wrong output nodes")
        # get input shapes
        if not FLAGS.input_shape:
            shapes = [find_node_shape(name, graph_def) for name in input_names]
        else:
            shapes = [shape[1:] for shape in shapes]
            
        # import protobuf and create Keras model
        with g.as_default():
            inputs = {}
            for name, shape in zip(input_names, shapes):
                inputs[name] = keras.layers.Input(shape=shape, name="input")
            tf.import_graph_def(graph_def, name="", input_map=inputs)
            tf_outputs = [g.get_tensor_by_name(name) for name in output_names]
            model = keras.Model(inputs=inputs, outputs=tf_outputs)
        return lite.TFLiteConverter.from_keras_model(model)
    else:
        logger.info("Tensorflow version 1.x")
        input_names = [i.split(':')[0] for i in input_names]
        output_names = [i.split(':')[0] for i in output_names]
        try:
            tflite_model = lite.TFLiteConverter.from_frozen_graph(filename, input_names, output_names)
        except Exception as e:
            graph_def = pb.load_graphdef(filename)
            if verify_input_names(input_names, graph_def):
                logger.info("Correct input names")
            if verify_output_names(output_names, graph_def):
                logger.info("Correct output names")
            raise e
        return tflite_model 
def _main(_):
    logger.setLevel(logging.INFO)
    converter = create_converter(FLAGS.model, model_loader)
    tflite_model = converter.convert()
    with open(FLAGS.output, "wb") as f:
        f.write(tflite_model)
if __name__ == "__main__":
    app.run(_main)