In Chapter 4, Transforming Pictures with Amazing Art Styles, and Chapter 5, Understanding Simple Speech Commands, we used two slightly different versions of a script called freeze.py to merge trained network weights with the network graph definition to a self-sufficient model file, a boon we can use on mobile devices. TensorFlow also comes with a more universal version of the freeze script, called freeze_graph.py, located in the tensorflow/python/tools folder, that you can use to build a model file. To make it work, you need to provide it with at least four parameters (to see all the available parameters, check out tensorflow/python/tools/freeze_graph.py):
- --input_graph or --input_meta_graph: a graph definition file of the model. For example, in the output of the command ls -lt $MODEL_DIR/train in Step 4 of the last section, model.ckpt-109587.meta is a meta graph file that contains the model's graph definition and other checkpoint-related metadata, and graph.pbtxt is just the model's graph definition.
- --input_checkpoint: a specific checkpoint file, for example, model.ckpt-109587. Notice you don't specify the full filename of the large sized checkpoint file, model.ckpt-109587.data-00000-of-00001.
- --output_graph: the path to the frozen model file – this is the one used on mobile devices.
- --output_node_names: the list of output node names, separated with a comma, that tells the freeze_graph tool what part of the model and weights are to be included in the frozen model, so nodes and weights not required for generating the specific output node names will be left.
So for this model, how do we figure out the must-have output node names, as well as the input node names, also essential for inference as we have seen in iOS and Android apps in the previous chapters? Because we already used the run_inference script to generate our test image's caption, we can see how it makes the inference.
Go to your im2txt source code folder, models/research/im2txt/im2txt: you may want to open this in a nice editor such as Atom or Sublime Text, or a Python IDE such as PyCharm, or just open it (https://github.com/tensorflow/models/tree/master/research/im2txt/im2txt) from your browser. In run_inference.py, there's a call to build_graph_from_config in inference_utils/inference_wrapper_base.py, which calls build_model in inference_wrapper.py, which further calls a build method in show_and_tell_model.py. The build method finally calls, among other things, a build_input method, which has the following code:
if self.mode == "inference":
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
input_feed = tf.placeholder(dtype=tf.int64,
shape=[None], # batch_size
name="input_feed")
And a build_model method, which has:
if self.mode == "inference":
tf.concat(axis=1, values=initial_state, name="initial_state")
state_feed = tf.placeholder(dtype=tf.float32,
shape=[None, sum(lstm_cell.state_size)],
name="state_feed")
...
tf.concat(axis=1, values=state_tuple, name="state")
...
tf.nn.softmax(logits, name="softmax")
So, the three placeholders named image_feed, input_feed, and state_feed should be the input node names, while initial_state, state, and softmax should be the output node names. Furthermore, two methods defined in inference_wrapper.py confirm our detective work – the first one is:
def feed_image(self, sess, encoded_image):
initial_state = sess.run(fetches="lstm/initial_state:0",
feed_dict={"image_feed:0": encoded_image})
return initial_state
So, we provide image_feed and get initial_state back (the lstm/ prefix just means that the node is under the lstm scope). The second method is:
def inference_step(self, sess, input_feed, state_feed):
softmax_output, state_output = sess.run(
fetches=["softmax:0", "lstm/state:0"],
feed_dict={
"input_feed:0": input_feed,
"lstm/state_feed:0": state_feed,
})
return softmax_output, state_output, None
We feed in input_feed and state_feed, and get back softmax and state. In total, three input node names and three output names.
Note that these nodes are created only if the mode is "inference", as show_and_tell_model.py is used by both train.py and run_inference.py. This means that the model's graph definition file and weights located in --checkpoint_path , generated with train in Step 5, will be modified after running the run_inference.py script. So, how do we save the updated graph definition and checkpoint files?
It turns out that, in run_inference.py, after a TensorFlow session is created, there's also a call restore_fn(sess) to load the checkpoint file, and the call is defined in inference_utils/inference_wrapper_base.py:
def _restore_fn(sess):
saver.restore(sess, checkpoint_path)
When reaching the saver.restore call after starting run_inference.py, the updated graph definition has been made so we can just save a new checkpoint and graph file there, making the _restore_fn function the following:
def _restore_fn(sess):
saver.restore(sess, checkpoint_path)
saver.save(sess, "model/image2text")
tf.train.write_graph(sess.graph_def, "model", 'im2txt4.pbtxt')
tf.summary.FileWriter("logdir", sess.graph_def)
The line tf.train.write_graph(sess.graph_def, "model", 'im2txt4.pbtxt') is optional as, when saving a new checkpoint file by calling saver.save, a meta file also gets generated, which can be used along with the checkpoint file by freeze_graph.py. But it's generated here for those who'd love to see everything in the plain text format or who prefer to use a graph definition file with the --in_graph parameter when freezing a model. The last line tf.summary.FileWriter("logdir", sess.graph_def) is also optional, but it generates an event file that can be visualized by TensorBoard. So with these changes, after running run_inference.py again (remember to run bazel build -c opt //im2txt:run_inference first unless you run the run_inference.py directly with Python), you'll see in your model directory the following new checkpoint files and a new graph definition file:
jeff@AiLabby:~/tensorflow-1.5.0/models/research/im2txt$ ls -lt model
-rw-rw-r-- 1 jeff jeff 2076964 Feb 7 12:33 image2text.pbtxt
-rw-rw-r-- 1 jeff jeff 1343049 Feb 7 12:33 image2text.meta
-rw-rw-r-- 1 jeff jeff 77 Feb 7 12:33 checkpoint
-rw-rw-r-- 1 jeff jeff 149002244 Feb 7 12:33 image2text.data-00000-of-00001
-rw-rw-r-- 1 jeff jeff 16873 Feb 7 12:33 image2text.index
And in your logdir directory:
jeff@AiLabby:~/tensorflow-1.5.0/models/research/im2txt$ ls -lt logdir
total 2124
-rw-rw-r-- 1 jeff jeff 2171623 Feb 7 12:33 events.out.tfevents.1518035604.AiLabby
Now that we're here, let's quickly use TensorBoard to take a look at what our graph looks like – simply run tensorboard --logdir logdir, and open http://localhost:6006 from a browser. Figure 6.1 shows three output node names (softmax at the top, and lstm/initial_state and lstm/state at the top of the highlighted red rectangle) and one input node name (state_feed at the bottom):

Figure 6.2 shows one additional input node name, image_feed:

Finally, Figure 6.3 shows the last input node name, input_feed:

There are certainly a lot of details we can't and won't cover here. But you get the big picture and, equally important, enough details to move forward. Now running freeze_graph.py should be like a breeze (pun intended):
python tensorflow/python/tools/freeze_graph.py --input_meta_graph=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text.meta --input_checkpoint=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text --output_graph=/tmp/image2text_frozen.pb --output_node_names="softmax,lstm/initial_state,lstm/state" --input_binary=true
Notice we use the meta graph file here, along with the --input_binary parameter set to true, as by default it's false, meaning the freeze_graph tool expects the input graph or meta graph file to be in text format.
You can use the text-formatted graph file as input, in which case there's no need to provide the --input_binary parameter:
python tensorflow/python/tools/freeze_graph.py --input_graph=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text.pbtxt --input_checkpoint=/home/jeff/tensorflow-1.5.0/models/research/im2txt/model/image2text --output_graph=/tmp/image2text_frozen2.pb --output_node_names="softmax,lstm/initial_state,lstm/state"
The sizes of the two output graph files, image2text_frozen.pb and image2text_frozen2.pb, will be slightly different, but they behave exactly the same when, after being transformed and possibly optimized, they're used on mobile devices.