未验证 提交 05c00af5 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add dygraph_to_static_code and get_code in ProgramTranslator (#23162)

As the title, we add decorator "dygraph_to_static_code", and add related "get_code", "get_func", "get_output" for ProgramTranslator. Next step will be adding "dygraph_to_static_program"
上级 cc8ca8ce
......@@ -13,16 +13,21 @@
# limitations under the License.
from __future__ import print_function
import gast
import inspect
import textwrap
import threading
import numpy
import six
import textwrap
import threading
import warnings
from paddle.fluid import framework
from paddle.fluid import core, executor
from paddle.fluid.data import data
from paddle.fluid.dygraph.dygraph_to_static import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.framework import in_dygraph_mode
__all__ = ['ProgramTranslator']
......@@ -89,8 +94,8 @@ class ProgramCache(object):
self._feed_name_to_idx = {}
self._is_repeated = False
# Indicates whether the function call is still building program.
# Because `__call__` can be called recursively when `Net` has
# sub class in `forward()`.
# Because user can call recursively when `Net` has sub class in
# `forward()`.
self._in_build_process = True
def build_program_and_return_output(self, dyfunc, *args, **kwargs):
......@@ -242,6 +247,52 @@ class ProgramTranslator(object):
# Once main_program is changed, should run startup_program.
self._need_startup = True
def get_output(self, dygraph_func, *args, **kwargs):
"""
Returns the output tensors for dygraph function and its arguments
"""
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache()
outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs)
if not program_cache.in_build_process:
outputs = self.run(*args, **kwargs)
return outputs
def get_func(self, dygraph_func):
"""
Returns the translated static function from dygraph function
"""
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func
static_func, ast_transformer = convert_to_static(dygraph_func)
return static_func
def get_code(self, dygraph_func):
"""
Returns the translated static function code from dygraph code
"""
# Get AST from dygraph function
raw_code = inspect.getsource(dygraph_func)
code = textwrap.dedent(raw_code)
root = gast.parse(code)
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
# Get source_code
source_code = ast_to_source_code(root_wrapper.node)
return source_code
def run(self, *args, **kwargs):
"""
Executes main_program and returns output Tensors.
......@@ -255,6 +306,19 @@ class ProgramTranslator(object):
return outputs
def set_optimizer(self, optimizer, loss_name):
"""
Supports to set or update the optimizer used to minimize loss.
"""
self._check_cache_valid()
self._optimizer = optimizer
if not isinstance(loss_name, six.string_types):
raise ValueError(
"Type of input loss_name should type(str), but received {}.".
format(type(loss_name)))
self._loss_name = loss_name
def _prepare(self, args):
"""
Prepares with feed_dict, fetch_list, optimizer and initialize vars
......@@ -298,19 +362,6 @@ class ProgramTranslator(object):
return feed_dict
def set_optimizer(self, optimizer, loss_name):
"""
Supports to set or update the optimizer used to minimize loss.
"""
self._check_cache_valid()
self._optimizer = optimizer
if not isinstance(loss_name, six.string_types):
raise ValueError(
"Type of input loss_name should type(str), but received {}.".
format(type(loss_name)))
self._loss_name = loss_name
def _add_optimizer(self):
"""
Supports to set or update the optimizer used to minimize loss.
......
......@@ -14,7 +14,10 @@
from __future__ import print_function
__all__ = ['TracedLayer', 'dygraph_to_static_output', 'dygraph_to_static_graph']
__all__ = [
'TracedLayer', 'dygraph_to_static_code', 'dygraph_to_static_graph',
'dygraph_to_static_output'
]
import warnings
......@@ -51,41 +54,43 @@ def extract_vars(inputs):
return result_list
def _dygraph_to_static_graph_(dygraph_func):
def _dygraph_to_static_code_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
"The decorator 'dygraph_to_static_code' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
static_func, ast_transformer = convert_to_static(dygraph_func)
return static_func(*args, **kwargs)
program_translator = ProgramTranslator()
return program_translator.get_code(dygraph_func)
return __impl__
dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_)
dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_)
def _dygraph_to_static_output_(dygraph_func):
program_translator = ProgramTranslator()
def _dygraph_to_static_graph_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in dygraph mode."
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
static_func = program_translator.get_func(dygraph_func)
return static_func(*args, **kwargs)
program_cache = program_translator.get_program_cache()
outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs)
return __impl__
# Run program to fetch output Tensors once building successfully.
if not program_cache.in_build_process:
outputs = program_translator.run(*args, **kwargs)
return outputs
dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_)
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
return program_translator.get_output(dygraph_func, *args, **kwargs)
return __impl__
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_code
from ifelse_simple_func import dyfunc_with_if_else
class TestDygraphToStaticCode(unittest.TestCase):
def setUp(self):
# set to print all string diff when assertEqual fails
self.maxDiff = None
def test_decorator(self):
answer = "\
def dyfunc_with_if_else(x_v, label=None):\n\
\n\
def true_fn_0(x_v):\n\
x_v = x_v - 1\n\
return x_v\n\
\n\
def false_fn_0(x_v):\n\
x_v = x_v + 1\n\
return x_v\n\
x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\
true_fn_0(x_v), lambda : false_fn_0(x_v))\n\
if label is not None:\n\
loss = fluid.layers.cross_entropy(x_v, label)\n\
return loss\n\
return x_v\n"
x_v = None
code = dygraph_to_static_code(dyfunc_with_if_else)(x_v)
self.assertEqual(answer, code)
def test_program_translator(self):
answer = "\
def dyfunc_with_if_else(x_v, label=None):\n\
\n\
def true_fn_1(x_v):\n\
x_v = x_v - 1\n\
return x_v\n\
\n\
def false_fn_1(x_v):\n\
x_v = x_v + 1\n\
return x_v\n\
x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\
true_fn_1(x_v), lambda : false_fn_1(x_v))\n\
if label is not None:\n\
loss = fluid.layers.cross_entropy(x_v, label)\n\
return loss\n\
return x_v\n"
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
self.assertEqual(answer, code)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册