tensorflow

Creating a custom operation with tf.py_func (CPU only)

Parameters#

Parameter Details
func python function, which takes numpy arrays as its inputs and returns numpy arrays as its outputs
inp list of Tensors (inputs)
Tout list of tensorflow data types for the outputs of func
## Basic example
The tf.py_func(func, inp, Tout) operator creates a TensorFlow operation that calls a Python function, func on a list of tensors inp.

See the documentation for tf.py_func(func, inp, Tout).

Warning: The tf.py_func() operation will only run on CPU. If you are using distributed TensorFlow, the tf.py_func() operation must be placed on a CPU device in the same process as the client.

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

Why to use tf.py_func

The tf.py_func() operator enables you to run arbitrary Python code in the middle of a TensorFlow graph. It is particularly convenient for wrapping custom NumPy operators for which no equivalent TensorFlow operator (yet) exists. Adding tf.py_func() is an alternative to using sess.run() calls inside the graph.

Another way of doing that is to cut the graph in two parts:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

With tf.py_func this is much easier:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)

This modified text is an extract of the original Stack Overflow Documentation created by the contributors and released under CC BY-SA 3.0 This website is not affiliated with Stack Overflow