diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index 535e881c42f675198a2679cb7974af64b65cc194..0fc4981a8e9da09f15e6d0a5e5c6761e01328876 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -131,7 +131,7 @@ def make_channel(dtype, capacity=0): return channel -def channel_send(channel, value): +def channel_send(channel, value, copy=False): """ Sends a value through a channel variable. Used by an unbuffered or buffered channel to pass data from within or to a concurrent Go block, where @@ -141,6 +141,8 @@ def channel_send(channel, value): channel (Variable|Channel): Channel variable created using `make_channel`. value (Variable): Value to send to channel + copy (bool): Copy data while channel send. If False, then data + is moved. The input cannot be used after move. Returns: Variable: The boolean status on whether or not the channel successfully sent the passed value. @@ -162,11 +164,26 @@ def channel_send(channel, value): type=core.VarDesc.VarType.LOD_TENSOR, dtype=core.VarDesc.VarType.BOOL) + X = value + + if copy is True: + copied_X = helper.create_variable( + name=unique_name.generate(value.name + '_copy'), + type=value.type, + dtype=value.dtype, + shape=value.shape, + lod_level=value.lod_level, + capacity=value.capacity) + + assign_op = channel_send_block.append_op( + type="assign_op", inputs={"X": value}, outputs={"Out": copied_X}) + X = copied_X + channel_send_op = channel_send_block.append_op( type="channel_send", inputs={ "Channel": channel, - "X": value, + "X": X, }, outputs={"Status": status})