提交 c7bf77d0 编写于 作者: T Thuan Nguyen 提交者: Abhinav Arora

Add in is_copy attribute to SelectCase. (#9393)

This is a temporary solution to allowing for variables to be copied during a channel send operations.  Also fixed issue with is_copy for "channel_send" method, and also updated unit tests.
上级 65534c47
......@@ -82,11 +82,14 @@ class SelectCase(object):
RECEIVE = 2
def __init__(self,
select,
case_idx,
case_to_execute,
channel_action_fn=None,
channel=None,
value=None):
value=None,
is_copy=False):
self.select = select
self.helper = LayerHelper('conditional_block')
self.main_program = self.helper.main_program
self.is_scalar_condition = True
......@@ -99,7 +102,24 @@ class SelectCase(object):
self.action = (self.SEND
if channel_action_fn.__name__ == ('channel_send') else
self.RECEIVE) if channel_action_fn else self.DEFAULT
self.value = value
X = value
if self.action == self.SEND and is_copy:
# We create of copy of the data we want to send
copied_X = self.select.parent_block.create_var(
name=unique_name.generate(value.name + '_copy'),
type=value.type,
dtype=value.dtype,
shape=value.shape,
lod_level=value.lod_level,
capacity=value.capacity
if hasattr(value, 'capacity') else None, )
self.select.parent_block.append_op(
type="assign", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X
self.value = X
self.channel = channel
def __enter__(self):
......@@ -173,6 +193,7 @@ class SelectCase(object):
class Select(BlockGuard):
def __init__(self, name=None):
self.helper = LayerHelper('select', name=name)
self.parent_block = self.helper.main_program.current_block()
self.cases = []
super(Select, self).__init__(self.helper.main_program)
......@@ -183,12 +204,12 @@ class Select(BlockGuard):
super(Select, self).__enter__()
return self
def case(self, channel_action_fn, channel, value):
def case(self, channel_action_fn, channel, value, is_copy=False):
"""Create a new block for this condition.
"""
select_case = SelectCase(
len(self.cases), self.case_to_execute, channel_action_fn, channel,
value)
select_case = SelectCase(self,
len(self.cases), self.case_to_execute,
channel_action_fn, channel, value, is_copy)
self.cases.append(select_case)
......@@ -197,7 +218,7 @@ class Select(BlockGuard):
def default(self):
"""Create a default case block for this condition.
"""
default_case = SelectCase(len(self.cases), self.case_to_execute)
default_case = SelectCase(self, len(self.cases), self.case_to_execute)
self.cases.append(default_case)
......@@ -341,17 +362,17 @@ def channel_send(channel, value, is_copy=False):
X = value
if is_copy is True:
if is_copy:
copied_X = helper.create_variable(
name=unique_name.generate(value.name + '_copy'),
type=value.type,
dtype=value.dtype,
shape=value.shape,
lod_level=value.lod_level,
capacity=value.capacity)
capacity=value.capacity if hasattr(value, 'capacity') else None)
assign_op = channel_send_block.append_op(
type="assign_op", inputs={"X": value}, outputs={"Out": copied_X})
type="assign", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X
channel_send_block.append_op(
......
......@@ -173,16 +173,10 @@ class TestRoutineOp(unittest.TestCase):
with while_op.block():
result2 = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
x_to_send_tmp = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
# TODO(abhinav): Need to perform copy when doing a channel send.
# Once this is complete, we can remove these lines
assign(input=x, output=x_to_send_tmp)
with fluid.Select() as select:
with select.case(fluid.channel_send, channel,
x_to_send_tmp):
with select.case(
fluid.channel_send, channel, x, is_copy=True):
assign(input=x, output=x_tmp)
assign(input=y, output=x)
assign(elementwise_add(x=x_tmp, y=y), output=y)
......@@ -230,21 +224,12 @@ class TestRoutineOp(unittest.TestCase):
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.FP64)
pong_result = self._create_tensor('pong_return_value',
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.FP64)
def ping(ch, message):
message_to_send_tmp = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.FP64, value=0)
assign(input=message, output=message_to_send_tmp)
fluid.channel_send(ch, message_to_send_tmp)
fluid.channel_send(ch, message, is_copy=True)
def pong(ch1, ch2):
fluid.channel_recv(ch1, ping_result)
assign(input=ping_result, output=pong_result)
fluid.channel_send(ch2, pong_result)
fluid.channel_send(ch2, ping_result, is_copy=True)
pings = fluid.make_channel(
dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册