On the confusing behavior of `tf.cond`
This is a post about howtf.cond
is a little bit of a leaky abstraction.
To review: tensorflow provides the following API for control flow:
tf.cond(pred, fn1, fn2)
tf.while_loop(pred, body, loop_vars)
tf.scan(fn, elements)
tf.cond
looks like the way it should work is the same as in programming languages, that is, you expect if pred
is true
that fn1
is executed, and if not then fn2
is executed. But is that actually how it works?
Let’s see some code. This works as expected:
from __future__ import absolute_import, division, print_function
import tensorflow as tf
x = tf.Variable(5)
def update():
return tf.assign(x, 100) # Take note of this line!
def dont_update():
return x
# Set this variable to true or false. This probably does what you expect:
dummy_cond = tf.Variable(False)
final_val = tf.cond(dummy_cond, update, dont_update)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
res = sess.run([final_val, x])
print(res)
…but what happens if we pull out the assignment?
from __future__ import absolute_import, division, print_function
import tensorflow as tf
x = tf.Variable(5)
assign_op = tf.assign(x, 100)
def update():
return assign_op # Take note: this line is what changed!
def dont_update():
return x
# Set this variable to true or false. This will not do what you expect:
dummy_cond = tf.Variable(True)
final_val = tf.cond(dummy_cond, update, dont_update)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
res = sess.run([final_val, x])
print(res)
The reason this doesn’t work the way most would expect is because of tensorflow’s execution model.
Tensorflow is designed to be used like this:
- You make a computation graph
- You say “compute this node of the graph”
- To compute each node, tensorflow calculates the node’s dependencies and executes them. Anything that isn’t a dependency isn’t executed, since tensorflow only executes “ready” nodes. (Side note: tensorflow execution looks lazy but it’s actually eager, it blocks on nodes being ready and then executes them as they become ready).
tf.cond
does some special stuff! The reason it takes a lambda is supposed to be a message to the API user: you should
try to do everything in that context. They build a special context around your lambda so that ops created inside the lambda
are disabled when the condition for tf.cond
doesn’t match.
However that special conditional context can’t save you if you refer to an op outside of the context. The execution
model says we must compute our dependencies before executing our operation, so TF will still blindly do that (e.g if you’re trying to tf.multiply(a, b)
, then a
and b
should exist!).
Similarly if you’re trying to tf.cond(c, a, b)
, a and b should exist! And if a and b reach outside of their special
closure contexts, you’re asking for trouble because tf will try to compute all dependencies of a AND b, without regard
for the conditional.
TL;DR this is why you can find yourself in cases where you set up a conditional and it seems like both branches are executed without regard for the actual conditional value. It’s because they are (when you stray outside of the closure contexts).
Related Stack Overflow post: “Confused by the behavior of tf.cond”