提交 c7bf77d0 编写于 作者: T Thuan Nguyen 提交者: Abhinav Arora

Add in is_copy attribute to SelectCase. (#9393)

This is a temporary solution to allowing for variables to be copied during a channel send operations.  Also fixed issue with is_copy for "channel_send" method, and also updated unit tests.
上级 65534c47
...@@ -82,11 +82,14 @@ class SelectCase(object): ...@@ -82,11 +82,14 @@ class SelectCase(object):
RECEIVE = 2 RECEIVE = 2
def __init__(self, def __init__(self,
select,
case_idx, case_idx,
case_to_execute, case_to_execute,
channel_action_fn=None, channel_action_fn=None,
channel=None, channel=None,
value=None): value=None,
is_copy=False):
self.select = select
self.helper = LayerHelper('conditional_block') self.helper = LayerHelper('conditional_block')
self.main_program = self.helper.main_program self.main_program = self.helper.main_program
self.is_scalar_condition = True self.is_scalar_condition = True
...@@ -99,7 +102,24 @@ class SelectCase(object): ...@@ -99,7 +102,24 @@ class SelectCase(object):
self.action = (self.SEND self.action = (self.SEND
if channel_action_fn.__name__ == ('channel_send') else if channel_action_fn.__name__ == ('channel_send') else
self.RECEIVE) if channel_action_fn else self.DEFAULT self.RECEIVE) if channel_action_fn else self.DEFAULT
self.value = value
X = value
if self.action == self.SEND and is_copy:
# We create of copy of the data we want to send
copied_X = self.select.parent_block.create_var(
name=unique_name.generate(value.name + '_copy'),
type=value.type,
dtype=value.dtype,
shape=value.shape,
lod_level=value.lod_level,
capacity=value.capacity
if hasattr(value, 'capacity') else None, )
self.select.parent_block.append_op(
type="assign", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X
self.value = X
self.channel = channel self.channel = channel
def __enter__(self): def __enter__(self):
...@@ -173,6 +193,7 @@ class SelectCase(object): ...@@ -173,6 +193,7 @@ class SelectCase(object):
class Select(BlockGuard): class Select(BlockGuard):
def __init__(self, name=None): def __init__(self, name=None):
self.helper = LayerHelper('select', name=name) self.helper = LayerHelper('select', name=name)
self.parent_block = self.helper.main_program.current_block()
self.cases = [] self.cases = []
super(Select, self).__init__(self.helper.main_program) super(Select, self).__init__(self.helper.main_program)
...@@ -183,12 +204,12 @@ class Select(BlockGuard): ...@@ -183,12 +204,12 @@ class Select(BlockGuard):
super(Select, self).__enter__() super(Select, self).__enter__()
return self return self
def case(self, channel_action_fn, channel, value): def case(self, channel_action_fn, channel, value, is_copy=False):
"""Create a new block for this condition. """Create a new block for this condition.
""" """
select_case = SelectCase( select_case = SelectCase(self,
len(self.cases), self.case_to_execute, channel_action_fn, channel, len(self.cases), self.case_to_execute,
value) channel_action_fn, channel, value, is_copy)
self.cases.append(select_case) self.cases.append(select_case)
...@@ -197,7 +218,7 @@ class Select(BlockGuard): ...@@ -197,7 +218,7 @@ class Select(BlockGuard):
def default(self): def default(self):
"""Create a default case block for this condition. """Create a default case block for this condition.
""" """
default_case = SelectCase(len(self.cases), self.case_to_execute) default_case = SelectCase(self, len(self.cases), self.case_to_execute)
self.cases.append(default_case) self.cases.append(default_case)
...@@ -341,17 +362,17 @@ def channel_send(channel, value, is_copy=False): ...@@ -341,17 +362,17 @@ def channel_send(channel, value, is_copy=False):
X = value X = value
if is_copy is True: if is_copy:
copied_X = helper.create_variable( copied_X = helper.create_variable(
name=unique_name.generate(value.name + '_copy'), name=unique_name.generate(value.name + '_copy'),
type=value.type, type=value.type,
dtype=value.dtype, dtype=value.dtype,
shape=value.shape, shape=value.shape,
lod_level=value.lod_level, lod_level=value.lod_level,
capacity=value.capacity) capacity=value.capacity if hasattr(value, 'capacity') else None)
assign_op = channel_send_block.append_op( assign_op = channel_send_block.append_op(
type="assign_op", inputs={"X": value}, outputs={"Out": copied_X}) type="assign", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X X = copied_X
channel_send_block.append_op( channel_send_block.append_op(
......
...@@ -173,16 +173,10 @@ class TestRoutineOp(unittest.TestCase): ...@@ -173,16 +173,10 @@ class TestRoutineOp(unittest.TestCase):
with while_op.block(): with while_op.block():
result2 = fill_constant( result2 = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
x_to_send_tmp = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
# TODO(abhinav): Need to perform copy when doing a channel send.
# Once this is complete, we can remove these lines
assign(input=x, output=x_to_send_tmp)
with fluid.Select() as select: with fluid.Select() as select:
with select.case(fluid.channel_send, channel, with select.case(
x_to_send_tmp): fluid.channel_send, channel, x, is_copy=True):
assign(input=x, output=x_tmp) assign(input=x, output=x_tmp)
assign(input=y, output=x) assign(input=y, output=x)
assign(elementwise_add(x=x_tmp, y=y), output=y) assign(elementwise_add(x=x_tmp, y=y), output=y)
...@@ -230,21 +224,12 @@ class TestRoutineOp(unittest.TestCase): ...@@ -230,21 +224,12 @@ class TestRoutineOp(unittest.TestCase):
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.FP64) core.VarDesc.VarType.FP64)
pong_result = self._create_tensor('pong_return_value',
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.FP64)
def ping(ch, message): def ping(ch, message):
message_to_send_tmp = fill_constant( fluid.channel_send(ch, message, is_copy=True)
shape=[1], dtype=core.VarDesc.VarType.FP64, value=0)
assign(input=message, output=message_to_send_tmp)
fluid.channel_send(ch, message_to_send_tmp)
def pong(ch1, ch2): def pong(ch1, ch2):
fluid.channel_recv(ch1, ping_result) fluid.channel_recv(ch1, ping_result)
assign(input=ping_result, output=pong_result) fluid.channel_send(ch2, ping_result, is_copy=True)
fluid.channel_send(ch2, pong_result)
pings = fluid.make_channel( pings = fluid.make_channel(
dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1) dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册