未验证 提交 ce5b235f 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Rename Dygraph To Static Decorators (#23880) (#23962)

1. Rename Dygraph To Static Decorators to declarative
2. dygraph_to_static_func is still using in some training tests, I cannot delete it now.
3. Add some API docs
上级 84f899cb
......@@ -43,10 +43,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static
__all__ = ['DygraphToStaticAst', 'convert_to_static']
DECORATOR_NAMES = [
'dygraph_to_static_code', 'dygraph_to_static_program',
'dygraph_to_static_func', 'dygraph_to_static_output'
]
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
class DygraphToStaticAst(gast.NodeTransformer):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,10 +15,10 @@
from __future__ import print_function
import gast
import inspect
import logging
import numpy
import textwrap
import threading
import warnings
from paddle.fluid import framework
from paddle.fluid import core, executor
......@@ -32,6 +32,8 @@ from paddle.fluid.data_feeder import check_type
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
logger = logging.getLogger("fluid")
class FunctionCache(object):
"""
......@@ -235,6 +237,10 @@ class ProgramCache(object):
class ProgramTranslator(object):
"""
Class to translate dygraph function into static graph function.
"""
_singleton_lock = threading.Lock()
_instance = None
......@@ -274,16 +280,37 @@ class ProgramTranslator(object):
self._loss_name = None
# Once startup_program is changed, should run startup_program.
self._prev_startup = None
self.enable_declarative = True
def enable_declarative_function(self, enable_declarative):
"""
Enable or disable the converting from imperative to declarative by
ProgramTranslator globally.
Args:
enable_declarative (bool): True or False to enable or disable declarative
"""
self.enable_declarative = enable_declarative
def get_output(self, dygraph_func, *args, **kwargs):
"""
Returns the output tensors for dygraph function and its arguments
Returns the output dygraph VarBase for dygraph function. The dygraph
function will be translated into static graph function so the under
beneath numerical result will be calculated by declarative mode.
Args:
dygraph_func (callable): the dygraph function.
*args, **kwargs : the input argument of dygraph_func.
Returns:
VarBase or tuple of VarBase: the dygraph VarBase containing digital
result.
"""
if in_dygraph_mode():
warnings.warn(
if in_dygraph_mode() or not self.enable_declarative:
logger.info(
"The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use it in "
"static mode if you would like to translate to static graph.")
"mode or set enable_declarative_function to False. We will "
"just return dygraph output.")
return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache()
......@@ -292,33 +319,60 @@ class ProgramTranslator(object):
if not program_cache.in_build_process:
outputs = self.run(*args, **kwargs)
with guard():
outputs = [to_variable(x) for x in outputs]
if len(outputs) == 1:
outputs = to_variable(outputs[0])
else:
outputs = tuple(to_variable(x) for x in outputs)
return outputs
def get_func(self, dygraph_func):
"""
Returns the translated static function from dygraph function
Returns a callable function which converts imperative dygraph APIs of
the input dygraph_func into declarative net-building APIs, which means
it doesn't return immediate digital result as get_output does.
Users should handle Program and Executor by themselves.
Args:
dygraph_func (callable): the dygraph function.
Returns:
callable: converting imperative dygraph APIs into declarative
net-building APIs.
"""
if in_dygraph_mode():
warnings.warn(
if in_dygraph_mode() or not self.enable_declarative:
logger.info(
"The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use it in "
"static mode if you would like to translate to static graph.")
"mode or set enable_declarative_function to False. We will "
"just return dygraph output.")
return dygraph_func
static_func = convert_function_with_cache(dygraph_func)
return static_func
def get_program(self, dygraph_func, *args, **kwargs):
"""
Returns the translated static program and input/output variables from
dygraph function.
"""
if in_dygraph_mode():
warnings.warn(
dygraph function. The users can use the program to run by executor.
Args:
dygraph_func (callable): the dygraph function.
*args, **kwargs : the input argument of dygraph_func.
Returns:
tuple of (main_program, startup_program, inputs, outputs) whose
types are (Program, Program, list of Variable, list of Variable).
main_program: the converted main program.
startup_program: the converted startup program.
inputs: list of input Variables which need to be fed.
outputs: list of output Variables which users can fetch.
"""
if in_dygraph_mode() or not self.enable_declarative:
logger.info(
"The ProgramTranslator.get_program doesn't work in dygraph "
"mode. We will just return dygraph output. Use it in static "
"mode if you would like to translate to static graph.")
"mode or set enable_declarative_function to False. We will "
"just return dygraph output.")
return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache()
outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs)
......@@ -326,7 +380,13 @@ class ProgramTranslator(object):
def get_code(self, dygraph_func):
"""
Returns the translated static function code from dygraph code
Returns the translated static function string code from dygraph function.
Args:
dygraph_func (callable): the dygraph function.
Returns:
str: the string code of translated static function
"""
# Gets AST from dygraph function
raw_code = inspect.getsource(dygraph_func)
......
......@@ -14,12 +14,9 @@
from __future__ import print_function
__all__ = [
'TracedLayer', 'dygraph_to_static_code', 'dygraph_to_static_func',
'dygraph_to_static_output', 'dygraph_to_static_program'
]
__all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
import warnings
import logging
from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph
......@@ -30,6 +27,8 @@ from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
logger = logging.getLogger("fluid")
def create_program_from_desc(program_desc):
program = Program()
......@@ -54,62 +53,114 @@ def extract_vars(inputs):
return result_list
def _dygraph_to_static_code_(dygraph_func):
def _dygraph_to_static_func_(dygraph_func):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
@dygraph_to_static_func only converts imperative dygraph APIs into
declarative net-building APIs, which means it doesn't return immediate
digital result as imperative mode. Users should handle Program and Executor
by themselves.
Note:
This decorator is NOT our recommended way to transform imperative function
to declarative function. We will remove this decorator after we finalize
cleaning up code.
Args:
dygraph_func (callable): callable imperative function.
Returns:
Callable: converting imperative dygraph APIs into declarative
net-building APIs.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
from paddle.fluid.dygraph.jit import dygraph_to_static_func
@dygraph_to_static_func
def func(x):
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
x = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='float64')
x_v = func(x)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(fetch_list=[x_v])
print(out[0])
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
"""
# TODO: remove this decorator after we finalize training API
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
return program_translator.get_code(dygraph_func)
if in_dygraph_mode() or not program_translator.enable_declarative:
logger.info(
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set enable_declarative_function to False. "
"We will just return dygraph output.")
return dygraph_func(*args, **kwargs)
static_func = program_translator.get_func(dygraph_func)
return static_func(*args, **kwargs)
return __impl__
dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_)
dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
def _dygraph_to_static_program_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_program' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
return program_translator.get_program(dygraph_func, *args, **kwargs)
return __impl__
def _declarative_(dygraph_func):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns
the result as a dygraph VarBase.
Args:
dygraph_func (callable): callable imperative function.
dygraph_to_static_program = wrap_decorator(_dygraph_to_static_program_)
Returns:
VarBase: containing the numerical result.
Examples:
.. code-block:: python
def _dygraph_to_static_func_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
static_func = program_translator.get_func(dygraph_func)
return static_func(*args, **kwargs)
import paddle.fluid as fluid
import numpy as np
from paddle.fluid.dygraph.jit import declarative
return __impl__
@declarative
def func(x):
x = fluid.dygraph.to_variable(x)
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
x = np.ones([1, 2])
x_v = func(x)
print(x_v.numpy()) # [[2. 2.]]
"""
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_declarative:
logger.info(
"The decorator 'declarative' doesn't work in dygraph "
"mode or set enable_declarative_function to False. We will "
"just return dygraph output.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
return program_translator.get_output(dygraph_func, *args, **kwargs)
......@@ -117,7 +168,7 @@ def _dygraph_to_static_output_(dygraph_func):
return __impl__
dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_)
declarative = wrap_decorator(_declarative_)
@dygraph_only
......
......@@ -20,7 +20,7 @@ from collections import Counter
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache
......@@ -50,10 +50,14 @@ class TestCacheProgram(unittest.TestCase):
op.type for op in fluid.default_main_program().block(0).ops
])
if batch_id > 0:
prev_out_numpy = prev_out[0].numpy() if isinstance(
prev_out, tuple) else prev_out.numpy()
cur_out_numpy = cur_out[0].numpy() if isinstance(
cur_out, tuple) else cur_out.numpy()
self.assertTrue(
np.allclose(prev_out[0].numpy(), cur_out[0].numpy()),
np.allclose(prev_out_numpy, cur_out_numpy),
msg='Output in previous batch is {}\n Output in current batch is \n{}'
.format(prev_out, cur_out))
.format(prev_out_numpy, cur_out_numpy))
self.assertEqual(prev_ops, cur_ops)
......@@ -139,7 +143,7 @@ class TestConvertWithCache(unittest.TestCase):
self.assertTrue(id(static_func), id(cached_func))
@dygraph_to_static_output
@declarative
def sum_even_util_limit(max_len, limit):
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
for i in range(max_len):
......@@ -152,7 +156,7 @@ def sum_even_util_limit(max_len, limit):
return ret_sum
@dygraph_to_static_output
@declarative
def sum_under_while(limit):
i = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
......@@ -165,10 +169,10 @@ def sum_under_while(limit):
class TestToOutputWithCache(unittest.TestCase):
def test_output(self):
ret = sum_even_util_limit(80, 10)
self.assertEqual(ret[0].numpy(), 30)
self.assertEqual(ret.numpy(), 30)
ret = sum_under_while(100)
self.assertEqual(ret[0].numpy(), 5050)
self.assertEqual(ret.numpy(), 5050)
if __name__ == '__main__':
......
......@@ -14,7 +14,7 @@
from __future__ import print_function
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.jit import declarative
import numpy as np
import unittest
......@@ -30,7 +30,7 @@ class Pool2D(fluid.dygraph.Layer):
self.pool2d = fluid.dygraph.Pool2D(
pool_size=2, pool_type='avg', pool_stride=1, global_pooling=False)
@dygraph_to_static_output
@declarative
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
......@@ -54,7 +54,7 @@ class Linear(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)))
@dygraph_to_static_output
@declarative
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
pre = self.fc(inputs)
......@@ -82,7 +82,9 @@ class TestPool2D(unittest.TestCase):
with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class()
out = dy_layer(x=self.data)
return out[0].numpy()
if isinstance(out, tuple):
return out[0].numpy()
return out.numpy()
def test_static_output(self):
dygraph_res = self.run_dygraph_mode()
......
......@@ -17,15 +17,38 @@ from __future__ import print_function
import astor
import gast
import inspect
import numpy as np
import textwrap
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_code
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.nn import Linear
from ifelse_simple_func import dyfunc_with_if_else
np.random.seed(0)
def simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer, bias_attr=False)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
@declarative
def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer, bias_attr=False)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
def get_source_code(func):
raw_code = inspect.getsource(func)
......@@ -81,8 +104,9 @@ class TestDygraphToStaticCode(unittest.TestCase):
def test_decorator(self):
x_v = None
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
answer = get_source_code(StaticCode1.dyfunc_with_if_else)
code = dygraph_to_static_code(dyfunc_with_if_else)(x_v)
self.assertEqual(answer, code)
def test_program_translator(self):
......@@ -92,5 +116,80 @@ class TestDygraphToStaticCode(unittest.TestCase):
self.assertEqual(answer, code)
class TestEnableDeclarative(unittest.TestCase):
def test_enable_disable_get_output(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_output = program_translator.get_output(simple_func, x,
weight)
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_output = program_translator.get_output(simple_func, x,
weight)
self.assertTrue(
np.allclose(
static_output.numpy(), dygraph_output.numpy(), atol=1e-4))
def test_enable_disable_get_func(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_func = program_translator.get_func(simple_func)
self.assertTrue(callable(static_func))
static_output = static_func(x, weight)
self.assertTrue(isinstance(static_output, fluid.Variable))
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_func = program_translator.get_func(simple_func)
self.assertTrue(callable(dygraph_func))
dygraph_output = dygraph_func(x, weight)
self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))
def test_enable_disable_get_program(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_output = program_translator.get_program(simple_func, x,
weight)
self.assertTrue(isinstance(static_output, tuple))
self.assertEqual(len(static_output), 4)
self.assertTrue(isinstance(static_output[0], fluid.Program))
self.assertTrue(isinstance(static_output[1], fluid.Program))
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_output = program_translator.get_program(simple_func, x,
weight)
self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))
def test_enable_disable_declarative(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_output = decorated_simple_func(x, weight)
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_output = decorated_simple_func(x, weight)
self.assertTrue(
np.allclose(
static_output.numpy(), dygraph_output.numpy(), atol=1e-4))
if __name__ == '__main__':
unittest.main()
......@@ -21,7 +21,7 @@ import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.jit import declarative
np.random.seed(2020)
......@@ -34,7 +34,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
super(SimpleFcLayer, self).__init__()
self._linear = fluid.dygraph.Linear(fc_size, fc_size)
@dygraph_to_static_output
@declarative
def forward(self, x):
x = fluid.dygraph.to_variable(x)
y = self._linear(x)
......
......@@ -20,7 +20,7 @@ import numpy as np
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.dygraph.jit import dygraph_to_static_program
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.nn import Linear
np.random.seed(2020)
......@@ -38,7 +38,6 @@ def simple_func(x, weight_numpy):
return z
@dygraph_to_static_program
def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer)
......@@ -55,8 +54,8 @@ class TestDyToStaticSaveLoad(unittest.TestCase):
with fluid.dygraph.guard(place):
dygraph_result = simple_func(x, weight)
main_program, startup_program, inputs, outputs = decorated_simple_func(
x, weight)
main_program, startup_program, inputs, outputs = ProgramTranslator(
).get_program(decorated_simple_func, x, weight)
exe = fluid.Executor(place)
exe.run(startup_program)
fluid.save(main_program, "./test_dy2stat_save_load")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册