未验证 提交 2e07c8b7 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] Parameter Recorder Part 2: new parameter collection mechanism (#50336)

* [dy2static] support fallback for whole graph. (stage 1)

* bug fix

* bug fix and add a new unittest

* fix code by code review

* fix coverage

* [dy2static] ParameterRecorder Part 2

* Parameter Recorder - 2 bug fix

* bugfix: fix the dygraph go into _jst.ld() errors.

* fix ci error.

* fix ci error
上级 c47f11f5
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
class FakeNet:
def __init__(self):
self.var = paddle.to_tensor([2.0])
f = FakeNet()
g = paddle.to_tensor([1.0])
class Net(paddle.nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
# unsupport g as store.
t = g * 2 + x
t = f.var * t
return t
class TestFallback(unittest.TestCase):
def setUp(self):
self.x = paddle.to_tensor(1.0).astype('int')
def test_name_load(self):
net_dy = Net()
net_st = Net()
output_dy = net_dy(self.x)
output_st = paddle.jit.to_static(net_st)(self.x)
np.testing.assert_allclose(output_dy.numpy(), output_st.numpy())
if __name__ == "__main__":
unittest.main()
......@@ -26,6 +26,7 @@ from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401
from .convert_operators import convert_attr as Attr # noqa: F401
from .convert_operators import convert_load as Ld # noqa: F401
from .convert_operators import indexable as Indexable # noqa: F401
from .variable_trans_func import create_bool_as_type # noqa: F401
from .variable_trans_func import to_static_variable # noqa: F401
......
......@@ -22,7 +22,7 @@ import os
from . import logging_utils
from .assert_transformer import AssertTransformer
from .base_transformer import BaseTransformer
from .basic_api_transformer import BasicApiTransformer
from .basic_api_transformer import BasicApiTransformer, NameloadJstTransformer
from .break_continue_transformer import (
BreakContinueTransformer,
BreakTransformOptimizer,
......@@ -93,6 +93,7 @@ class DygraphToStaticAst(BaseTransformer):
transformers = [
EarlyReturnTransformer,
DecoratorTransformer, # transform decorators to function call
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor)
BreakContinueTransformer, # break/continue in loops
......@@ -104,7 +105,7 @@ class DygraphToStaticAst(BaseTransformer):
AssertTransformer, # assert statement
CallTransformer, # transform call recursively
CastTransformer, # type casting statement
DecoratorTransformer, # transform decorators to function call
NameloadJstTransformer,
TypeHintTransformer, # remove all typehint in gast.Name
]
......
......@@ -43,7 +43,6 @@ class BasicApiTransformer(BaseTransformer):
attribute_transformer = AttributeJstTransformer(self.root)
attribute_transformer.transform()
self.visit(self.root)
return self.wrapper_root
def visit_Assign(self, node):
......@@ -127,6 +126,63 @@ class ToTensorTransformer(BaseTransformer):
return node
class NameloadJstTransformer(BaseTransformer):
"""
change name and attribute load to __jst.Ld(name) pattern.
for example:
a.dtype --> __jst.Ld(__jst.Ld(a).dtype)
In paddle science and deepxde, we have to support changing tensor into variable
in arbitrary occasion such as global tensor.
NOTE: we only deal with ctx=Load() case.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
return self.root
def _surround_with_ld(self, node):
node = (
gast.parse(
"_jst.Ld({})".format(utils.ast_to_source_code(node).strip())
)
.body[0]
.value
)
return node
def visit_Call(self, node):
"""
Can't convert name of function call, bacause this will affect CallTransformer.
"""
node.args = [self.generic_visit(arg) for arg in node.args]
return node
def visit_Attribute(self, node):
assert isinstance(node, gast.Attribute)
assert isinstance(node.attr, str)
self.generic_visit(node)
if isinstance(node.ctx, gast.Load):
node = self._surround_with_ld(node)
return node
def visit_Name(self, node):
assert isinstance(node, gast.Name)
self.generic_visit(node)
if isinstance(node.ctx, gast.Load):
node = self._surround_with_ld(node)
return node
class AttributeJstTransformer(BaseTransformer):
"""
change some special attribute into __jst.XXX(obj, "attr_name") format.
......
......@@ -16,6 +16,7 @@ import re
import paddle
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.base import _convert_into_variable
from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import Print, control_flow, fill_constant
from paddle.fluid.layers.control_flow import while_loop
......@@ -39,6 +40,17 @@ def convert_attr(x, attr):
return getattr(x, attr)
def convert_load(x):
from paddle.fluid.dygraph.base import in_declarative_mode
if in_declarative_mode() and isinstance(x, paddle.fluid.core.eager.Tensor):
"""
TODO:(@xiongkun) may run convert_load in dygraph mode, which should be fixed.
"""
return _convert_into_variable(x)
return x
def indexable(x, code=None):
if isinstance(x, Variable):
return x
......
......@@ -1145,9 +1145,10 @@ class ProgramCache:
def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
# NOTE(xiongkun): Need a global FLAGS to enable/disable fallback
enable_fallback = enable_prim
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
try:
concrete_program = ConcreteProgram.from_func_spec(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册