From 68c76793ca15aec944d64578f262df520c827f6e Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Mon, 24 Feb 2020 11:46:05 +0800 Subject: [PATCH] support dygraph basic api transformed to static api (#22678) * support dygraph to static graph for simple case. * add test for dygraph API recognition. test=develop * support basic to_variable api. test=develop * update dict: dygraph_class_to_static_api * add all tests of dygraph api. test=develop * use gast/astor instead of ast/codegen for the compatibility of PY2 and PY3. test=develop * add arg 'num_flatten_dims' for fc ast node. test=develop * Modify names of class by Camel-Case. --- .../dygraph_to_static/ast_transformer.py | 95 +++- .../fluid/dygraph/dygraph_to_static/utils.py | 177 ++++++++ python/paddle/fluid/dygraph/jit.py | 20 +- ...raph_to_static_basic_api_transformation.py | 426 ++++++++++++++++++ 4 files changed, 707 insertions(+), 11 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/utils.py create mode 100644 python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index d793dcecb8b..7e6f3fc34ae 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -13,7 +13,8 @@ # limitations under the License. from __future__ import print_function - +import astor +from .utils import * import gast # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). # It provides a compatibility layer between the AST of various Python versions, @@ -66,7 +67,7 @@ class IfElseTransformer(gast.NodeTransformer): def visit_Call(self, node): # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` - # Todo: should be removed. it may be considered as basic api transformation. + # TODO: should be removed. it may be considered as basic api transformation. if isinstance(node.func, gast.Attribute): attribute = node.func if attribute.attr == 'numpy': @@ -105,6 +106,10 @@ class DygraphToStaticAst(gast.NodeTransformer): def transfer_from_node_type(self, node): # Generic transformation self.visit(node.node) + + # Transform basic api of dygraph to static graph + BasicApiTransformer(node).ast_visit() + # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node).ast_visit() @@ -128,3 +133,89 @@ class DygraphToStaticAst(gast.NodeTransformer): # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name + + +class BasicApiTransformer(gast.NodeTransformer): + """ + Class to transform basic API from dygraph to static graph. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + self.class_node_dict = {} + + def ast_visit(self): + self.visit(self.root) + return self.wrapper_root + + def visit_FunctionDef(self, node): + self.generic_visit(node) + if hasattr(node, 'decorator_list'): + decorator_list = [ + d for d in node.decorator_list if d.id != DECORATOR_NAME + ] + node.decorator_list = decorator_list + return node + + def visit_Assign(self, node): + if self._update_class_node_dict(node): + return None + + value_node = node.value + for child_node in gast.walk(value_node): + if isinstance(child_node, gast.Call): + self._visit_Call(child_node) + + return node + + def visit_Expr(self, node): + value_node = node.value + for child_node in gast.walk(value_node): + if isinstance(child_node, gast.Call): + if is_dygraph_api(child_node): + return + else: + self._visit_Call(child_node) + + return node + + def _visit_Call(self, node): + assert isinstance(node, gast.Call) + + # Replace API `to_variable` with `fluid.layers.assign` + if is_to_variable(node): + node = to_assign_node(node) + return node + + func_name = astor.to_source(node.func) + if self._is_dygraph_forward(func_name): + class_node = self._get_class_node(func_name) + static_node = to_static_ast(node, class_node) + return static_node + else: + return node + + def _is_dygraph_forward(self, func_id): + return func_id in self.class_node_dict + + def _get_class_node(self, func_id): + return self.class_node_dict[func_id] + + def _update_class_node_dict(self, node): + assert isinstance(node, gast.Assign) + node_value = node.value + if isinstance(node_value, gast.Call): + if is_to_variable(node_value): + return False + + if is_dygraph_api(node_value): + update_args_of_func(node_value, node_value, "__init__") + target_str = astor.to_source(node.targets[0]) + self.class_node_dict[target_str] = node_value + return True + # TODO: node.value is not dygraph class + return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py new file mode 100644 index 00000000000..4ebe58da575 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -0,0 +1,177 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import inspect +import gast +import astor +import atexit +import os +import tempfile +import six +import imp + +dygraph_class_to_static_api = { + "BatchNorm": "batch_norm", + "BilinearTensorProduct": "bilinear_tensor_product", + "Conv2D": "conv2d", + "Conv3D": "conv3d", + "Conv2DTranspose": "conv2d_transpose", + "Conv3DTranspose": "conv3d_transpose", + "CosineDecay": "cosine_decay", + "Embedding": "embedding", + "ExponentialDecay": "exponential_decay", + "GroupNorm": "group_norm", + "GRUUnit": "gru_unit", + "InverseTimeDecay": "inverse_time_decay", + "LayerNorm": "layer_norm", + "Linear": "fc", + "NaturalExpDecay": "natural_exp_decay", + "NCE": "nce", + "NoamDecay": "noam_decay", + "PiecewiseDecay": "piecewise_decay", + "PolynomialDecay": "polynomial_decay", + "Pool2D": "pool2d", + "PRelu": "prelu", + "SpectralNorm": "spectral_norm", +} + + +def _delete_keywords_from(node): + assert isinstance(node, gast.Call) + func_src = astor.to_source(node.func) + import paddle.fluid as fluid + full_args = eval("inspect.getargspec({})".format(func_src)) + full_args_name = full_args[0] + + node.keywords = [k for k in node.keywords if k.arg in full_args_name] + return + + +def to_static_api(dygraph_class): + if dygraph_class in dygraph_class_to_static_api: + return dygraph_class_to_static_api[dygraph_class] + else: + raise NotImplementedError("Paddle dygraph API {} cannot be converted " + "to static graph at present.".format( + dygraph_class)) + + +def _add_keywords_to(node, dygraph_api_name): + assert isinstance(node, gast.Call) + if dygraph_api_name == "Linear": + for ast_keyword in node.keywords: + if ast_keyword.arg == "output_dim": + ast_keyword.arg = "size" + + node.keywords.append( + gast.keyword( + arg="num_flatten_dims", + value=gast.Constant( + value=-1, kind=None))) + + if dygraph_api_name == "BilinearTensorProduct": + for ast_keyword in node.keywords: + if ast_keyword.arg == "output_dim": + ast_keyword.arg = "size" + + if dygraph_api_name == "PRelu": + for ast_keyword in node.keywords: + if ast_keyword.arg == "input": + ast_keyword.arg = "x" + return + + +def _is_paddle_dygraph_api(obj): + m = inspect.getmodule(obj) + return m is not None and m.__name__.startswith("paddle.fluid.dygraph") + + +def is_dygraph_api(node): + assert isinstance(node, gast.Call) + func_src = astor.to_source(node.func) + try: + import paddle.fluid as fluid + return eval("_is_paddle_dygraph_api({})".format(func_src)) + except NameError: + return False + + +def is_to_variable(node): + assert isinstance(node, gast.Call) + if is_dygraph_api(node): + api_name = node.func.attr + return api_name == "to_variable" + return False + + +def to_static_ast(node, class_node): + assert isinstance(node, gast.Call) + assert isinstance(class_node, gast.Call) + static_api = to_static_api(class_node.func.attr) + + node.func = gast.Attribute( + attr=static_api, + ctx=gast.Load(), + value=gast.Attribute( + attr='layers', + ctx=gast.Load(), + value=gast.Name( + ctx=gast.Load(), id='fluid', annotation=None, + type_comment=None))) + + update_args_of_func(node, class_node, 'forward') + + node.args.extend(class_node.args) + node.keywords.extend(class_node.keywords) + _add_keywords_to(node, class_node.func.attr) + _delete_keywords_from(node) + + gast.fix_missing_locations(node) + + return node + + +def to_assign_node(ori_node): + assert isinstance(ori_node, gast.Call) + assign_api = gast.parse('fluid.layers.assign').body[0].value + ori_node.func = assign_api + return ori_node + + +def update_args_of_func(node, dygraph_node, method_name): + assert isinstance(node, gast.Call) + if method_name not in ["__init__", "forward"]: + raise ValueError( + "The method name of class to update args should be '__init__' or 'forward'" + ) + + class_src = astor.to_source(dygraph_node.func) + import paddle.fluid as fluid + if method_name == "__init__" or eval( + "issubclass({}, fluid.dygraph.Layer)".format(class_src)): + full_args = eval("inspect.getargspec({}.{})".format(class_src, + method_name)) + full_args_name = [ + arg_name for arg_name in full_args[0] if arg_name != "self" + ] + else: + full_args_name = [] + added_keywords = [] + for idx, arg in enumerate(node.args): + added_keywords.append(gast.keyword(arg=full_args_name[idx], value=arg)) + + node.args = [] + node.keywords = added_keywords + node.keywords diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 43012c08af9..f84173b8c7e 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + __all__ = ['TracedLayer', 'dygraph_to_static_output'] import gast @@ -58,11 +60,11 @@ def _dygraph_to_static_output_(dygraph_func): dygraph_code = inspect.getsource(dygraph_func) dygraph_code = textwrap.dedent(dygraph_code) root = gast.parse(dygraph_code) + # Transform AST dygraph_to_static = DygraphToStaticAst() root_wrapper = dygraph_to_static.get_static_ast(root) func_name = dygraph_to_static.get_module_name() - static_func, file_name = ast_to_func(root_wrapper.node, func_name) return static_func(*args, **kwargs) @@ -108,17 +110,17 @@ def _trace(layer, class TracedLayer(object): """ - TracedLayer is used to convert a forward dygraph model to a static - graph model. This is mainly used to save the dygraph model for online - inference using C++. Besides, users can also do inference in Python - using the converted static graph model, which usually has better - performance than the original dygraph model. + TracedLayer is used to convert a forward dygraph model to a static + graph model. This is mainly used to save the dygraph model for online + inference using C++. Besides, users can also do inference in Python + using the converted static graph model, which usually has better + performance than the original dygraph model. TracedLayer would run the static graph model using :code:`Executor` and :code:`CompiledProgram` . The static graph model would share parameters with the dygraph model. - - All TracedLayer objects should not be created by constructor and should + + All TracedLayer objects should not be created by constructor and should be created by static method :code:`TracedLayer.trace(layer, inputs)` . The TracedLayer can only be used to convert the data-independent dygraph @@ -159,7 +161,7 @@ class TracedLayer(object): @dygraph_only def trace(layer, inputs): """ - This method is the only allowed method to create TracedLayer object. + This method is the only allowed method to create TracedLayer object. It would call the :code:`layer(*inputs)` method to run the dygraph model and convert it into a static graph model. diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py new file mode 100644 index 00000000000..55895507c2c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py @@ -0,0 +1,426 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +import unittest +import inspect +import gast + +from paddle.fluid.dygraph.jit import dygraph_to_static_output +from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api + +SEED = 2020 +np.random.seed(SEED) + + +def dyfunc_to_variable(x): + res = fluid.dygraph.to_variable(x) + return res + + +class TestDygraphBasicApi_ToVariable(unittest.TestCase): + def setUp(self): + self.input = np.ones(5).astype("int32") + self.dygraph_func = dyfunc_to_variable + + def get_dygraph_output(self): + with fluid.dygraph.guard(): + res = self.dygraph_func(self.input).numpy() + + return res + + def get_static_output(self): + main_program = fluid.Program() + main_program.random_seed = SEED + with fluid.program_guard(main_program): + static_out = dygraph_to_static_output(self.dygraph_func)(self.input) + + exe = fluid.Executor(fluid.CPUPlace()) + static_res = exe.run(main_program, fetch_list=static_out) + + return static_res[0] + + def test_transformed_static_result(self): + dygraph_res = self.get_dygraph_output() + static_res = self.get_static_output() + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph is {}\n static_res is {}'.format(dygraph_res, + static_res)) + + +# 1. test Apis that inherit from layers.Layer + + +def dyfunc_BilinearTensorProduct(layer1, layer2): + bilinearTensorProduct = fluid.dygraph.nn.BilinearTensorProduct( + input1_dim=5, + input2_dim=4, + output_dim=1000, + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.99)), + bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5))) + + res = bilinearTensorProduct( + fluid.dygraph.base.to_variable(layer1), + fluid.dygraph.base.to_variable(layer2)) + return res + + +def dyfunc_Conv2D(input): + conv2d = fluid.dygraph.Conv2D( + num_channels=3, + num_filters=2, + filter_size=3, + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.99)), + bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)), ) + res = conv2d(input) + return res + + +def dyfunc_Conv3D(input): + conv3d = fluid.dygraph.Conv3D( + num_channels=3, + num_filters=2, + filter_size=3, + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.99)), + bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)), ) + res = conv3d(input) + return res + + +def dyfunc_Conv2DTranspose(input): + conv2dTranspose = fluid.dygraph.nn.Conv2DTranspose( + num_channels=3, + num_filters=12, + filter_size=12, + use_cudnn=False, + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.99)), + bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)), ) + ret = conv2dTranspose(input) + return ret + + +def dyfunc_Conv3DTranspose(input): + conv3dTranspose = fluid.dygraph.nn.Conv3DTranspose( + num_channels=3, + num_filters=12, + filter_size=12, + use_cudnn=False, + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.99)), + bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)), ) + ret = conv3dTranspose(input) + return ret + + +def dyfunc_Linear(input): + fc = fluid.dygraph.Linear( + input_dim=10, + output_dim=5, + act='relu', + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.99)), + bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)), ) + res = fc(input) + return res + + +def dyfunc_Pool2D(input): + fluid.dygraph.Pool2D( + pool_size=2, pool_type='avg', pool_stride=1, global_pooling=False) + pool2d = fluid.dygraph.Pool2D( + pool_size=2, pool_type='avg', pool_stride=1, global_pooling=False) + res = pool2d(input) + return res + + +def dyfunc_Prelu(input): + prelu0 = fluid.PRelu( + mode='all', + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(1.0))) + res = prelu0(input=input) + return res + + +class TestDygraphBasicApi(unittest.TestCase): + # Compare results of dynamic graph and transformed static graph function which only + # includes basic Api. + + def setUp(self): + self.input = np.random.random((1, 4, 3, 3)).astype('float32') + self.dygraph_func = dyfunc_Pool2D + + def get_dygraph_output(self): + with fluid.dygraph.guard(): + fluid.default_startup_program.random_seed = SEED + fluid.default_main_program.random_seed = SEED + data = fluid.dygraph.to_variable(self.input) + res = self.dygraph_func(data).numpy() + + return res + + def get_static_output(self): + startup_program = fluid.Program() + startup_program.random_seed = SEED + main_program = fluid.Program() + main_program.random_seed = SEED + with fluid.program_guard(main_program, startup_program): + data = fluid.layers.assign(self.input) + static_out = dygraph_to_static_output(self.dygraph_func)(data) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + static_res = exe.run(main_program, fetch_list=static_out) + return static_res[0] + + def test_transformed_static_result(self): + dygraph_res = self.get_dygraph_output() + static_res = self.get_static_output() + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res, + static_res)) + + +class TestDygraphBasicApi_BilinearTensorProduct(TestDygraphBasicApi): + def setUp(self): + self.input1 = np.random.random((5, 5)).astype('float32') + self.input2 = np.random.random((5, 4)).astype('float32') + self.dygraph_func = dyfunc_BilinearTensorProduct + + def get_dygraph_output(self): + with fluid.dygraph.guard(): + fluid.default_startup_program.random_seed = SEED + fluid.default_main_program.random_seed = SEED + res = self.dygraph_func(self.input1, self.input2).numpy() + return res + + def get_static_output(self): + startup_program = fluid.Program() + startup_program.random_seed = SEED + main_program = fluid.Program() + main_program.random_seed = SEED + with fluid.program_guard(main_program, startup_program): + static_out = dygraph_to_static_output(self.dygraph_func)( + self.input1, self.input2) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + static_res = exe.run(main_program, fetch_list=static_out) + return static_res[0] + + +class TestDygraphBasicApi_Conv2D(TestDygraphBasicApi): + def setUp(self): + self.input = np.random.random((1, 3, 3, 5)).astype('float32') + self.dygraph_func = dyfunc_Conv2D + + +class TestDygraphBasicApi_Conv3D(TestDygraphBasicApi): + def setUp(self): + self.input = np.random.random((1, 3, 3, 3, 5)).astype('float32') + self.dygraph_func = dyfunc_Conv3D + + +class TestDygraphBasicApi_Conv2DTranspose(TestDygraphBasicApi): + def setUp(self): + self.input = np.random.random((5, 3, 32, 32)).astype('float32') + self.dygraph_func = dyfunc_Conv2DTranspose + + +class TestDygraphBasicApi_Conv3DTranspose(TestDygraphBasicApi): + def setUp(self): + self.input = np.random.random((5, 3, 12, 32, 32)).astype('float32') + self.dygraph_func = dyfunc_Conv3DTranspose + + +class TestDygraphBasicApi_Linear(TestDygraphBasicApi): + def setUp(self): + self.input = np.random.random((4, 3, 10)).astype('float32') + self.dygraph_func = dyfunc_Linear + + +class TestDygraphBasicApi_Prelu(TestDygraphBasicApi): + def setUp(self): + self.input = np.ones([5, 20, 10, 10]).astype('float32') + self.dygraph_func = dyfunc_Prelu + + +# 2. test Apis that inherit from LearningRateDecay +def dyfunc_CosineDecay(): + base_lr = 0.1 + CosineDecay = fluid.dygraph.CosineDecay( + learning_rate=base_lr, step_each_epoch=10000, epochs=120) + lr = CosineDecay() + return lr + + +def dyfunc_ExponentialDecay(): + base_lr = 0.1 + exponential_decay = fluid.dygraph.ExponentialDecay( + learning_rate=base_lr, + decay_steps=10000, + decay_rate=0.5, + staircase=True) + lr = exponential_decay() + return lr + + +def dyfunc_InverseTimeDecay(): + base_lr = 0.1 + inverse_time_decay = fluid.dygraph.InverseTimeDecay( + learning_rate=base_lr, + decay_steps=10000, + decay_rate=0.5, + staircase=True) + lr = inverse_time_decay() + return lr + + +def dyfunc_NaturalExpDecay(): + base_lr = 0.1 + natural_exp_decay = fluid.dygraph.NaturalExpDecay( + learning_rate=base_lr, + decay_steps=10000, + decay_rate=0.5, + staircase=True) + lr = natural_exp_decay() + return lr + + +def dyfunc_NoamDecay(): + noam_decay = fluid.dygraph.NoamDecay(100, 100) + lr = noam_decay() + return lr + + +def dyfunc_PiecewiseDecay(): + boundaries = [10000, 20000] + values = [1.0, 0.5, 0.1] + pd = fluid.dygraph.PiecewiseDecay(boundaries, values, begin=0) + lr = pd() + return lr + + +def dyfunc_PolynomialDecay(): + start_lr = 0.01 + total_step = 5000 + end_lr = 0 + pd = fluid.dygraph.PolynomialDecay(start_lr, total_step, end_lr, power=1.0) + lr = pd() + return lr + + +class TestDygraphBasicApi_CosineDecay(unittest.TestCase): + def setUp(self): + self.dygraph_func = dyfunc_CosineDecay + + def get_dygraph_output(self): + with fluid.dygraph.guard(): + fluid.default_startup_program.random_seed = SEED + fluid.default_main_program.random_seed = SEED + res = self.dygraph_func().numpy() + return res + + def get_static_output(self): + startup_program = fluid.Program() + startup_program.random_seed = SEED + main_program = fluid.Program() + main_program.random_seed = SEED + with fluid.program_guard(main_program, startup_program): + static_out = dygraph_to_static_output(self.dygraph_func)() + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + static_res = exe.run(main_program, fetch_list=static_out) + return static_res[0] + + def test_transformed_static_result(self): + dygraph_res = self.get_dygraph_output() + static_res = self.get_static_output() + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res, + static_res)) + + +class TestDygraphBasicApi_ExponentialDecay(TestDygraphBasicApi_CosineDecay): + def setUp(self): + self.dygraph_func = dyfunc_ExponentialDecay + + +class TestDygraphBasicApi_InverseTimeDecay(TestDygraphBasicApi_CosineDecay): + def setUp(self): + self.dygraph_func = dyfunc_InverseTimeDecay + + +class TestDygraphBasicApi_NaturalExpDecay(TestDygraphBasicApi_CosineDecay): + def setUp(self): + self.dygraph_func = dyfunc_NaturalExpDecay + + +class TestDygraphBasicApi_NoamDecay(TestDygraphBasicApi_CosineDecay): + def setUp(self): + self.dygraph_func = dyfunc_NoamDecay + + +class TestDygraphBasicApi_PiecewiseDecay(TestDygraphBasicApi_CosineDecay): + def setUp(self): + self.dygraph_func = dyfunc_PiecewiseDecay + + +class TestDygraphBasicApi_PolynomialDecay(TestDygraphBasicApi_CosineDecay): + def setUp(self): + self.dygraph_func = dyfunc_PolynomialDecay + + +def _dygraph_fn(): + import paddle.fluid as fluid + x = np.random.random((1, 3)).astype('float32') + with fluid.dygraph.guard(): + fluid.dygraph.to_variable(x) + np.random.random((1)) + + +class TestDygraphApiRecognition(unittest.TestCase): + def setUp(self): + self.src = inspect.getsource(_dygraph_fn) + self.root = gast.parse(self.src) + + def _get_dygraph_ast_node(self): + return self.root.body[0].body[2].body[0].value + + def _get_static_ast_node(self): + return self.root.body[0].body[2].body[1].value + + def test_dygraph_api(self): + self.assertTrue(is_dygraph_api(self._get_dygraph_ast_node()) is True) + self.assertTrue(is_dygraph_api(self._get_static_ast_node()) is False) + + +if __name__ == '__main__': + unittest.main() -- GitLab