未验证 提交 d3f98088 编写于 作者: X xiongkun 提交者: GitHub

[Dy2static] support set_value_op in static mode by _jst.Ld() (#56028)

* Fix and add unittest

* don't introduce assign when already in global block.

* fix more

* fix bugs

* fix bugs

* fix ci

* fix bfgs

* make function local
上级 d1ea359b
......@@ -3573,6 +3573,10 @@ def _stride_in_no_check_dy2st_diff():
def check_if_to_static_diff_with_dygraph(op_type, inplace_map, outputs):
if (
op_type == "while"
): # dont' need check while, while is only a wrapper of inner ops, we will stuck in inner op.
return
if outputs is not None:
for k, v in outputs.items():
if isinstance(v, Variable):
......
......@@ -466,6 +466,39 @@ class LayerHelperBase:
stop_gradient=stop_gradient,
)
def _create_global_variable_for_type_inference(
self, dtype, stop_gradient=False, shape=None
):
"""Create a global variable that should be type inferred layer.
Note:
The default type will be set to LOD_TENSOR. However, when
the var is used as operator output, its type will be updated
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
# set global dtype
if not dtype:
dtype = self.__dtype
output = self.main_program.global_block().create_var(
name=unique_name.generate_with_ignorable_key(
".".join([self.name, 'tmp'])
),
dtype=dtype,
shape=shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=stop_gradient,
)
saved_block_id = self.main_program.current_block_idx
self.main_program.current_block_idx = 0
paddle.tensor.creation.fill_constant(
output.shape, dtype, 0.0, force_cpu=False, out=output
)
output.stop_gradient = stop_gradient
self.main_program.current_block_idx = saved_block_id
return output
def create_sparse_variable_for_type_inference(
self, dtype, stop_gradient=False, shape=None
):
......
......@@ -566,7 +566,13 @@ def _setitem_impl_(var, item, value):
output = var
else:
helper = paddle.fluid.layer_helper.LayerHelper('set_value', **locals())
output = helper.create_variable_for_type_inference(dtype=var.dtype)
if helper.main_program.current_block_idx != 0:
# not in global block, we should create a global variable.
output = helper._create_global_variable_for_type_inference(
dtype=var.dtype
)
else:
output = helper.create_variable_for_type_inference(dtype=var.dtype)
cur_block = default_main_program().current_block()
cur_block.append_op(
......@@ -909,7 +915,15 @@ def _setitem_static(x, indices, values):
helper = paddle.fluid.layer_helper.LayerHelper(
'set_value', **locals()
)
output = helper.create_variable_for_type_inference(dtype=x.dtype)
if helper.main_program.current_block_idx != 0:
# not in global block, we should create a global variable.
output = helper._create_global_variable_for_type_inference(
dtype=x.dtype
)
else:
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
......@@ -975,7 +989,15 @@ def _setitem_static(x, indices, values):
helper = paddle.fluid.layer_helper.LayerHelper(
'set_value', **locals()
)
output = helper.create_variable_for_type_inference(dtype=x.dtype)
if helper.main_program.current_block_idx != 0:
# not in global block, we should create a global variable.
output = helper._create_global_variable_for_type_inference(
dtype=x.dtype
)
else:
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
......
......@@ -152,11 +152,14 @@ class NameloadJstTransformer(BaseTransformer):
Can't convert name of function call, bacause this will affect CallTransformer.
"""
node.args = [self.generic_visit(arg) for arg in node.args]
node.func = self.generic_visit(node.func)
return node
def visit_Attribute(self, node):
assert isinstance(node, gast.Attribute)
assert isinstance(node.attr, str)
if utils.ast_to_source_code(node).startswith("_jst."): # skip _jst.xxx
return node
self.generic_visit(node)
if isinstance(node.ctx, gast.Load):
node = self._surround_with_ld(node)
......
......@@ -1296,9 +1296,16 @@ class ParametersMap:
params = self.params_dict.get(self._program_hash(program))
if params is None:
return None
if id in params.keys():
return params[id]
return None
if id not in params:
return None
root_var = params[id]
saved = []
while root_var.desc.id() in params.keys():
saved.append(root_var)
root_var = params[root_var.desc.id()]
for var in saved:
params[var.desc.id()] = root_var
return root_var
def _program_hash(self, program):
"""
......
......@@ -589,7 +589,8 @@ def assign_skip_lod_tensor_array(input, output):
# input is not generated in While sub block and modified by in-place and only
# belong to inplace ops in constructing program process, because in-place pass
# is only available in Graph level.
paddle.assign(input, output)
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
paddle.assign(input, output)
def while_loop(cond, body, loop_vars, is_test=False, name=None):
......
......@@ -179,22 +179,56 @@ class TestCase11(TestSetItemBase):
class TestCase12(TestSetItemBase):
# Test combind-indexing
# Test gradient of value tensor
def init_func(self):
def foo(x, value):
y = x + 1
y[[0, 1], 1, :2] = value
return y
def foo():
res = paddle.zeros([4, 3, 2])
b = paddle.zeros([4, 3, 2])
v = paddle.to_tensor(1.0)
for i in range(paddle.shape(b)[0]):
res[i] = v
return res
return foo
def run_dygraph(self, func):
x = self.init_data()
value = paddle.ones((32,))
value.stop_gradient = False
y = func(x, value)
x_grad, value_grad = paddle.grad(y, [x, value])
return y, x_grad, value_grad
y = func()
return (y,)
class TestCase13(TestSetItemBase):
# Test gradient of value tensor
def init_func(self):
def foo():
res = paddle.zeros([4, 3, 2])
v = paddle.to_tensor(1.0)
for i in range(4):
res[i] = v
return res
return foo
def run_dygraph(self, func):
y = func()
return (y,)
class TestCase14(TestSetItemBase):
# Test gradient of value tensor
def init_func(self):
def foo():
data = np.arange(8).reshape((2, 4)).astype('float32')
x = paddle.to_tensor(data)
x[:, 1:] = x[:, :-1].clone()
x[:, 0] = 1
res = x.flatten()
return res
return foo
def run_dygraph(self, func):
y = func()
return (y,)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册