未验证 提交 9ef7f6e8 编写于 作者: A Aurelius84 提交者: GitHub

support to modify dict and list in while_loop (#23083)

* support to modify dict and list in while_loop test=develop

* polish code test=develop
上级 8c6fde9e
......@@ -21,7 +21,7 @@ from .. import core
from ..framework import Program, Variable, Operator, in_dygraph_mode
from ..layer_helper import LayerHelper, unique_name
from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, map_structure
from .utils import assert_same_structure, map_structure, hold_mutable_vars, copy_mutable_vars
import numpy
import warnings
import six
......@@ -1018,8 +1018,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
return loop_vars
while_loop_block = While(pre_cond, is_test, name)
has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
with while_loop_block.block():
output_vars = body(*loop_vars)
# If a variable with mutable type is included in loop_vars, like `dict/list`,
# modifying it in the body function will cause origin variable to be modified
# synchronously. This will raise an assignment error out of while block.
# Here we make a copy of the mutable vars to avoid this problem.
if has_mutable_vars_in_loop:
new_loop_vars = copy_mutable_vars(loop_vars)
output_vars = body(*new_loop_vars)
else:
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
if len(output_vars) != len(loop_vars):
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import collections
import copy
import six
import numpy as np
from ..framework import Variable
......@@ -187,6 +188,24 @@ def map_structure(func, *structure):
return pack_sequence_as(structure[0], [func(*x) for x in entries])
def hold_mutable_vars(structure):
"""
Returns whether structure holds sequence like `list/dict`.
"""
for s in structure:
if is_sequence(s):
return True
return False
def copy_mutable_vars(structure):
"""
Returns vars copied from sequence without mutable property.
"""
flat_structure = copy.copy(flatten(structure))
return pack_sequence_as(structure, flat_structure)
def _recursive_assert_same_structure(nest1, nest2, check_types):
"""
Helper function for `assert_same_structure`.
......
......@@ -59,17 +59,7 @@ class SubNetWithDict(fluid.dygraph.Layer):
cache_k, cache_v = cache["k"], cache["v"]
k = 0.1 * cache_k + k
v = 0.2 * cache_v + v
# TODO: currently while_loop can have a dict as loop_vars, but
# to change the value in a dict, you have to use layers.assign
# because cache["k"] = k is putting k in dict without building
# network. So we cannot write:
#
# cache["k"], cache["v"] = k, v
#
# we have to support this kind of dict in loop in the future.
# For example, automatically change = to assign in AutoTracer
fluid.layers.assign(k, cache["k"])
fluid.layers.assign(v, cache["v"])
cache["k"], cache["v"] = k, v
weight = fluid.layers.matmul(x=q, y=k, transpose_y=True)
weight = fluid.layers.softmax(weight)
......@@ -108,16 +98,7 @@ class MainNetWithDict(fluid.dygraph.Layer):
def update_cache(self, cache):
for k, val in six.iteritems(cache):
# TODO: currently while_loop can have a dict as loop_vars, but
# to change the value in a dict, you have to use layers.assign
# because cache["k"] = k is putting k in dict without building
# network. So we cannot write:
#
# cache[k] = fluid.layers.softmax(val)
#
# we have to support this kind of dict in loop in the future.
# For example, automatically change = to assign in AutoTracer
fluid.layers.assign(fluid.layers.softmax(val), cache[k])
cache[k] = fluid.layers.softmax(val)
return cache
......
......@@ -78,13 +78,21 @@ class TestApiWhileLoop(unittest.TestCase):
self.assertTrue(np.allclose(np.asarray(res[1]), data))
def test_var_dict(self):
def cond(i, ten, test_dict):
def cond(i, ten, test_dict, test_list, test_list_dict):
return layers.less_than(i, ten)
def body(i, ten, test_dict):
layers.assign(i, test_dict["test_key"])
def body(i, ten, test_dict, test_list, test_list_dict):
test_dict["test_key"] = i
test_dict["test_key"] += 1
test_list[0] = fluid.layers.reshape(test_list[0], [2, -1]) + 1
test_list_dict[0]["test_key"] += 1
test_list_dict[0]["test_key"] = fluid.layers.relu(test_list_dict[0][
"test_key"])
i = layers.increment(i)
return [i, ten, test_dict]
return [i, ten, test_dict, test_list, test_list_dict]
main_program = Program()
startup_program = Program()
......@@ -92,18 +100,42 @@ class TestApiWhileLoop(unittest.TestCase):
i = layers.zeros(shape=[1], dtype='int64')
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
test_data = layers.fill_constant(shape=[1], dtype='int64', value=0)
test_dict = {"test_key": test_data}
i, ten, test_dict = layers.while_loop(cond, body,
[i, ten, test_dict])
test_list = [
layers.fill_constant(
shape=[1, 2], dtype='int64', value=0)
]
test_list_dict = [{
"test_key": layers.fill_constant(
shape=[1], dtype='float32', value=0)
}]
i, ten, test_dict, test_list, test_list_dict = layers.while_loop(
cond, body, [i, ten, test_dict, test_list, test_list_dict])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program, fetch_list=[test_dict["test_key"]])
res = exe.run(main_program,
fetch_list=[
test_dict["test_key"], test_list[0],
test_list_dict[0]["test_key"]
])
self.assertTrue(
np.allclose(
np.asarray(res[0]),
np.full(
shape=(1), fill_value=9, dtype=np.int64)))
shape=(1), fill_value=10, dtype=np.int64)))
self.assertTrue(
np.allclose(
np.asarray(res[1]),
np.full(
shape=(2, 1), fill_value=10, dtype=np.int64)))
self.assertTrue(
np.allclose(
np.asarray(res[2]),
np.full(
shape=(1), fill_value=10, dtype=np.float32)))
class TestApiWhileLoop_Nested(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册