You need to sign in or sign up before continuing.
未验证 提交 38eac880 编写于 作者: L liym27 提交者: GitHub

[Cherry-pick Release/2.0] Support recursive call. (#23960)

* [Cherry-pick][dy2static]Support recursive call (#23900)
  * Support recursive call.

  * Remove Redundant decorator to pass the Py35 unittest temporarily.

* [Cherry-pick]Fix bug in convert_call because difference exists between python3 and python2. (#23966 )
上级 1ea43954
...@@ -29,9 +29,13 @@ from .variable_trans_func import * ...@@ -29,9 +29,13 @@ from .variable_trans_func import *
from . import program_translator from . import program_translator
from .program_translator import * from .program_translator import *
from . import convert_call_func
from .convert_call_func import *
__all__ = [] __all__ = []
__all__ += ast_transformer.__all__ __all__ += ast_transformer.__all__
__all__ += loop_transformer.__all__ __all__ += loop_transformer.__all__
__all__ += static_analysis.__all__ __all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__ __all__ += variable_trans_func.__all__
__all__ += program_translator.__all__ __all__ += program_translator.__all__
__all__ += convert_call_func.__all__
...@@ -32,6 +32,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTran ...@@ -32,6 +32,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTran
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
...@@ -58,7 +59,6 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -58,7 +59,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_visitor = StaticAnalysisVisitor(root)
self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
) )
self.decorate_func_name = None self.decorate_func_name = None
self.arg_name_to_idx = {} self.arg_name_to_idx = {}
self.transfer_from_node_type(self.static_analysis_root) self.transfer_from_node_type(self.static_analysis_root)
...@@ -88,6 +88,9 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -88,6 +88,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform all if/else statement of Dygraph into Static Graph. # Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform() IfElseTransformer(node_wrapper).transform()
# Transform call recursively
CallTransformer(node_wrapper).transform()
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if self.decorate_func_name is None: if self.decorate_func_name is None:
self.decorate_func_name = node.name self.decorate_func_name = node.name
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
class CallTransformer(gast.NodeTransformer):
"""
This class transforms function calls into Static Graph Ast.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of CallTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
def visit_Call(self, node):
self.generic_visit(node)
if is_paddle_api(node):
return node
func_str = ast_to_source_code(node.func).strip()
new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format(
func_str)
new_func_ast = gast.parse(new_func_str).body[0].value
node.func = new_func_ast
return node
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
__all__ = ['convert_call']
import collections
import copy
import functools
import inspect
import pdb
import re
import types
import numpy
import six
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.layers import Layer
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator()
to_static_func = program_translator.get_func
def is_builtin(func):
if isinstance(func, types.BuiltinFunctionType):
return True
elif func in six.moves.builtins.__dict__.values():
return True
# Other built-in modules
# TODO(liym27): A better way to do this.
elif any(func in m.__dict__.values()
for m in (collections, pdb, copy, inspect, re, six, numpy)):
return True
else:
return False
def is_paddle_func(func):
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith("paddle")
def convert_call(func):
"""
Converts a function call which needs to be transformed to static fucntion.
Args:
func (callable): A callable function or method to convert.
Returns:
Callable: A converted function.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import convert_call
def dyfunc(x):
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
new_func = convert_call(dyfunc)
x = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='float64')
x_v = new_func(x)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(fetch_list=[x_v])
print(out[0])
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
"""
func_self = None
converted_call = None
if is_builtin(func):
return func
if is_paddle_func(func):
return func
if inspect.isfunction(func):
# TODO(liym27): If func is a lambda function, special conversion is needed.
if func.__name__ == '<lambda>':
return func
try:
if func in func.__globals__.values():
if six.PY3:
source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
converted_call = None
except (IOError, OSError):
# NOTE:
# If func has beed decorated, its source code can not be get
# so that it can not be transformed to static function.
converted_call = None
elif inspect.ismethod(func):
try:
if six.PY3:
source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have beed decorated.
converted_call = None
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
if six.PY3:
source_code = inspect.getsource(func.forward)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
except Exception:
# NOTE: func.forward may have beed decorated.
func_self = None if func_self else func_self
converted_call = func
else:
try:
call_func = func.__class__.__call__
converted_call = to_static_func(call_func)
func_self = func
except Exception:
# NOTE:
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed
func_self = None if func_self else func_self
if converted_call is None:
return func
if func_self:
converted_call = functools.partial(converted_call, func_self)
return converted_call
...@@ -18,14 +18,14 @@ __all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func'] ...@@ -18,14 +18,14 @@ __all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
import logging import logging
from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph
from .layers import Layer
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.wrapped_decorator import wrap_decorator
logger = logging.getLogger("fluid") logger = logging.getLogger("fluid")
......
...@@ -47,7 +47,6 @@ class SubNetWithDict(fluid.dygraph.Layer): ...@@ -47,7 +47,6 @@ class SubNetWithDict(fluid.dygraph.Layer):
bias_attr=False, bias_attr=False,
param_attr=init_weight(0.2)) param_attr=init_weight(0.2))
@dygraph_to_static_func
def forward(self, input, cache=None): def forward(self, input, cache=None):
input = fluid.dygraph.to_variable(input) input = fluid.dygraph.to_variable(input)
......
...@@ -13,16 +13,17 @@ ...@@ -13,16 +13,17 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest
from time import time from time import time
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.nn import Conv2D, Linear, Pool2D
import unittest from paddle.fluid.optimizer import AdamOptimizer
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
...@@ -66,7 +67,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer): ...@@ -66,7 +67,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
global_pooling=global_pooling, global_pooling=global_pooling,
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
@dygraph_to_static_func
def forward(self, inputs): def forward(self, inputs):
x = self._conv2d(inputs) x = self._conv2d(inputs)
x = self._pool2d(x) x = self._pool2d(x)
...@@ -105,7 +105,6 @@ class MNIST(fluid.dygraph.Layer): ...@@ -105,7 +105,6 @@ class MNIST(fluid.dygraph.Layer):
else: else:
return x return x
@dygraph_to_static_func
def inference(self, inputs): def inference(self, inputs):
x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x) x = self._simple_img_conv_pool_2(x)
......
...@@ -69,8 +69,10 @@ class StaticCode1(): ...@@ -69,8 +69,10 @@ class StaticCode1():
return x_v return x_v
x_v = fluid.layers.cond( x_v = fluid.layers.cond(
fluid.layers.mean(x_v)[0] > 5, lambda: true_fn_0(x_v), fluid.layers.mean(x_v)[0] > 5,
lambda: false_fn_0(x_v)) lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(x_v)
)
if label is not None: if label is not None:
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
return loss return loss
...@@ -88,9 +90,10 @@ class StaticCode2(): ...@@ -88,9 +90,10 @@ class StaticCode2():
return x_v return x_v
x_v = fluid.layers.cond( x_v = fluid.layers.cond(
fluid.layers.mean(x_v)[0] > 5, lambda: true_fn_1(x_v), fluid.layers.mean(x_v)[0] > 5,
lambda: false_fn_1(x_v)) lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_1)(x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_1)(x_v)
)
if label is not None: if label is not None:
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
return loss return loss
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_func
SEED = 2020
np.random.seed(SEED)
# Use a decorator to test exception
@dygraph_to_static_func
def dyfunc_with_if(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
else:
x_v = x_v + 1
return x_v
@dygraph_to_static_func
def nested_func(x_v):
x_v = fluid.dygraph.to_variable(x_v)
def fn1():
return x_v
res = fn1()
return res
class TestRecursiveCall1(unittest.TestCase):
def setUp(self):
self.input = np.random.random([10, 16]).astype('float32')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_test_func()
def init_test_func(self):
self.dyfunc = nested_func
def get_dygraph_output(self):
with fluid.dygraph.guard():
res = self.dyfunc(self.input).numpy()
return res
def get_static_output(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_out = self.dyfunc(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]
def test_transformed_static_result(self):
static_res = self.get_static_output()
dygraph_res = self.get_dygraph_output()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
lambda_fun = lambda x: x
class MyConvLayer(fluid.dygraph.Layer):
def __init__(self):
super(MyConvLayer, self).__init__()
self._conv = fluid.dygraph.Conv2D(
num_channels=3,
num_filters=2,
filter_size=3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
@dygraph_to_static_func
def forward(self, inputs):
y = dyfunc_with_if(inputs)
y = lambda_fun(y)
y = self.dymethod(y)
return y
@dygraph_to_static_func
def dymethod(self, x_v):
x_v = fluid.layers.assign(x_v)
return x_v
class MyLayer(fluid.dygraph.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.conv = MyConvLayer()
self.fc = fluid.dygraph.Linear(
input_dim=5,
output_dim=1,
act='relu',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
@dygraph_to_static_func
def forward(self, inputs):
h = self.conv(inputs)
out = self.fc(h)
return out
class TestRecursiveCall2(unittest.TestCase):
def setUp(self):
self.input = np.random.random((1, 3, 3, 5)).astype('float32')
self.Layer = MyLayer
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def get_dygraph_output(self):
with fluid.dygraph.guard():
self.dygraph_func = self.Layer()
fluid.default_startup_program.random_seed = SEED
fluid.default_main_program.random_seed = SEED
data = fluid.dygraph.to_variable(self.input)
res = self.dygraph_func(data)
return res.numpy()
def get_static_output(self):
startup_program = fluid.Program()
startup_program.random_seed = SEED
main_program = fluid.Program()
main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program):
self.dygraph_func = self.Layer()
data = fluid.layers.assign(self.input)
static_out = self.dygraph_func(data)
exe = fluid.Executor(self.place)
exe.run(startup_program)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]
def test_transformed_static_result(self):
dygraph_res = self.get_dygraph_output()
static_res = self.get_static_output()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res,
static_res))
if __name__ == '__main__':
unittest.main()
...@@ -12,30 +12,18 @@ ...@@ -12,30 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import math
import time
import unittest
import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D
import unittest
import time
import math
import numpy as np
IMAGENET1000 = 1281167 IMAGENET1000 = 1281167
base_lr = 0.1 base_lr = 0.1
...@@ -93,7 +81,6 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -93,7 +81,6 @@ class ConvBNLayer(fluid.dygraph.Layer):
self._batch_norm = BatchNorm(num_filters, act=act) self._batch_norm = BatchNorm(num_filters, act=act)
@dygraph_to_static_func
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
...@@ -133,7 +120,6 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -133,7 +120,6 @@ class BottleneckBlock(fluid.dygraph.Layer):
self._num_channels_out = num_filters * 4 self._num_channels_out = num_filters * 4
@dygraph_to_static_func
def forward(self, inputs): def forward(self, inputs):
y = self.conv0(inputs) y = self.conv0(inputs)
conv1 = self.conv1(y) conv1 = self.conv1(y)
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
import logging import logging
import math import math
import numpy as np
import time import time
import unittest import unittest
import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D
SEED = 2020 SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
...@@ -98,7 +99,6 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -98,7 +99,6 @@ class ConvBNLayer(fluid.dygraph.Layer):
self._batch_norm = BatchNorm(num_filters, act=act) self._batch_norm = BatchNorm(num_filters, act=act)
@dygraph_to_static_func
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
...@@ -127,7 +127,6 @@ class SqueezeExcitation(fluid.dygraph.Layer): ...@@ -127,7 +127,6 @@ class SqueezeExcitation(fluid.dygraph.Layer):
initializer=fluid.initializer.Uniform(-stdv, stdv)), initializer=fluid.initializer.Uniform(-stdv, stdv)),
act='sigmoid') act='sigmoid')
@dygraph_to_static_func
def forward(self, input): def forward(self, input):
y = self._pool(input) y = self._pool(input)
y = fluid.layers.reshape(y, shape=[-1, self._num_channels]) y = fluid.layers.reshape(y, shape=[-1, self._num_channels])
...@@ -179,7 +178,6 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -179,7 +178,6 @@ class BottleneckBlock(fluid.dygraph.Layer):
self._num_channels_out = num_filters * 2 self._num_channels_out = num_filters * 2
@dygraph_to_static_func
def forward(self, inputs): def forward(self, inputs):
y = self.conv0(inputs) y = self.conv0(inputs)
conv1 = self.conv1(y) conv1 = self.conv1(y)
...@@ -301,6 +299,7 @@ class SeResNeXt(fluid.dygraph.Layer): ...@@ -301,6 +299,7 @@ class SeResNeXt(fluid.dygraph.Layer):
for bottleneck_block in self.bottleneck_block_list: for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y) y = bottleneck_block(y)
y = self.pool2d_avg(y) y = self.pool2d_avg(y)
y = fluid.layers.dropout(y, dropout_prob=0.5, seed=100) y = fluid.layers.dropout(y, dropout_prob=0.5, seed=100)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output]) y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
......
...@@ -13,17 +13,15 @@ ...@@ -13,17 +13,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
import numpy as np
import time
import os import os
import time
import unittest import unittest
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import transformer_util as util import transformer_util as util
from transformer_dygraph_model import position_encoding_init from transformer_dygraph_model import CrossEntropyCriterion, Transformer, position_encoding_init
from transformer_dygraph_model import Transformer
from transformer_dygraph_model import CrossEntropyCriterion
trainer_count = 1 trainer_count = 1
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
......
...@@ -18,10 +18,9 @@ import numpy as np ...@@ -18,10 +18,9 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable from paddle.fluid.dygraph import Embedding, Layer, LayerNorm, Linear, to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.layers.utils import map_structure from paddle.fluid.layers.utils import map_structure
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
...@@ -67,7 +66,6 @@ class PrePostProcessLayer(Layer): ...@@ -67,7 +66,6 @@ class PrePostProcessLayer(Layer):
self.functors.append(lambda x: layers.dropout( self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False)) x, dropout_prob=dropout_rate, is_test=False))
@dygraph_to_static_func
def forward(self, x, residual=None): def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd): for i, cmd in enumerate(self.process_cmd):
if cmd == "a": if cmd == "a":
...@@ -94,7 +92,6 @@ class MultiHeadAttention(Layer): ...@@ -94,7 +92,6 @@ class MultiHeadAttention(Layer):
self.proj_fc = Linear( self.proj_fc = Linear(
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False) input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
@dygraph_to_static_func
def forward(self, queries, keys, values, attn_bias, cache=None): def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v # compute q ,k ,v
keys = queries if keys is None else keys keys = queries if keys is None else keys
...@@ -138,7 +135,6 @@ class FFN(Layer): ...@@ -138,7 +135,6 @@ class FFN(Layer):
self.fc1 = Linear(input_dim=d_model, output_dim=d_inner_hid, act="relu") self.fc1 = Linear(input_dim=d_model, output_dim=d_inner_hid, act="relu")
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model) self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
@dygraph_to_static_func
def forward(self, x): def forward(self, x):
hidden = self.fc1(x) hidden = self.fc1(x)
if self.dropout_rate: if self.dropout_rate:
...@@ -176,7 +172,6 @@ class EncoderLayer(Layer): ...@@ -176,7 +172,6 @@ class EncoderLayer(Layer):
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model, self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout)
@dygraph_to_static_func
def forward(self, enc_input, attn_bias): def forward(self, enc_input, attn_bias):
attn_output = self.self_attn( attn_output = self.self_attn(
self.preprocesser1(enc_input), None, None, attn_bias) self.preprocesser1(enc_input), None, None, attn_bias)
...@@ -214,7 +209,6 @@ class Encoder(Layer): ...@@ -214,7 +209,6 @@ class Encoder(Layer):
self.processer = PrePostProcessLayer(preprocess_cmd, d_model, self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout)
@dygraph_to_static_func
def forward(self, enc_input, attn_bias): def forward(self, enc_input, attn_bias):
for encoder_layer in self.encoder_layers: for encoder_layer in self.encoder_layers:
enc_output = encoder_layer(enc_input, attn_bias) enc_output = encoder_layer(enc_input, attn_bias)
...@@ -232,7 +226,6 @@ class Embedder(Layer): ...@@ -232,7 +226,6 @@ class Embedder(Layer):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Normal(0., emb_dim**-0.5))) initializer=fluid.initializer.Normal(0., emb_dim**-0.5)))
@dygraph_to_static_func
def forward(self, word): def forward(self, word):
word_emb = self.word_embedder(word) word_emb = self.word_embedder(word)
return word_emb return word_emb
...@@ -258,7 +251,6 @@ class WrapEncoder(Layer): ...@@ -258,7 +251,6 @@ class WrapEncoder(Layer):
attention_dropout, relu_dropout, preprocess_cmd, attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd) postprocess_cmd)
@dygraph_to_static_func
def forward(self, src_word, src_pos, src_slf_attn_bias): def forward(self, src_word, src_pos, src_slf_attn_bias):
word_emb = self.word_embedder(src_word) word_emb = self.word_embedder(src_word)
word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5) word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5)
...@@ -304,7 +296,6 @@ class DecoderLayer(Layer): ...@@ -304,7 +296,6 @@ class DecoderLayer(Layer):
self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model, self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout)
@dygraph_to_static_func
def forward(self, def forward(self,
dec_input, dec_input,
enc_output, enc_output,
...@@ -342,7 +333,6 @@ class Decoder(Layer): ...@@ -342,7 +333,6 @@ class Decoder(Layer):
self.processer = PrePostProcessLayer(preprocess_cmd, d_model, self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout) prepostprocess_dropout)
@dygraph_to_static_func
def forward(self, def forward(self,
dec_input, dec_input,
enc_output, enc_output,
...@@ -386,7 +376,6 @@ class WrapDecoder(Layer): ...@@ -386,7 +376,6 @@ class WrapDecoder(Layer):
self.linear = Linear( self.linear = Linear(
input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False) input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False)
@dygraph_to_static_func
def forward(self, def forward(self,
trg_word, trg_word,
trg_pos, trg_pos,
...@@ -415,7 +404,6 @@ class CrossEntropyCriterion(object): ...@@ -415,7 +404,6 @@ class CrossEntropyCriterion(object):
def __init__(self, label_smooth_eps): def __init__(self, label_smooth_eps):
self.label_smooth_eps = label_smooth_eps self.label_smooth_eps = label_smooth_eps
@dygraph_to_static_func
def __call__(self, predict, label, weights): def __call__(self, predict, label, weights):
if self.label_smooth_eps: if self.label_smooth_eps:
label_out = layers.label_smooth( label_out = layers.label_smooth(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册