未验证 提交 d3f98088 编写于 作者: X xiongkun 提交者: GitHub

[Dy2static] support set_value_op in static mode by _jst.Ld() (#56028)

* Fix and add unittest

* don't introduce assign when already in global block.

* fix more

* fix bugs

* fix bugs

* fix ci

* fix bfgs

* make function local
上级 d1ea359b
...@@ -3573,6 +3573,10 @@ def _stride_in_no_check_dy2st_diff(): ...@@ -3573,6 +3573,10 @@ def _stride_in_no_check_dy2st_diff():
def check_if_to_static_diff_with_dygraph(op_type, inplace_map, outputs): def check_if_to_static_diff_with_dygraph(op_type, inplace_map, outputs):
if (
op_type == "while"
): # dont' need check while, while is only a wrapper of inner ops, we will stuck in inner op.
return
if outputs is not None: if outputs is not None:
for k, v in outputs.items(): for k, v in outputs.items():
if isinstance(v, Variable): if isinstance(v, Variable):
......
...@@ -466,6 +466,39 @@ class LayerHelperBase: ...@@ -466,6 +466,39 @@ class LayerHelperBase:
stop_gradient=stop_gradient, stop_gradient=stop_gradient,
) )
def _create_global_variable_for_type_inference(
self, dtype, stop_gradient=False, shape=None
):
"""Create a global variable that should be type inferred layer.
Note:
The default type will be set to LOD_TENSOR. However, when
the var is used as operator output, its type will be updated
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
# set global dtype
if not dtype:
dtype = self.__dtype
output = self.main_program.global_block().create_var(
name=unique_name.generate_with_ignorable_key(
".".join([self.name, 'tmp'])
),
dtype=dtype,
shape=shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=stop_gradient,
)
saved_block_id = self.main_program.current_block_idx
self.main_program.current_block_idx = 0
paddle.tensor.creation.fill_constant(
output.shape, dtype, 0.0, force_cpu=False, out=output
)
output.stop_gradient = stop_gradient
self.main_program.current_block_idx = saved_block_id
return output
def create_sparse_variable_for_type_inference( def create_sparse_variable_for_type_inference(
self, dtype, stop_gradient=False, shape=None self, dtype, stop_gradient=False, shape=None
): ):
......
...@@ -566,7 +566,13 @@ def _setitem_impl_(var, item, value): ...@@ -566,7 +566,13 @@ def _setitem_impl_(var, item, value):
output = var output = var
else: else:
helper = paddle.fluid.layer_helper.LayerHelper('set_value', **locals()) helper = paddle.fluid.layer_helper.LayerHelper('set_value', **locals())
output = helper.create_variable_for_type_inference(dtype=var.dtype) if helper.main_program.current_block_idx != 0:
# not in global block, we should create a global variable.
output = helper._create_global_variable_for_type_inference(
dtype=var.dtype
)
else:
output = helper.create_variable_for_type_inference(dtype=var.dtype)
cur_block = default_main_program().current_block() cur_block = default_main_program().current_block()
cur_block.append_op( cur_block.append_op(
...@@ -909,7 +915,15 @@ def _setitem_static(x, indices, values): ...@@ -909,7 +915,15 @@ def _setitem_static(x, indices, values):
helper = paddle.fluid.layer_helper.LayerHelper( helper = paddle.fluid.layer_helper.LayerHelper(
'set_value', **locals() 'set_value', **locals()
) )
output = helper.create_variable_for_type_inference(dtype=x.dtype) if helper.main_program.current_block_idx != 0:
# not in global block, we should create a global variable.
output = helper._create_global_variable_for_type_inference(
dtype=x.dtype
)
else:
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block() cur_block = default_main_program().current_block()
cur_block.append_op( cur_block.append_op(
type="set_value", type="set_value",
...@@ -975,7 +989,15 @@ def _setitem_static(x, indices, values): ...@@ -975,7 +989,15 @@ def _setitem_static(x, indices, values):
helper = paddle.fluid.layer_helper.LayerHelper( helper = paddle.fluid.layer_helper.LayerHelper(
'set_value', **locals() 'set_value', **locals()
) )
output = helper.create_variable_for_type_inference(dtype=x.dtype) if helper.main_program.current_block_idx != 0:
# not in global block, we should create a global variable.
output = helper._create_global_variable_for_type_inference(
dtype=x.dtype
)
else:
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block() cur_block = default_main_program().current_block()
cur_block.append_op( cur_block.append_op(
type="set_value", type="set_value",
......
...@@ -152,11 +152,14 @@ class NameloadJstTransformer(BaseTransformer): ...@@ -152,11 +152,14 @@ class NameloadJstTransformer(BaseTransformer):
Can't convert name of function call, bacause this will affect CallTransformer. Can't convert name of function call, bacause this will affect CallTransformer.
""" """
node.args = [self.generic_visit(arg) for arg in node.args] node.args = [self.generic_visit(arg) for arg in node.args]
node.func = self.generic_visit(node.func)
return node return node
def visit_Attribute(self, node): def visit_Attribute(self, node):
assert isinstance(node, gast.Attribute) assert isinstance(node, gast.Attribute)
assert isinstance(node.attr, str) assert isinstance(node.attr, str)
if utils.ast_to_source_code(node).startswith("_jst."): # skip _jst.xxx
return node
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.ctx, gast.Load): if isinstance(node.ctx, gast.Load):
node = self._surround_with_ld(node) node = self._surround_with_ld(node)
......
...@@ -1296,9 +1296,16 @@ class ParametersMap: ...@@ -1296,9 +1296,16 @@ class ParametersMap:
params = self.params_dict.get(self._program_hash(program)) params = self.params_dict.get(self._program_hash(program))
if params is None: if params is None:
return None return None
if id in params.keys(): if id not in params:
return params[id] return None
return None root_var = params[id]
saved = []
while root_var.desc.id() in params.keys():
saved.append(root_var)
root_var = params[root_var.desc.id()]
for var in saved:
params[var.desc.id()] = root_var
return root_var
def _program_hash(self, program): def _program_hash(self, program):
""" """
......
...@@ -589,7 +589,8 @@ def assign_skip_lod_tensor_array(input, output): ...@@ -589,7 +589,8 @@ def assign_skip_lod_tensor_array(input, output):
# input is not generated in While sub block and modified by in-place and only # input is not generated in While sub block and modified by in-place and only
# belong to inplace ops in constructing program process, because in-place pass # belong to inplace ops in constructing program process, because in-place pass
# is only available in Graph level. # is only available in Graph level.
paddle.assign(input, output) with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
paddle.assign(input, output)
def while_loop(cond, body, loop_vars, is_test=False, name=None): def while_loop(cond, body, loop_vars, is_test=False, name=None):
......
...@@ -179,22 +179,56 @@ class TestCase11(TestSetItemBase): ...@@ -179,22 +179,56 @@ class TestCase11(TestSetItemBase):
class TestCase12(TestSetItemBase): class TestCase12(TestSetItemBase):
# Test combind-indexing # Test gradient of value tensor
def init_func(self): def init_func(self):
def foo(x, value): def foo():
y = x + 1 res = paddle.zeros([4, 3, 2])
y[[0, 1], 1, :2] = value b = paddle.zeros([4, 3, 2])
return y v = paddle.to_tensor(1.0)
for i in range(paddle.shape(b)[0]):
res[i] = v
return res
return foo return foo
def run_dygraph(self, func): def run_dygraph(self, func):
x = self.init_data() y = func()
value = paddle.ones((32,)) return (y,)
value.stop_gradient = False
y = func(x, value)
x_grad, value_grad = paddle.grad(y, [x, value]) class TestCase13(TestSetItemBase):
return y, x_grad, value_grad # Test gradient of value tensor
def init_func(self):
def foo():
res = paddle.zeros([4, 3, 2])
v = paddle.to_tensor(1.0)
for i in range(4):
res[i] = v
return res
return foo
def run_dygraph(self, func):
y = func()
return (y,)
class TestCase14(TestSetItemBase):
# Test gradient of value tensor
def init_func(self):
def foo():
data = np.arange(8).reshape((2, 4)).astype('float32')
x = paddle.to_tensor(data)
x[:, 1:] = x[:, :-1].clone()
x[:, 0] = 1
res = x.flatten()
return res
return foo
def run_dygraph(self, func):
y = func()
return (y,)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册