Tensorflow tools
There are many tensorflow tools that are not widely known.
Freeze Graph
Freeze Graph literally freezes the graph.
- Unnecessary nodes will be removed
- Model is one simple protobuf file (weights & graph definitions)
~/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=your_graph_definition.pb \
--input_checkpoint=your_tensorflow_checkpoint \
--input_binary=true \
--output_graph=frozen_graph.pb \
--output_node_names=Softmax
input_binary
means whether the input graph file is in binary format or not.
- If the file extension ends with
.pb
, it's binary format. - If the file extension ends with
.pbtxt
, it's a txt format.
node_names
are the names of tensors.
For example, you should be able to retrieve tensors by the name
# retrieve tensors
image_input = graph.get_tensor_by_name('image_input:0')
keep_prob = graph.get_tensor_by_name('keep_prob:0')
softmax = graph.get_tensor_by_name('Softmax:0')
# operations
prob = sess.run(softmax, {image_input: some_image, keep_prob: 1.0})
And here is the code to load the freeze graph
def load_graph(graph_file, use_xla=False):
jit_level = 0
config = tf.ConfigProto()
if use_xla:
jit_level = tf.OptimizerOptions.ON_1
config.graph_options.optimizer_options.global_jit_level = jit_level
with tf.Session(graph=tf.Graph(), config=config) as sess:
gd = tf.GraphDef()
with tf.gfile.Open(graph_file, 'rb') as f:
data = f.read()
gd.ParseFromString(data)
tf.import_graph_def(gd, name='')
# unnecessary: only to see how many operations are in the model
ops = sess.graph.get_operations()
n_ops = len(ops)
return sess.graph, ops
After freezing, the number of operation has reduced by 88%.
sess, base_ops = load_graph('base_graph.pb')
print(len(base_ops)) # 2165
sess, frozen_ops = load_graph('frozen_graph.pb')
print(len(frozen_ops)) # 245
Optimize for Inference
For inference, it can be further optimized through Optimize for Inference
- Some operations are not necessary for inference
- For example, batch normalization can be removed after extracting mean and std
- Many operations can be fused into one
- For example, 3 steps (CNN - BN - RELU) can be fused into one step (CNNBNRELU)
Usage
bazel build tensorflow/python/tools:optimize_for_inference && \
bazel-bin/tensorflow/python/tools/optimize_for_inference \
--input=frozen_inception_graph.pb \
--output=optimized_inception_graph.pb \
--frozen_graph=True \
--input_names=Mul \
--output_names=softmax
Result
Now the number of operation has reduced to 200.
sess, optimized_ops = load_graph('optimized_graph.pb')
print(len(optimized_ops)) # 200
Graph Transform
Graph Transform allows to transform the graph.
For example,
- Float to Int
- 32 bits to 8 bits
There are so much less data. Now it runs so much faster.
Is it okay?
For training, back propagation requires significant digits. For inference, it's not necessary since we are interested in the highest output (label) not the value itself.
Usage
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_old_batch_norms
'
Result
Note that the number of operations has increased, but the model runs super fast with lower bits.
sess, eightbit_ops = load_graph('eightbit_graph.pb')
print(len(optimized_ops)) # 425
AOT & JIT
Tensorflow supports AOT and JIT compilation.
- AOT stands for "Ahead Of Time"
- JIT stands for "Just In Time"
Furthermore, it can be used not only for inference but also for training as well.
Usage
# Create a TensorFlow configuration object.
config = tf.Config()
# JIT level, this can be set to ON_1 or ON_2
jit_level = tf.OptimizerOptions.ON_1
config.graph_options.optimizer_options.global_jit_level = jit_level
# Open a session with the config
with tf.Session(config=config) as sess:
...