From 05c00af5f16da64d1e8953711c647512121ef3d2 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 25 Mar 2020 17:26:09 +0800 Subject: [PATCH] Add dygraph_to_static_code and get_code in ProgramTranslator (#23162) As the title, we add decorator "dygraph_to_static_code", and add related "get_code", "get_func", "get_output" for ProgramTranslator. Next step will be adding "dygraph_to_static_program" --- .../dygraph_to_static/program_translator.py | 87 +++++++++++++++---- python/paddle/fluid/dygraph/jit.py | 39 +++++---- .../test_program_translator.py | 79 +++++++++++++++++ 3 files changed, 170 insertions(+), 35 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 36a18c18d9..a3f9e0b9c2 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -13,16 +13,21 @@ # limitations under the License. from __future__ import print_function +import gast import inspect -import textwrap -import threading import numpy import six +import textwrap +import threading +import warnings from paddle.fluid import framework from paddle.fluid import core, executor from paddle.fluid.data import data -from paddle.fluid.dygraph.dygraph_to_static import convert_to_static +from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static +from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.framework import in_dygraph_mode __all__ = ['ProgramTranslator'] @@ -89,8 +94,8 @@ class ProgramCache(object): self._feed_name_to_idx = {} self._is_repeated = False # Indicates whether the function call is still building program. - # Because `__call__` can be called recursively when `Net` has - # sub class in `forward()`. + # Because user can call recursively when `Net` has sub class in + # `forward()`. self._in_build_process = True def build_program_and_return_output(self, dyfunc, *args, **kwargs): @@ -242,6 +247,52 @@ class ProgramTranslator(object): # Once main_program is changed, should run startup_program. self._need_startup = True + def get_output(self, dygraph_func, *args, **kwargs): + """ + Returns the output tensors for dygraph function and its arguments + """ + if in_dygraph_mode(): + warnings.warn( + "The decorator 'dygraph_to_static_output' doesn't work in dygraph mode." + " Please use it in static mode.") + return dygraph_func(*args, **kwargs) + + program_cache = self.get_program_cache() + outputs = program_cache.build_program_and_return_output(dygraph_func, + *args, **kwargs) + if not program_cache.in_build_process: + outputs = self.run(*args, **kwargs) + return outputs + + def get_func(self, dygraph_func): + """ + Returns the translated static function from dygraph function + """ + if in_dygraph_mode(): + warnings.warn( + "The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode." + " Please use it in static mode.") + return dygraph_func + static_func, ast_transformer = convert_to_static(dygraph_func) + return static_func + + def get_code(self, dygraph_func): + """ + Returns the translated static function code from dygraph code + """ + # Get AST from dygraph function + raw_code = inspect.getsource(dygraph_func) + code = textwrap.dedent(raw_code) + root = gast.parse(code) + + # Transform AST + dygraph_to_static = DygraphToStaticAst() + root_wrapper = dygraph_to_static.get_static_ast(root) + + # Get source_code + source_code = ast_to_source_code(root_wrapper.node) + return source_code + def run(self, *args, **kwargs): """ Executes main_program and returns output Tensors. @@ -255,6 +306,19 @@ class ProgramTranslator(object): return outputs + def set_optimizer(self, optimizer, loss_name): + """ + Supports to set or update the optimizer used to minimize loss. + """ + self._check_cache_valid() + self._optimizer = optimizer + + if not isinstance(loss_name, six.string_types): + raise ValueError( + "Type of input loss_name should type(str), but received {}.". + format(type(loss_name))) + self._loss_name = loss_name + def _prepare(self, args): """ Prepares with feed_dict, fetch_list, optimizer and initialize vars @@ -298,19 +362,6 @@ class ProgramTranslator(object): return feed_dict - def set_optimizer(self, optimizer, loss_name): - """ - Supports to set or update the optimizer used to minimize loss. - """ - self._check_cache_valid() - self._optimizer = optimizer - - if not isinstance(loss_name, six.string_types): - raise ValueError( - "Type of input loss_name should type(str), but received {}.". - format(type(loss_name))) - self._loss_name = loss_name - def _add_optimizer(self): """ Supports to set or update the optimizer used to minimize loss. diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 56b1a9362a..12e4c1927b 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -14,7 +14,10 @@ from __future__ import print_function -__all__ = ['TracedLayer', 'dygraph_to_static_output', 'dygraph_to_static_graph'] +__all__ = [ + 'TracedLayer', 'dygraph_to_static_code', 'dygraph_to_static_graph', + 'dygraph_to_static_output' +] import warnings @@ -51,41 +54,43 @@ def extract_vars(inputs): return result_list -def _dygraph_to_static_graph_(dygraph_func): +def _dygraph_to_static_code_(dygraph_func): def __impl__(*args, **kwargs): if in_dygraph_mode(): warnings.warn( - "The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode." + "The decorator 'dygraph_to_static_code' doesn't work in dygraph mode." " Please use it in static mode.") return dygraph_func(*args, **kwargs) - static_func, ast_transformer = convert_to_static(dygraph_func) - return static_func(*args, **kwargs) + program_translator = ProgramTranslator() + return program_translator.get_code(dygraph_func) return __impl__ -dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_) - +dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_) -def _dygraph_to_static_output_(dygraph_func): - program_translator = ProgramTranslator() +def _dygraph_to_static_graph_(dygraph_func): def __impl__(*args, **kwargs): if in_dygraph_mode(): warnings.warn( - "The decorator 'dygraph_to_static_output' doesn't work in dygraph mode." + "The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode." " Please use it in static mode.") return dygraph_func(*args, **kwargs) + program_translator = ProgramTranslator() + static_func = program_translator.get_func(dygraph_func) + return static_func(*args, **kwargs) + + return __impl__ - program_cache = program_translator.get_program_cache() - outputs = program_cache.build_program_and_return_output(dygraph_func, - *args, **kwargs) - # Run program to fetch output Tensors once building successfully. - if not program_cache.in_build_process: - outputs = program_translator.run(*args, **kwargs) +dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_) - return outputs + +def _dygraph_to_static_output_(dygraph_func): + def __impl__(*args, **kwargs): + program_translator = ProgramTranslator() + return program_translator.get_output(dygraph_func, *args, **kwargs) return __impl__ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py new file mode 100644 index 0000000000..f96f1ceafa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid as fluid + +from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator +from paddle.fluid.dygraph.jit import dygraph_to_static_code + +from ifelse_simple_func import dyfunc_with_if_else + + +class TestDygraphToStaticCode(unittest.TestCase): + def setUp(self): + # set to print all string diff when assertEqual fails + self.maxDiff = None + + def test_decorator(self): + answer = "\ +def dyfunc_with_if_else(x_v, label=None):\n\ +\n\ + def true_fn_0(x_v):\n\ + x_v = x_v - 1\n\ + return x_v\n\ +\n\ + def false_fn_0(x_v):\n\ + x_v = x_v + 1\n\ + return x_v\n\ + x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\ + true_fn_0(x_v), lambda : false_fn_0(x_v))\n\ + if label is not None:\n\ + loss = fluid.layers.cross_entropy(x_v, label)\n\ + return loss\n\ + return x_v\n" + + x_v = None + code = dygraph_to_static_code(dyfunc_with_if_else)(x_v) + self.assertEqual(answer, code) + + def test_program_translator(self): + answer = "\ +def dyfunc_with_if_else(x_v, label=None):\n\ +\n\ + def true_fn_1(x_v):\n\ + x_v = x_v - 1\n\ + return x_v\n\ +\n\ + def false_fn_1(x_v):\n\ + x_v = x_v + 1\n\ + return x_v\n\ + x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\ + true_fn_1(x_v), lambda : false_fn_1(x_v))\n\ + if label is not None:\n\ + loss = fluid.layers.cross_entropy(x_v, label)\n\ + return loss\n\ + return x_v\n" + + program_translator = ProgramTranslator() + code = program_translator.get_code(dyfunc_with_if_else) + self.assertEqual(answer, code) + + +if __name__ == '__main__': + unittest.main() -- GitLab