python - Is tf.gradients thread-safe? -


i have several calls tf.gradients each take time, concurrently call tf.gradients. however, receive 1 of several errors when try in graph. suspect not thread-safe, have not been able reproduce error mwe. tried using both pathos.pools.threadpool , pathos.pools.processpool in both mwe , real code - real code fails. here mwe tried:

from pathos.pools import threadpool, processpool import tensorflow tf import numpy np  xs = [tf.cast(np.random.random((10,10)), dtype=tf.float64) in range(3)] ys = [xs[0]*xs[1]*xs[2], xs[0]/xs[1]*xs[2], xs[0]/xs[1]/xs[2]]  def compute_grad(yx):     return tf.gradients(yx[0], yx[1])  tp = threadpool(3) res = tp.map(compute_grad, zip(ys, xs)) print(res) 

here's partial traceback encountered when trying real code. threadpool version.

file "pathos/threading.py", line 134, in map     return _pool.map(star(f), zip(*args)) # chunksize   file "multiprocess/pool.py", line 260, in map     return self._map_async(func, iterable, mapstar, chunksize).get()   file "multiprocess/pool.py", line 608, in     raise self._value   file "multiprocess/pool.py", line 119, in worker     result = (true, func(*args, **kwds))   file "multiprocess/pool.py", line 44, in mapstar     return list(map(*args))   file "pathos/helpers/mp_helper.py", line 15, in <lambda>     func = lambda args: f(*args)   file "my_code.py", line 939, in gradients_with_index     return (tf.gradients(y, variables), b_idx)   file "tensorflow/python/ops/gradients_impl.py", line 448, in gradients     colocate_gradients_with_ops)   file "tensorflow/python/ops/gradients_impl.py", line 188, in _pendingcount     between_op_list, between_ops, colocate_gradients_with_ops)   file "tensorflow/python/ops/control_flow_ops.py", line 1288, in maybecreatecontrolflowstate     loop_state.addwhilecontext(op, between_op_list, between_ops)   file "tensorflow/python/ops/control_flow_ops.py", line 1103, in addwhilecontext     grad_state = gradloopstate(forward_ctxt, outer_grad_state)   file "tensorflow/python/ops/control_flow_ops.py", line 737, in __init__     cnt, outer_grad_state)   file "tensorflow/python/ops/control_flow_ops.py", line 2282, in addbackproploopcounter     merge_count = merge([enter_count, enter_count])[0]   file "tensorflow/python/ops/control_flow_ops.py", line 404, in merge     return gen_control_flow_ops._merge(inputs, name)   file "tensorflow/python/ops/gen_control_flow_ops.py", line 150, in _merge     result = _op_def_lib.apply_op("merge", inputs=inputs, name=name)   file "tensorflow/python/framework/op_def_library.py", line 767, in apply_op     op_def=op_def)   file "tensorflow/python/framework/ops.py", line 2506, in create_op     original_op=self._default_original_op, op_def=op_def)   file "tensorflow/python/framework/ops.py", line 1273, in __init__     self._control_flow_context.addop(self)   file "tensorflow/python/ops/control_flow_ops.py", line 2147, in addop     self._addopinternal(op)   file "tensorflow/python/ops/control_flow_ops.py", line 2177, in _addopinternal     self._maybeaddcontroldependency(op)   file "tensorflow/python/ops/control_flow_ops.py", line 2204, in _maybeaddcontroldependency     op._add_control_input(self.getcontrolpivot().op) attributeerror: 'nonetype' object has no attribute 'op' 

here traceback. note error different

traceback (most recent call last):   file "tensorflow/python/ops/control_flow_ops.py", line 869, in addforwardaccumulator     enter_acc = self.forward_context.addvalue(acc)   file "tensorflow/python/ops/control_flow_ops.py", line 2115, in addvalue     self._outer_context.addinnerop(enter.op)   file "tensorflow/python/framework/ops.py", line 3355, in __exit__     self._graph._pop_control_dependencies_controller(self)   file "tensorflow/python/framework/ops.py", line 3375, in _pop_control_dependencies_controller     assert self._control_dependencies_stack[-1] controller assertionerror 

the processpool version encountered error:

_pickle.picklingerror: can't pickle <class 'tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.tfshouldusewarningwrapper'>: it's not found tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.tfshouldusewarningwrapper 

the tf.gradients() function not thread-safe. makes sequence of complicated , non-atomic modifications graph, , these not protected locks. in particular, seems using tf.gradients() on graph contains control flow operations (such tf.while_loop()) more run problems if run concurrently.

note unlikely issuing parallel calls tf.gradients() speed up—even if implemented in thread-safe manner. function performs no i/o , not call native methods release python's gil, execution serialized. implementing multiprocessing-based parallelism require additional system calls accessing shared graph (and acquiring/releasing locks), unlikely faster.


Comments

Popular posts from this blog

Is there a better way to structure post methods in Class Based Views -

performance - Why is XCHG reg, reg a 3 micro-op instruction on modern Intel architectures? -

jquery - Responsive Navbar with Sub Navbar -