未验证 提交 c1913a5f 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] PaddleSOT pr (#54202)

* add paddle-symbolic-trace to paddle

* add symoblic trace

* delete swp

* support Layer in symbolic trace

* fix test-symbolic-trace, make symbolic trace return a StaticFunction

* template the error message

* fix some unittest

* Modify the execution mode of test

* Modify the module name

* add dy2static unittest decorator

* change some unittest files by @ast_only_test

* fix unittest.

* test-symbolic-trace

* update test_write_python_container.py

* update

* fix test_param_parse.py

* add submodule and ln -sf in cmakefile

* update

* update

* fix some ast only errors

* update

* Polish ut

* fix unittests

* update

* update

* fix unittests

* update

* test warning ast only

* update

* Ast only some uts

* Fix unitests

* test_error ast only

* update

* update

* Support build_strategy for sot

* update

* import sot as a third party module

* update

* update

* Polish code

* update

* update

* update

* update

* update

* remove old fluid api and use paddle.nn.relu instead

* fix

* comment the print of ast code

* add try-finally block

* fix dy2static stop-gradient bugs

* fix code

* remove unused submodule and minor codestyle fix

* fix

* fix cast error

* fix interpolate meets int64 in static model

* add evalframe support for py311

* fix

* fix err

* switch ENABLE_FALL_BACK=False

* fix

* Fix CI for some unittest

* add ENABLE_SOT

* remove setup.py dependences

---------
Co-authored-by: NNotHaozi <zhangmenghao@baidu.com>
Co-authored-by: Nfeifei-111 <2364819892@qq.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
Co-authored-by: NSigureMo <sigure.qaq@gmail.com>
上级 fc177d25
...@@ -254,7 +254,10 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, ...@@ -254,7 +254,10 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
if (disable_eval_frame != Py_True) { if (disable_eval_frame != Py_True) {
// Re-enable custom behavior // Re-enable custom behavior
eval_frame_callback_set(callback); eval_frame_callback_set(callback);
VLOG(7) << "Start eval new frame and code.";
auto out = eval_custom_code(tstate, frame, code, throw_flag); auto out = eval_custom_code(tstate, frame, code, throw_flag);
Py_DECREF(result);
Py_DECREF(code);
return out; return out;
} else { } else {
auto out = eval_custom_code(tstate, frame, code, throw_flag); auto out = eval_custom_code(tstate, frame, code, throw_flag);
......
...@@ -106,15 +106,17 @@ def program_desc_tracing_guard(enable): ...@@ -106,15 +106,17 @@ def program_desc_tracing_guard(enable):
def param_guard(parameters): def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers # Note: parameters is a reference of self._parameters or self._buffers
if in_declarative_mode() and not paddle.in_dynamic_mode() and parameters: if in_declarative_mode() and not paddle.in_dynamic_mode() and parameters:
origin_parameters = parameters.copy() try:
for name, var_base in parameters.items(): origin_parameters = parameters.copy()
if isinstance(var_base, list): for name, var_base in parameters.items():
new_var = [_convert_into_variable(var) for var in var_base] if isinstance(var_base, list):
else: new_var = [_convert_into_variable(var) for var in var_base]
new_var = _convert_into_variable(var_base) else:
parameters[name] = new_var new_var = _convert_into_variable(var_base)
yield parameters[name] = new_var
parameters.update(origin_parameters) yield
finally:
parameters.update(origin_parameters)
else: else:
yield yield
......
...@@ -25,6 +25,7 @@ from collections import OrderedDict ...@@ -25,6 +25,7 @@ from collections import OrderedDict
import inspect import inspect
import threading import threading
from typing import Any from typing import Any
import types
import paddle import paddle
from paddle.fluid import core, dygraph from paddle.fluid import core, dygraph
...@@ -46,6 +47,8 @@ from .dy2static.convert_call_func import ( ...@@ -46,6 +47,8 @@ from .dy2static.convert_call_func import (
from .dy2static.program_translator import ( from .dy2static.program_translator import (
ProgramTranslator, ProgramTranslator,
StaticFunction, StaticFunction,
ASTStaticFunction,
SymbolicStaticFunction,
unwrap_decorators, unwrap_decorators,
) )
from paddle.jit.translated_layer import ( from paddle.jit.translated_layer import (
...@@ -232,6 +235,7 @@ def to_static( ...@@ -232,6 +235,7 @@ def to_static(
input_spec=None, input_spec=None,
build_strategy=None, build_strategy=None,
backend=None, backend=None,
enable_fallback=None,
**kwargs, **kwargs,
): ):
""" """
...@@ -283,15 +287,29 @@ def to_static( ...@@ -283,15 +287,29 @@ def to_static(
def decorated(python_func): def decorated(python_func):
""" """
Decorates a python function into a StaticFunction object. Decorates a python function into a ASTStaticFunction object.
""" """
nonlocal enable_fallback
if enable_fallback is None:
flag = os.environ.get("ENABLE_FALL_BACK", None)
if flag == "True":
enable_fallback = True
else: # None or True
enable_fallback = False
StaticClass = StaticFunctionClass = {
True: SymbolicStaticFunction,
False: ASTStaticFunction,
}[enable_fallback]
# Step 1. unwrap the function if it is already decorated. # Step 1. unwrap the function if it is already decorated.
_, python_func = unwrap_decorators(python_func) _, python_func = unwrap_decorators(python_func)
# Step 2. copy some attributes from original python function. # Step 2. copy some attributes from original python function.
static_layer = copy_decorator_attrs( static_layer = copy_decorator_attrs(
original_func=python_func, original_func=python_func,
decorated_obj=StaticFunction( decorated_obj=StaticClass(
function=python_func, function=python_func,
input_spec=input_spec, input_spec=input_spec,
build_strategy=build_strategy, build_strategy=build_strategy,
...@@ -1033,7 +1051,9 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1033,7 +1051,9 @@ def save(layer, path, input_spec=None, **configs):
concrete_program = None concrete_program = None
for attr_func in functions: for attr_func in functions:
if isinstance(layer, Layer): if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None) static_func = get_ast_static_function(
getattr(inner_layer, attr_func, None)
)
if isinstance(static_func, StaticFunction): if isinstance(static_func, StaticFunction):
if static_func.is_property: if static_func.is_property:
# property method to be exported # property method to be exported
...@@ -1066,7 +1086,9 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1066,7 +1086,9 @@ def save(layer, path, input_spec=None, **configs):
input_spec, inner_input_spec input_spec, inner_input_spec
) )
static_forward = to_static( static_forward = to_static(
inner_layer.forward, input_spec=inner_input_spec inner_layer.forward,
input_spec=inner_input_spec,
enable_fallback=False,
) )
concrete_program = ( concrete_program = (
static_forward.concrete_program_specify_input_spec( static_forward.concrete_program_specify_input_spec(
...@@ -1082,24 +1104,29 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1082,24 +1104,29 @@ def save(layer, path, input_spec=None, **configs):
else: else:
# When layer is a function # When layer is a function
if isinstance(attr_func, StaticFunction): if isinstance(attr_func, StaticFunction):
if attr_func.is_property: static_func = get_ast_static_function(attr_func)
if static_func.is_property:
# property method to be exported # property method to be exported
immediate_val = attr_func() immediate_val = static_func()
property_vals.append((immediate_val, attr_func)) property_vals.append((immediate_val, static_func))
continue continue
concrete_program = ( concrete_program = (
attr_func.concrete_program_specify_input_spec( static_func.concrete_program_specify_input_spec(
inner_input_spec, is_prim_infer=is_prim_infer inner_input_spec, is_prim_infer=is_prim_infer
) )
) )
else: else:
static_func = get_ast_static_function(attr_func)
if inner_input_spec: if inner_input_spec:
inner_input_spec = paddle.utils.pack_sequence_as( inner_input_spec = paddle.utils.pack_sequence_as(
input_spec, inner_input_spec input_spec, inner_input_spec
) )
static_function = to_static( static_function = to_static(
attr_func, input_spec=inner_input_spec static_func,
input_spec=inner_input_spec,
enable_fallback=False,
) )
concrete_program = static_function.concrete_program concrete_program = static_function.concrete_program
...@@ -1115,9 +1142,9 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1115,9 +1142,9 @@ def save(layer, path, input_spec=None, **configs):
if isinstance(inner_layer, Layer): if isinstance(inner_layer, Layer):
dygraph_state_dict = inner_layer.to_static_state_dict() dygraph_state_dict = inner_layer.to_static_state_dict()
elif isinstance(attr_func, StaticFunction): elif isinstance(attr_func, StaticFunction):
if attr_func._class_instance: if static_func._class_instance:
dygraph_state_dict = ( dygraph_state_dict = (
attr_func._class_instance.to_static_state_dict() static_func._class_instance.to_static_state_dict()
) )
if dygraph_state_dict: if dygraph_state_dict:
...@@ -1887,3 +1914,29 @@ class TracedLayer: ...@@ -1887,3 +1914,29 @@ class TracedLayer:
clip_extra=clip_extra, clip_extra=clip_extra,
legacy_format=legacy_format, legacy_format=legacy_format,
) )
def get_ast_static_function(function):
if isinstance(function, SymbolicStaticFunction):
if function._class_instance:
dygraph_function = types.MethodType(
function._dygraph_function, function._class_instance
)
else:
dygraph_function = function._dygraph_function
if function._function_spec._input_spec is None:
ast_static_function = ASTStaticFunction(
dygraph_function,
function.last_call_input_spec,
**function._kwargs,
)
return ast_static_function
else:
ast_static_function = ASTStaticFunction(
dygraph_function,
function._function_spec._input_spec,
**function._kwargs,
)
return ast_static_function
return function
...@@ -231,6 +231,7 @@ class PartialProgramLayer: ...@@ -231,6 +231,7 @@ class PartialProgramLayer:
self._cuda_graph_vec, self._cuda_graph_vec,
*attrs *attrs
) )
self._update_stop_gradient(out_vars)
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
...@@ -960,6 +961,17 @@ class PartialProgramLayer: ...@@ -960,6 +961,17 @@ class PartialProgramLayer:
var.stop_gradient = True var.stop_gradient = True
return var return var
def _update_stop_gradient(self, out_vars):
# Update stop_gradient for all outputs
def set_stop_gradient(var_id, eager_tensor):
var = self._outputs[var_id]
assert isinstance(var, framework.Variable)
eager_tensor.stop_gradient = var.stop_gradient
return None
for idx, var in zip(self._outputs.var_ids, out_vars):
set_stop_gradient(idx, var)
def _restore_out(self, out_vars): def _restore_out(self, out_vars):
""" """
Restores same nested outputs by only replacing the Variable with Tensor. Restores same nested outputs by only replacing the Variable with Tensor.
......
...@@ -309,11 +309,6 @@ def unwrap_decorators(func): ...@@ -309,11 +309,6 @@ def unwrap_decorators(func):
class StaticFunction: class StaticFunction:
"""
Wrapper class to Manage program conversion of decorated function.
"""
def __init__(self, function, input_spec=None, **kwargs): def __init__(self, function, input_spec=None, **kwargs):
""" """
Initializes a `StaticFunction`. Initializes a `StaticFunction`.
...@@ -364,7 +359,6 @@ class StaticFunction: ...@@ -364,7 +359,6 @@ class StaticFunction:
self._training = True self._training = True
self._cuda_graph_capture_mode = "" self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0 self._cuda_graph_pool_id = 0
self._property = kwargs.get("property", False) self._property = kwargs.get("property", False)
@property @property
...@@ -473,42 +467,7 @@ class StaticFunction: ...@@ -473,42 +467,7 @@ class StaticFunction:
) )
) )
# 2. trace ops from dygraph layers and cache the generated program. return self._perform_call(*args, **kwargs)
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs, is_train=self._is_train_mode()
)
# 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training
else:
partial_program_layer.training = self._training
partial_program_layer._cuda_graph_capture_mode = (
self._cuda_graph_capture_mode
)
partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id
# 4. return outputs.
try:
return partial_program_layer(args)
except Exception as e:
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except Exception as e:
error_data = getattr(e, error.ERROR_DATA, None)
if error_data:
error_data.raise_new_exception()
else:
logging_utils.warn(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself.".format(type(e))
)
raise e
def _is_train_mode(self): def _is_train_mode(self):
if self._class_instance is not None: if self._class_instance is not None:
...@@ -543,6 +502,298 @@ class StaticFunction: ...@@ -543,6 +502,298 @@ class StaticFunction:
if self.is_property: if self.is_property:
raise RuntimeError("Can not call the func when property=True.") raise RuntimeError("Can not call the func when property=True.")
def get_concrete_program(self, *args, **kwargs):
raise NotImplementedError("Not implemented yet.")
def get_concrete_program_with_cache_key(self, cached_key):
raise NotImplementedError("Not implemented yet.")
def get_traced_count(self):
raise NotImplementedError("Not implemented yet.")
@property
def code(self):
raise NotImplementedError("Not implemented yet.")
@property
def dygraph_function(self):
"""
Returns the original decorated function.
"""
if self._class_instance is not None:
return self._dygraph_function.__get__(self._class_instance)
else:
return self._dygraph_function
@property
def concrete_program(self):
raise NotImplementedError("Not implemented yet.")
def concrete_program_specify_input_spec(
self, input_spec=None, with_hook=False, is_prim_infer=False
):
raise NotImplementedError("Not implemented yet.")
def rollback(self):
"""
Rollback into original dygraph functions for current class instance.
Returns:
Function or Method
Example::
.. code-block:: python
>>> # doctest: +SKIP
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> out = net(x)
>>> net.forward.rollback() # rollback into dygraph mode
>>> out = net(x)
"""
def rollback_impl(class_instance):
for name, func in class_instance._original_funcs.items():
setattr(class_instance, name, func.__get__(class_instance))
for sublayer in class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
if self._class_instance is None:
return self._dygraph_function
# only rollback sub-functions on path of top _dygraph_function
func_name = self._dygraph_function.__name__
assert (
func_name in self._class_instance._original_funcs
), "Not Found function '{}' in class '{}'.".format(
func_name, self._class_instance.__name__
)
func = self._class_instance._original_funcs[func_name]
setattr(
self._class_instance, func_name, func.__get__(self._class_instance)
)
for sublayer in self._class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
return getattr(self._class_instance, func_name)
def __deepcopy__(self, memo):
"""
Customized behavior for copy.deepcopy, return original decorated function instead
of a new StaticFunction Object. StaticFunction itself is not copyable becuase it's
associated with class_instance.
We add __deepcopy__ here only for the following usage:
Example::
.. code-block:: python
>>> import copy
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> copy_net = copy.deepcopy(net) # deepcopy a new net without @to_static
Please attention that original 'net' will unwrap @to_static and rollback into simple Layer.
"""
if self._class_instance is not None:
net_name = type(self._class_instance).__name__
logging_utils.log(
level=-1,
msg="Not recommend to deepcopy '{}' decorated with @to_static, it has side effect that will"
" rollback into original state before @to_static. Please deepcopy '{}' before applying @to_static.".format(
net_name, net_name
),
)
self.rollback()
return self._dygraph_function.__get__(
memo[id(self._class_instance)]
)
else:
return self._dygraph_function
@property
def inputs(self):
raise NotImplementedError("Not implemented yet.")
@property
def outputs(self):
raise NotImplementedError("Not implemented yet.")
@property
def main_program(self):
raise NotImplementedError("Not implemented yet.")
@property
def program_cache(self):
raise NotImplementedError("Not implemented yet.")
@property
def function_spec(self):
raise NotImplementedError("Not implemented yet.")
def raise_error_template(func_str):
def _raise_error(*args, **kwargs):
error_template = (
"Can't call {func} when enable_fallback=True."
"Use paddle.jit.to_static(enable_fallback=False) instead."
)
raise RuntimeError(error_template.format(func=func_str))
return _raise_error
class SymbolicStaticFunction(StaticFunction):
def __init__(self, function, input_spec=None, **kwargs):
if input_spec is not None:
warnings.warn(
"\nSymbolic Trace don't support input_spec arguments. It will Will not produce any effect.\n"
"1. You can disable fallback mode by `paddle.jit.to_static(enable_fallback=False)` to switch to AST to static, then you can assign input spec.\n"
)
super().__init__(function, input_spec, **kwargs)
self.last_call_input_spec = None
def _perform_call(self, *args, **kwargs):
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
(
input_args_with_spec,
input_kwargs_with_spec,
) = self._function_spec.args_to_input_spec(args, kwargs)
self.last_call_input_spec = input_args_with_spec
try:
from sot import symbolic_translate
except:
import os
os.system(
"pip install git+https://github.com/PaddlePaddle/PaddleSOT@develop"
)
from sot import symbolic_translate
build_strategy = self._kwargs.get("build_strategy", None)
traced_fun = symbolic_translate(
self._dygraph_function, build_strategy=build_strategy
)
if self._class_instance is not None:
args = (self._class_instance,) + args
return traced_fun(*args, **kwargs)
@property
def code(self):
raise_error_template("code")()
@property
def concrete_program(self):
raise_error_template("concrete_program")()
concrete_program_specify_input_spec = raise_error_template(
"concrete_program_specify_input_spec"
)
get_concrete_program = raise_error_template("get_concrete_program")
get_concrete_program_with_cache_key = raise_error_template(
"get_concrete_program_with_cache_key"
)
get_traced_count = raise_error_template("get_traced_count")
@property
def inputs(self):
raise_error_template("inputs")()
@property
def outputs(self):
raise_error_template("outputs")()
@property
def main_program(self):
raise_error_template("main_program")()
@property
def program_cache(self):
raise_error_template("program_cache")()
@property
def function_spec(self):
raise_error_template("function_spec ")()
class ASTStaticFunction(StaticFunction):
"""
Wrapper class to Manage program conversion of decorated function.
"""
def __init__(self, function, input_spec=None, **kwargs):
super().__init__(function, input_spec, **kwargs)
def _perform_call(self, *args, **kwargs):
# 1. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs, is_train=self._is_train_mode()
)
# 2. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training
else:
partial_program_layer.training = self._training
partial_program_layer._cuda_graph_capture_mode = (
self._cuda_graph_capture_mode
)
partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id
# 3. return outputs.
try:
return partial_program_layer(args)
except Exception as e:
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except Exception as e:
error_data = getattr(e, error.ERROR_DATA, None)
if error_data:
error_data.raise_new_exception()
else:
logging_utils.warn(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself.".format(type(e))
)
raise e
def get_concrete_program(self, *args, **kwargs): def get_concrete_program(self, *args, **kwargs):
""" """
Returns traced concrete program and inner executable partial layer. Returns traced concrete program and inner executable partial layer.
...@@ -629,16 +880,6 @@ class StaticFunction: ...@@ -629,16 +880,6 @@ class StaticFunction:
source_code = func_to_source_code(static_func) source_code = func_to_source_code(static_func)
return source_code return source_code
@property
def dygraph_function(self):
"""
Returns the original decorated function.
"""
if self._class_instance is not None:
return self._dygraph_function.__get__(self._class_instance)
else:
return self._dygraph_function
@property @property
def concrete_program(self): def concrete_program(self):
""" """
...@@ -757,113 +998,6 @@ class StaticFunction: ...@@ -757,113 +998,6 @@ class StaticFunction:
) )
return concrete_program return concrete_program
def rollback(self):
"""
Rollback into original dygraph functions for current class instance.
Returns:
Function or Method
Example::
.. code-block:: python
>>> # doctest: +SKIP
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> out = net(x)
>>> net.forward.rollback() # rollback into dygraph mode
>>> out = net(x)
"""
def rollback_impl(class_instance):
for name, func in class_instance._original_funcs.items():
setattr(class_instance, name, func.__get__(class_instance))
for sublayer in class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
if self._class_instance is None:
return self._dygraph_function
# only rollback sub-functions on path of top _dygraph_function
func_name = self._dygraph_function.__name__
assert (
func_name in self._class_instance._original_funcs
), "Not Found function '{}' in class '{}'.".format(
func_name, self._class_instance.__name__
)
func = self._class_instance._original_funcs[func_name]
setattr(
self._class_instance, func_name, func.__get__(self._class_instance)
)
for sublayer in self._class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
return getattr(self._class_instance, func_name)
def __deepcopy__(self, memo):
"""
Customized behavior for copy.deepcopy, return original decorated function instead
of a new StaticFunction Object. StaticFunction itself is not copyable becuase it's
associated with class_instance.
We add __deepcopy__ here only for the following usage:
Example::
.. code-block:: python
>>> import copy
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> copy_net = copy.deepcopy(net) # deepcopy a new net without @to_static
Please attention that original 'net' will unwrap @to_static and rollback into simple Layer.
"""
if self._class_instance is not None:
net_name = type(self._class_instance).__name__
logging_utils.log(
level=-1,
msg="Not recommend to deepcopy '{}' decorated with @to_static, it has side effect that will"
" rollback into original state before @to_static. Please deepcopy '{}' before applying @to_static.".format(
net_name, net_name
),
)
self.rollback()
return self._dygraph_function.__get__(
memo[id(self._class_instance)]
)
else:
return self._dygraph_function
@property @property
def inputs(self): def inputs(self):
""" """
......
...@@ -404,6 +404,7 @@ def interpolate( ...@@ -404,6 +404,7 @@ def interpolate(
) )
if isinstance(size, Variable): if isinstance(size, Variable):
size = size.cast("int32") # static mode only support int32
if size.ndim != 1: if size.ndim != 1:
raise ValueError( raise ValueError(
f"If size is a tensor, it's rank must be 1, but received {size.ndim}." f"If size is a tensor, it's rank must be 1, but received {size.ndim}."
......
...@@ -133,7 +133,7 @@ def shape(input): ...@@ -133,7 +133,7 @@ def shape(input):
outputs={'Out': out}, outputs={'Out': out},
stop_gradient=True, stop_gradient=True,
) )
out.stop_gradient = True
return out return out
......
...@@ -178,9 +178,9 @@ def cast(x, dtype): ...@@ -178,9 +178,9 @@ def cast(x, dtype):
x = paddle.to_tensor([2, 3, 4], 'float64') x = paddle.to_tensor([2, 3, 4], 'float64')
y = paddle.cast(x, 'uint8') y = paddle.cast(x, 'uint8')
""" """
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dynamic_mode(): if in_dynamic_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return _C_ops.cast(x, dtype) return _C_ops.cast(x, dtype)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
...@@ -2038,7 +2038,7 @@ def split(x, num_or_sections, axis=0, name=None): ...@@ -2038,7 +2038,7 @@ def split(x, num_or_sections, axis=0, name=None):
attrs['axis'] = dim attrs['axis'] = dim
if isinstance(num_or_sections, int): if isinstance(num_or_sections, int):
assert num_or_sections > 1, 'num_or_sections must be more than 1.' assert num_or_sections > 0, 'num_or_sections must be than 0.'
if isinstance(dim, int) and input_shape[dim] > 0: if isinstance(dim, int) and input_shape[dim] > 0:
assert input_shape[dim] % num_or_sections == 0, ( assert input_shape[dim] % num_or_sections == 0, (
"The input's size along the split dimension " "The input's size along the split dimension "
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
from functools import wraps
import numpy as np
from paddle import set_flags, static
from paddle.fluid import core
@contextlib.contextmanager
def enable_fallback_guard(enable):
flag = os.environ.get("ENABLE_FALL_BACK", None)
os.environ["ENABLE_FALL_BACK"] = enable
yield
if flag is not None:
os.environ["ENABLE_FALL_BACK"] = flag
else:
del os.environ["ENABLE_FALL_BACK"]
def to_ast(func):
"""
convet run fall_back to ast
"""
def impl(*args, **kwargs):
with enable_fallback_guard("False"):
func(*args, **kwargs)
return impl
def to_sot(func):
"""
convet run fall_back to ast
"""
enable_sot = os.environ.get("ENABLE_SOT", "False") == "True"
def impl(*args, **kwargs):
if enable_sot:
with enable_fallback_guard("True"):
func(*args, **kwargs)
else:
return
return impl
def dy2static_unittest(cls):
"""
dy2static unittest must be decorated to each Dy2static Unittests.
run both in Fallback and Ast mode.
Usage like:
@dy2static_unittest
class TestA (unittest.TestCase):
...
"""
for key in dir(cls):
if key.startswith("test"):
if not key.endswith("_ast"):
test_func = getattr(cls, key)
setattr(cls, key + "_ast", to_ast(test_func))
test_func = getattr(cls, key)
setattr(cls, key, to_sot(test_func))
return cls
def ast_only_test(func):
"""
run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase):
@ast_only_test
def test_ast_only(self):
pass
"""
def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "False":
func(*args, **kwargs)
return impl
def sot_only_test(func):
"""
run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase):
@ast_only_test
def test_ast_only(self):
pass
"""
def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "True":
func(*args, **kwargs)
return impl
def test_with_new_ir(func):
@wraps(func)
def impl(*args, **kwargs):
ir_outs = None
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
try:
new_ir_flag = 'FLAGS_enable_new_ir_in_executor'
os.environ[new_ir_flag] = 'True'
set_flags({new_ir_flag: True})
ir_outs = func(*args, **kwargs)
finally:
del os.environ[new_ir_flag]
set_flags({new_ir_flag: False})
return ir_outs
return impl
def test_and_compare_with_new_ir(need_check_output: bool = True):
def decorator(func):
@wraps(func)
def impl(*args, **kwargs):
outs = func(*args, **kwargs)
if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled():
return outs
# only run in CI-Coverage
if os.environ.get('FLAGS_NEW_IR_DY2ST_TEST', None) is None:
return outs
ir_outs = test_with_new_ir(func)(*args, **kwargs)
if not need_check_output:
return outs
for i in range(len(outs)):
np.testing.assert_array_equal(
outs[i],
ir_outs[i],
err_msg='Dy2St Unittest Check ('
+ func.__name__
+ ') has diff '
+ '\nExpect '
+ str(outs[i])
+ '\n'
+ 'But Got'
+ str(ir_outs[i]),
)
return outs
return impl
return decorator
...@@ -297,7 +297,7 @@ class BaseModel(paddle.nn.Layer): ...@@ -297,7 +297,7 @@ class BaseModel(paddle.nn.Layer):
loss = paddle.nn.functional.softmax_with_cross_entropy( loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False logits=dec_output, label=label, soft_label=False
) )
loss = paddle.squeeze(loss, axes=[2]) loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1] max_tar_seq_len = paddle.shape(tar)[1]
tar_mask = paddle.static.nn.sequence_lod.sequence_mask( tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32' tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
...@@ -835,13 +835,13 @@ class AttentionModel(paddle.nn.Layer): ...@@ -835,13 +835,13 @@ class AttentionModel(paddle.nn.Layer):
loss = paddle.nn.functional.softmax_with_cross_entropy( loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False logits=dec_output, label=label, soft_label=False
) )
loss = paddle.squeeze(loss, axes=[2]) loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1] max_tar_seq_len = paddle.shape(tar)[1]
tar_mask = paddle.static.nn.sequence_lod.sequence_mask( tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32' tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
) )
loss = loss * tar_mask loss = loss * tar_mask
loss = paddle.mean(loss, axis=[0]) loss = paddle.mean(loss, axis=[0])
loss = fluid.layers.reduce_sum(loss) loss = paddle.sum(loss)
return loss return loss
...@@ -50,6 +50,7 @@ class TestAST2Func(unittest.TestCase): ...@@ -50,6 +50,7 @@ class TestAST2Func(unittest.TestCase):
self.assertEqual(func(x, y), self._ast2func(func)(x, y)) self.assertEqual(func(x, y), self._ast2func(func)(x, y))
def test_ast2func_dygraph(self): def test_ast2func_dygraph(self):
paddle.disable_static()
funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else] funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else]
x_data = np.random.random([10, 16]).astype('float32') x_data = np.random.random([10, 16]).astype('float32')
for func in funcs: for func in funcs:
...@@ -60,6 +61,8 @@ class TestAST2Func(unittest.TestCase): ...@@ -60,6 +61,8 @@ class TestAST2Func(unittest.TestCase):
self.assertTrue((true_ret == test_ret).all()) self.assertTrue((true_ret == test_ret).all())
def test_ast2func_static(self): def test_ast2func_static(self):
paddle.enable_static()
def func(x): def func(x):
y = F.relu(x) y = F.relu(x)
loss = paddle.mean(y) loss = paddle.mean(y)
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
import numpy as np import numpy as np
from bert_dygraph_model import PretrainModelLayer from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader from bert_utils import get_bert_config, get_feed_data_reader
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -262,6 +263,7 @@ class TestBert(unittest.TestCase): ...@@ -262,6 +263,7 @@ class TestBert(unittest.TestCase):
out = output() out = output()
return out return out
@ast_only_test
def test_train(self): def test_train(self):
static_loss, static_ppl = self.train_static( static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader self.bert_config, self.data_reader
......
...@@ -18,6 +18,7 @@ import tempfile ...@@ -18,6 +18,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import dy2static_unittest
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -636,6 +637,7 @@ def val_bmn(model, args): ...@@ -636,6 +637,7 @@ def val_bmn(model, args):
return loss_data return loss_data
@dy2static_unittest
class TestTrain(unittest.TestCase): class TestTrain(unittest.TestCase):
def setUp(self): def setUp(self):
self.args = Args() self.args = Args()
...@@ -666,6 +668,7 @@ class TestTrain(unittest.TestCase): ...@@ -666,6 +668,7 @@ class TestTrain(unittest.TestCase):
local_random = np.random.RandomState(SEED) local_random = np.random.RandomState(SEED)
bmn = BMN(args) bmn = BMN(args)
bmn = paddle.jit.to_static(bmn)
adam = optimizer(args, parameter_list=bmn.parameters()) adam = optimizer(args, parameter_list=bmn.parameters())
train_reader = fake_data_reader(args, 'train') train_reader = fake_data_reader(args, 'train')
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -25,12 +26,14 @@ SEED = 2020 ...@@ -25,12 +26,14 @@ SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
@dy2static_unittest
class TestDy2staticException(unittest.TestCase): class TestDy2staticException(unittest.TestCase):
def setUp(self): def setUp(self):
self.x = np.random.random([10, 16]).astype('float32') self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None self.dyfunc = None
self.error = "Your if/else have different number of return value." self.error = "Your if/else have different number of return value."
@ast_only_test
def test_error(self): def test_error(self):
if self.dyfunc: if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error): with self.assertRaisesRegex(Dygraph2StaticException, self.error):
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from test_resnet import ResNetHelper from test_resnet import ResNetHelper
import paddle import paddle
@dy2static_unittest
class TestResnetWithPass(unittest.TestCase): class TestResnetWithPass(unittest.TestCase):
def setUp(self): def setUp(self):
self.build_strategy = paddle.static.BuildStrategy() self.build_strategy = paddle.static.BuildStrategy()
...@@ -64,6 +66,7 @@ class TestResnetWithPass(unittest.TestCase): ...@@ -64,6 +66,7 @@ class TestResnetWithPass(unittest.TestCase):
), ),
) )
@ast_only_test
def test_resnet(self): def test_resnet(self):
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False) dygraph_loss = self.train(to_static=False)
...@@ -77,6 +80,7 @@ class TestResnetWithPass(unittest.TestCase): ...@@ -77,6 +80,7 @@ class TestResnetWithPass(unittest.TestCase):
) )
self.verify_predict() self.verify_predict()
@ast_only_test
def test_in_static_mode_mkldnn(self): def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try: try:
......
...@@ -16,6 +16,7 @@ import unittest ...@@ -16,6 +16,7 @@ import unittest
from collections import Counter from collections import Counter
import numpy as np import numpy as np
from dygraph_to_static_util import dy2static_unittest
from test_fetch_feed import Linear, Pool2D from test_fetch_feed import Linear, Pool2D
import paddle import paddle
...@@ -24,6 +25,7 @@ from paddle.jit.api import to_static ...@@ -24,6 +25,7 @@ from paddle.jit.api import to_static
from paddle.jit.dy2static import convert_to_static from paddle.jit.dy2static import convert_to_static
@dy2static_unittest
class TestCacheProgram(unittest.TestCase): class TestCacheProgram(unittest.TestCase):
def setUp(self): def setUp(self):
self.batch_num = 5 self.batch_num = 5
...@@ -36,7 +38,7 @@ class TestCacheProgram(unittest.TestCase): ...@@ -36,7 +38,7 @@ class TestCacheProgram(unittest.TestCase):
with fluid.dygraph.guard(fluid.CPUPlace()): with fluid.dygraph.guard(fluid.CPUPlace()):
static_net = self.dygraph_class() static_net = self.dygraph_class()
for batch_id in range(self.batch_num): for batch_id in range(self.batch_num):
out = static_net(self.data) out = static_net(paddle.to_tensor(self.data))
# Check outputs # Check outputs
prev_out = cur_out prev_out = cur_out
cur_out = out cur_out = out
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dy2st_test_utils import test_and_compare_with_new_ir from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir
from paddle import fluid from paddle import fluid
from paddle.jit.api import to_static from paddle.jit.api import to_static
...@@ -38,7 +38,6 @@ def test_int_cast(x): ...@@ -38,7 +38,6 @@ def test_int_cast(x):
return x return x
@to_static
def test_float_cast(x): def test_float_cast(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
x = float(x) x = float(x)
...@@ -89,6 +88,7 @@ class TestCastBase(unittest.TestCase): ...@@ -89,6 +88,7 @@ class TestCastBase(unittest.TestCase):
res = self.func(self.input) res = self.func(self.input)
return res return res
@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False) @test_and_compare_with_new_ir(False)
def test_cast_result(self): def test_cast_result(self):
res = self.do_test().numpy() res = self.do_test().numpy()
...@@ -136,7 +136,7 @@ class TestFloatCast(TestCastBase): ...@@ -136,7 +136,7 @@ class TestFloatCast(TestCastBase):
self.cast_dtype = 'float32' self.cast_dtype = 'float32'
def set_func(self): def set_func(self):
self.func = test_float_cast self.func = to_static(test_float_cast)
class TestMixCast(TestCastBase): class TestMixCast(TestCastBase):
...@@ -156,6 +156,7 @@ class TestMixCast(TestCastBase): ...@@ -156,6 +156,7 @@ class TestMixCast(TestCastBase):
def set_func(self): def set_func(self):
self.func = test_mix_cast self.func = test_mix_cast
@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False) @test_and_compare_with_new_ir(False)
def test_cast_result(self): def test_cast_result(self):
res = self.do_test().numpy() res = self.do_test().numpy()
...@@ -189,6 +190,7 @@ class TestNotVarCast(TestCastBase): ...@@ -189,6 +190,7 @@ class TestNotVarCast(TestCastBase):
def set_func(self): def set_func(self):
self.func = test_not_var_cast self.func = test_not_var_cast
@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False) @test_and_compare_with_new_ir(False)
def test_cast_result(self): def test_cast_result(self):
res = self.do_test() res = self.do_test()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -38,6 +39,7 @@ class PrimeNet(paddle.nn.Layer): ...@@ -38,6 +39,7 @@ class PrimeNet(paddle.nn.Layer):
return out return out
@dy2static_unittest
class TestPrimForward(unittest.TestCase): class TestPrimForward(unittest.TestCase):
""" """
This case only tests prim_forward + to_static + cinn. Thus we need to This case only tests prim_forward + to_static + cinn. Thus we need to
...@@ -88,6 +90,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -88,6 +90,7 @@ class TestPrimForward(unittest.TestCase):
# Ensure that softmax is splitted into small ops # Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops) self.assertTrue('softmax' not in fwd_ops)
@ast_only_test
def test_cinn_prim_forward(self): def test_cinn_prim_forward(self):
dy_res = self.train(use_prim=False) dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True) cinn_res = self.train(use_prim=True)
...@@ -98,6 +101,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -98,6 +101,7 @@ class TestPrimForward(unittest.TestCase):
) )
@dy2static_unittest
class TestPrimForwardAndBackward(unittest.TestCase): class TestPrimForwardAndBackward(unittest.TestCase):
""" """
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
...@@ -153,6 +157,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -153,6 +157,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
if op != "matmul_v2_grad": if op != "matmul_v2_grad":
self.assertTrue("_grad" not in op) self.assertTrue("_grad" not in op)
@ast_only_test
def test_cinn_prim(self): def test_cinn_prim(self):
dy_res = self.train(use_prim=False) dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True) cinn_res = self.train(use_prim=True)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -52,6 +53,7 @@ class PrimeNet(paddle.nn.Layer): ...@@ -52,6 +53,7 @@ class PrimeNet(paddle.nn.Layer):
return out return out
@dy2static_unittest
class TestPrimForwardAndBackward(unittest.TestCase): class TestPrimForwardAndBackward(unittest.TestCase):
""" """
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
...@@ -104,6 +106,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -104,6 +106,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that gelu is splitted into small ops # Ensure that gelu is splitted into small ops
self.assertTrue('gelu' not in fwd_ops) self.assertTrue('gelu' not in fwd_ops)
@ast_only_test
def test_cinn_prim(self): def test_cinn_prim(self):
for shape in self.shapes: for shape in self.shapes:
for dtype in self.dtypes: for dtype in self.dtypes:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -101,6 +102,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -101,6 +102,7 @@ class TestPrimForward(unittest.TestCase):
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops) self.assertTrue('layer_norm' not in fwd_ops)
@ast_only_test
def test_cinn_prim_forward(self): def test_cinn_prim_forward(self):
for dtype in self.dtypes: for dtype in self.dtypes:
if paddle.device.get_device() == "cpu": if paddle.device.get_device() == "cpu":
...@@ -168,6 +170,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -168,6 +170,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops) self.assertTrue('layer_norm' not in fwd_ops)
@ast_only_test
def test_cinn_prim(self): def test_cinn_prim(self):
for dtype in self.dtypes: for dtype in self.dtypes:
if paddle.device.get_device() == "cpu": if paddle.device.get_device() == "cpu":
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
from paddle import tensor from paddle import tensor
...@@ -54,6 +55,7 @@ class PrimeNet( ...@@ -54,6 +55,7 @@ class PrimeNet(
return out return out
@dy2static_unittest
class TestPrimForward(unittest.TestCase): class TestPrimForward(unittest.TestCase):
""" """
This case only tests prim_forward + to_static + cinn. Thus we need to This case only tests prim_forward + to_static + cinn. Thus we need to
...@@ -110,6 +112,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -110,6 +112,7 @@ class TestPrimForward(unittest.TestCase):
# Ensure that reduce_mean is splitted into small ops # Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops) self.assertTrue('reduce_mean' not in fwd_ops)
@ast_only_test
def test_cinn_prim_forward(self): def test_cinn_prim_forward(self):
for shape in self.shapes: for shape in self.shapes:
for dtype in self.dtypes: for dtype in self.dtypes:
...@@ -131,6 +134,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -131,6 +134,7 @@ class TestPrimForward(unittest.TestCase):
) )
@dy2static_unittest
class TestPrimForwardAndBackward(unittest.TestCase): class TestPrimForwardAndBackward(unittest.TestCase):
""" """
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
...@@ -183,6 +187,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -183,6 +187,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that reduce_mean is splitted into small ops # Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops) self.assertTrue('reduce_mean' not in fwd_ops)
@ast_only_test
def test_cinn_prim(self): def test_cinn_prim(self):
for shape in self.shapes: for shape in self.shapes:
for dtype in self.dtypes: for dtype in self.dtypes:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
import unittest import unittest
from numpy import append from numpy import append
...@@ -324,4 +325,5 @@ class TestPushPopTrans(unittest.TestCase): ...@@ -324,4 +325,5 @@ class TestPushPopTrans(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
os.environ['ENABLE_FALL_BACK'] = "False"
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ import tempfile ...@@ -17,6 +17,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle import paddle
...@@ -69,6 +70,7 @@ class NestSequentialNet(paddle.nn.Layer): ...@@ -69,6 +70,7 @@ class NestSequentialNet(paddle.nn.Layer):
return self.layers(x) return self.layers(x)
@dy2static_unittest
class TestSequential(unittest.TestCase): class TestSequential(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.set_device('cpu') paddle.set_device('cpu')
......
...@@ -16,6 +16,7 @@ import logging ...@@ -16,6 +16,7 @@ import logging
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
import paddle.jit.dy2static as _jst import paddle.jit.dy2static as _jst
...@@ -252,6 +253,7 @@ class TestNotToConvert(TestRecursiveCall2): ...@@ -252,6 +253,7 @@ class TestNotToConvert(TestRecursiveCall2):
) )
@dy2static_unittest
class TestNotToConvert2(TestRecursiveCall2): class TestNotToConvert2(TestRecursiveCall2):
def set_func(self): def set_func(self):
self.net = NotToStaticHelper() self.net = NotToStaticHelper()
...@@ -264,7 +266,9 @@ class TestNotToConvert2(TestRecursiveCall2): ...@@ -264,7 +266,9 @@ class TestNotToConvert2(TestRecursiveCall2):
self.assertIsNotNone(options) self.assertIsNotNone(options)
self.assertTrue(options.not_convert) self.assertTrue(options.not_convert)
@ast_only_test
def test_code(self): def test_code(self):
self.dygraph_func = paddle.jit.to_static(self.net.sum)
# check 'if statement' is not converted # check 'if statement' is not converted
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code) self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)
...@@ -277,19 +281,23 @@ def forward(self, x): ...@@ -277,19 +281,23 @@ def forward(self, x):
return x return x
@dy2static_unittest
class TestConvertPaddleAPI(unittest.TestCase): class TestConvertPaddleAPI(unittest.TestCase):
@ast_only_test
def test_functional_api(self): def test_functional_api(self):
func = paddle.nn.functional.relu func = paddle.nn.functional.relu
func = paddle.jit.to_static(func) func = paddle.jit.to_static(func)
self.assertNotIn("_jst.IfElse", func.code) self.assertNotIn("_jst.IfElse", func.code)
self.assertIn("if in_dynamic_mode()", func.code) self.assertIn("if in_dynamic_mode()", func.code)
@ast_only_test
def test_class_api(self): def test_class_api(self):
bn = paddle.nn.SyncBatchNorm(2) bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn) paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code) self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if in_dynamic_mode()", bn.forward.code) self.assertIn("if in_dynamic_mode()", bn.forward.code)
@ast_only_test
def test_class_patch_api(self): def test_class_patch_api(self):
paddle.nn.SyncBatchNorm.forward = forward paddle.nn.SyncBatchNorm.forward = forward
bn = paddle.nn.SyncBatchNorm(2) bn = paddle.nn.SyncBatchNorm(2)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import unittest import unittest
from dygraph_to_static_util import ast_only_test
import paddle import paddle
from paddle.jit import to_static from paddle.jit import to_static
from paddle.jit.dy2static.convert_call_func import translator_logger from paddle.jit.dy2static.convert_call_func import translator_logger
...@@ -31,6 +33,8 @@ def main_func(): ...@@ -31,6 +33,8 @@ def main_func():
class TestConvertGenerator(unittest.TestCase): class TestConvertGenerator(unittest.TestCase):
# fallback will ok.
@ast_only_test
def test_raise_error(self): def test_raise_error(self):
translator_logger.verbosity_level = 1 translator_logger.verbosity_level = 1
with self.assertLogs( with self.assertLogs(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
...@@ -40,6 +41,8 @@ net.forward = "A string so that convert forward will fail" ...@@ -40,6 +41,8 @@ net.forward = "A string so that convert forward will fail"
class TestConvertCall(unittest.TestCase): class TestConvertCall(unittest.TestCase):
# fallback mode will raise a InnerError, it's ok.
@ast_only_test
def test_class_exception(self): def test_class_exception(self):
@paddle.jit.to_static @paddle.jit.to_static
def call_not_exist(): def call_not_exist():
......
...@@ -15,6 +15,11 @@ ...@@ -15,6 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
)
import paddle import paddle
...@@ -28,8 +33,8 @@ class TestCpuCuda(unittest.TestCase): ...@@ -28,8 +33,8 @@ class TestCpuCuda(unittest.TestCase):
return x return x
x = paddle.to_tensor([3]) x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code) # print(paddle.jit.to_static(func).code)
print(paddle.jit.to_static(func)(x)) # print(paddle.jit.to_static(func)(x))
class TestToTensor(unittest.TestCase): class TestToTensor(unittest.TestCase):
...@@ -41,7 +46,7 @@ class TestToTensor(unittest.TestCase): ...@@ -41,7 +46,7 @@ class TestToTensor(unittest.TestCase):
return x return x
x = paddle.to_tensor([3]) x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code) # print(paddle.jit.to_static(func).code)
np.testing.assert_allclose( np.testing.assert_allclose(
paddle.jit.to_static(func)(x).numpy(), paddle.jit.to_static(func)(x).numpy(),
np.array([1, 2, 3, 4]), np.array([1, 2, 3, 4]),
...@@ -49,7 +54,9 @@ class TestToTensor(unittest.TestCase): ...@@ -49,7 +54,9 @@ class TestToTensor(unittest.TestCase):
) )
@dy2static_unittest
class TestToTensor1(unittest.TestCase): class TestToTensor1(unittest.TestCase):
@ast_only_test
def test_to_tensor_with_variable_list(self): def test_to_tensor_with_variable_list(self):
def func(x): def func(x):
ones = paddle.to_tensor([1]) ones = paddle.to_tensor([1])
...@@ -61,28 +68,59 @@ class TestToTensor1(unittest.TestCase): ...@@ -61,28 +68,59 @@ class TestToTensor1(unittest.TestCase):
return x return x
x = paddle.to_tensor([3]) x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
np.testing.assert_allclose( np.testing.assert_allclose(
paddle.jit.to_static(func)(x).numpy(), paddle.jit.to_static(func)(x).numpy(),
np.array([1, 2, 3, 4]), np.array([1, 2, 3, 4]),
rtol=1e-05, rtol=1e-05,
) )
@sot_only_test
def test_to_tensor_with_variable_list_sot(self):
def func(x):
ones = paddle.to_tensor([1])
twos = paddle.to_tensor([2])
""" we ignore the [3] and [4], they will be assign to a variable, and is regard as scalar.
TODO: deal with this case after 0-dim tensor is developed.
"""
x = paddle.to_tensor([ones, twos, [3], [4]])
return x
x = paddle.to_tensor([3])
np.testing.assert_allclose(
paddle.jit.to_static(func)(x),
np.array([[1], [2], [3], [4]]),
rtol=1e-05,
)
@dy2static_unittest
class TestToTensor2(unittest.TestCase): class TestToTensor2(unittest.TestCase):
@ast_only_test
def test_to_tensor_with_variable_list(self): def test_to_tensor_with_variable_list(self):
def func(x): def func(x):
x = paddle.to_tensor([[1], [2], [3], [4]]) x = paddle.to_tensor([[1], [2], [3], [4]])
return x return x
x = paddle.to_tensor([3]) x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
np.testing.assert_allclose( np.testing.assert_allclose(
paddle.jit.to_static(func)(x).numpy(), paddle.jit.to_static(func)(x).numpy(),
np.array([[1], [2], [3], [4]]), np.array([[1], [2], [3], [4]]),
rtol=1e-05, rtol=1e-05,
) )
@sot_only_test
def test_to_tensor_with_variable_list_sot(self):
def func(x):
x = paddle.to_tensor([[1], [2], [3], [4]])
return x
x = paddle.to_tensor([3])
np.testing.assert_allclose(
paddle.jit.to_static(func)(x),
np.array([[1], [2], [3], [4]]),
rtol=1e-05,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -30,6 +30,8 @@ from paddle.jit.dy2static.program_translator import ( ...@@ -30,6 +30,8 @@ from paddle.jit.dy2static.program_translator import (
from paddle.nn import Layer from paddle.nn import Layer
from paddle.static import InputSpec from paddle.static import InputSpec
os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only
class SimpleNet(Layer): class SimpleNet(Layer):
def __init__(self): def __init__(self):
......
...@@ -19,6 +19,7 @@ from functools import wraps ...@@ -19,6 +19,7 @@ from functools import wraps
import decos import decos
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
...@@ -147,7 +148,6 @@ def fun8(x, y=0): ...@@ -147,7 +148,6 @@ def fun8(x, y=0):
return a return a
@paddle.jit.to_static
def forward(): def forward():
funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8] funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8]
out = [] out = []
...@@ -166,7 +166,6 @@ def fun9(): ...@@ -166,7 +166,6 @@ def fun9():
print('in fun9 want contextmanager warning') print('in fun9 want contextmanager warning')
@paddle.jit.to_static
def warn1(): def warn1():
fun9() fun9()
...@@ -182,9 +181,10 @@ def deco_with_paddle_api(): ...@@ -182,9 +181,10 @@ def deco_with_paddle_api():
return fun10() return fun10()
@dy2static_unittest
class TestDecoratorTransform(unittest.TestCase): class TestDecoratorTransform(unittest.TestCase):
def test_deco_transform(self): def test_deco_transform(self):
outs = forward() outs = paddle.jit.to_static(forward)()
np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05) np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05)
np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05) np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05)
np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05) np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05)
...@@ -194,11 +194,12 @@ class TestDecoratorTransform(unittest.TestCase): ...@@ -194,11 +194,12 @@ class TestDecoratorTransform(unittest.TestCase):
np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05) np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05) np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05)
@ast_only_test
def test_contextmanager_warning(self): def test_contextmanager_warning(self):
paddle.disable_static() paddle.disable_static()
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
warn1() paddle.jit.to_static(warn1)()
flag = False flag = False
for warn in w: for warn in w:
if ( if (
......
...@@ -23,6 +23,8 @@ from paddle import fluid ...@@ -23,6 +23,8 @@ from paddle import fluid
from paddle.jit.dy2static import error from paddle.jit.dy2static import error
from paddle.jit.dy2static.origin_info import unwrap from paddle.jit.dy2static.origin_info import unwrap
os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only
def inner_func(): def inner_func():
paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int") paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")
...@@ -255,11 +257,11 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase): ...@@ -255,11 +257,11 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 35, in func_error_in_compile_time'.format( 'File "{}", line 37, in func_error_in_compile_time'.format(
self.filepath self.filepath
), ),
'inner_func()', 'inner_func()',
f'File "{self.filepath}", line 28, in inner_func', f'File "{self.filepath}", line 30, in inner_func',
'def inner_func():', 'def inner_func():',
'paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', 'paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE', '<--- HERE',
...@@ -286,7 +288,7 @@ class TestErrorStaticLayerCallInCompiletime_2( ...@@ -286,7 +288,7 @@ class TestErrorStaticLayerCallInCompiletime_2(
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 46, in func_error_in_compile_time_2'.format( 'File "{}", line 48, in func_error_in_compile_time_2'.format(
self.filepath self.filepath
), ),
'def func_error_in_compile_time_2(x):', 'def func_error_in_compile_time_2(x):',
...@@ -312,7 +314,7 @@ class TestErrorStaticLayerCallInCompiletime_3( ...@@ -312,7 +314,7 @@ class TestErrorStaticLayerCallInCompiletime_3(
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
f'File "{self.filepath}", line 91, in forward', f'File "{self.filepath}", line 93, in forward',
'@paddle.jit.to_static', '@paddle.jit.to_static',
'def forward(self):', 'def forward(self):',
'self.test_func()', 'self.test_func()',
...@@ -336,7 +338,7 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime): ...@@ -336,7 +338,7 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 54, in func_error_in_runtime'.format( 'File "{}", line 56, in func_error_in_runtime'.format(
self.filepath self.filepath
), ),
'x = fluid.dygraph.to_variable(x)', 'x = fluid.dygraph.to_variable(x)',
...@@ -353,7 +355,7 @@ class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime): ...@@ -353,7 +355,7 @@ class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime):
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format( 'File "{}", line 108, in func_error_in_runtime_with_empty_line'.format(
self.filepath self.filepath
), ),
'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")', 'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")',
...@@ -376,7 +378,7 @@ class TestJitSaveInCompiletime(TestErrorBase): ...@@ -376,7 +378,7 @@ class TestJitSaveInCompiletime(TestErrorBase):
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
f'File "{self.filepath}", line 80, in forward', f'File "{self.filepath}", line 82, in forward',
'def forward(self, x):', 'def forward(self, x):',
'y = self._linear(x)', 'y = self._linear(x)',
'z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', 'z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")',
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
...@@ -84,6 +85,7 @@ class TestFallback(unittest.TestCase): ...@@ -84,6 +85,7 @@ class TestFallback(unittest.TestCase):
u_net(self.x).numpy(), u_net(self.x).numpy(),
) )
@ast_only_test
def test_case_net_error(self): def test_case_net_error(self):
s_net = SuppportNet() s_net = SuppportNet()
u_net = UnsuppportNet() u_net = UnsuppportNet()
......
...@@ -53,7 +53,7 @@ class Linear(paddle.nn.Layer): ...@@ -53,7 +53,7 @@ class Linear(paddle.nn.Layer):
) )
self.act = paddle.nn.ReLU() self.act = paddle.nn.ReLU()
@to_static # @to_static
def forward(self, x): def forward(self, x):
pre = self.fc(x) pre = self.fc(x)
pre = self.act(pre) pre = self.act(pre)
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.jit import to_static
@paddle.jit.to_static @paddle.jit.to_static
...@@ -48,7 +48,7 @@ def decorated_call_decorated(x): ...@@ -48,7 +48,7 @@ def decorated_call_decorated(x):
class DoubleDecorated: class DoubleDecorated:
@classmethod @classmethod
@to_static @paddle.jit.to_static
def double_decorated_func1(self, x): def double_decorated_func1(self, x):
return dygraph_decorated_func(x) return dygraph_decorated_func(x)
...@@ -59,6 +59,7 @@ class DoubleDecorated: ...@@ -59,6 +59,7 @@ class DoubleDecorated:
class TestFullNameDecorator(unittest.TestCase): class TestFullNameDecorator(unittest.TestCase):
@ast_only_test
def test_run_success(self): def test_run_success(self):
x = np.ones([1, 2]).astype("float32") x = np.ones([1, 2]).astype("float32")
answer = np.zeros([1, 2]).astype("float32") answer = np.zeros([1, 2]).astype("float32")
......
...@@ -17,6 +17,7 @@ import tempfile ...@@ -17,6 +17,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle import paddle
...@@ -25,7 +26,6 @@ class GradLayer(paddle.nn.Layer): ...@@ -25,7 +26,6 @@ class GradLayer(paddle.nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@paddle.jit.to_static
def forward(self, x): def forward(self, x):
x.stop_gradient = False x.stop_gradient = False
y = x * x y = x * x
...@@ -38,7 +38,6 @@ class GradLinearLayer(paddle.nn.Layer): ...@@ -38,7 +38,6 @@ class GradLinearLayer(paddle.nn.Layer):
super().__init__() super().__init__()
self.linear = paddle.nn.Linear(5, 5, bias_attr=False) self.linear = paddle.nn.Linear(5, 5, bias_attr=False)
@paddle.jit.to_static
def forward(self, x): def forward(self, x):
x.stop_gradient = False x.stop_gradient = False
tmp = x + x tmp = x + x
...@@ -56,7 +55,6 @@ class NoGradLinearLayer(paddle.nn.Layer): ...@@ -56,7 +55,6 @@ class NoGradLinearLayer(paddle.nn.Layer):
super().__init__() super().__init__()
self.linear = paddle.nn.Linear(5, 5, bias_attr=False) self.linear = paddle.nn.Linear(5, 5, bias_attr=False)
@paddle.jit.to_static
def forward(self, x): def forward(self, x):
x.stop_gradient = False x.stop_gradient = False
...@@ -69,7 +67,7 @@ class NoGradLinearLayer(paddle.nn.Layer): ...@@ -69,7 +67,7 @@ class NoGradLinearLayer(paddle.nn.Layer):
class TestGrad(unittest.TestCase): class TestGrad(unittest.TestCase):
def setUp(self): def setUp(self):
self.func = GradLayer() self.func = paddle.jit.to_static(GradLayer())
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False self.x.stop_gradient = False
...@@ -85,9 +83,10 @@ class TestGrad(unittest.TestCase): ...@@ -85,9 +83,10 @@ class TestGrad(unittest.TestCase):
np.testing.assert_allclose(static_res, dygraph_res, rtol=1e-05) np.testing.assert_allclose(static_res, dygraph_res, rtol=1e-05)
@dy2static_unittest
class TestGradLinear(TestGrad): class TestGradLinear(TestGrad):
def setUp(self): def setUp(self):
self.func = GradLinearLayer() self.func = paddle.jit.to_static(GradLinearLayer())
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False self.x.stop_gradient = False
...@@ -103,6 +102,7 @@ class TestGradLinear(TestGrad): ...@@ -103,6 +102,7 @@ class TestGradLinear(TestGrad):
self.temp_dir.cleanup() self.temp_dir.cleanup()
def test_save_infer_program(self): def test_save_infer_program(self):
self.setUp() # make self.func change to ast mode
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[10, 2, 5], dtype='float32') paddle.static.InputSpec(shape=[10, 2, 5], dtype='float32')
] ]
...@@ -114,6 +114,7 @@ class TestGradLinear(TestGrad): ...@@ -114,6 +114,7 @@ class TestGradLinear(TestGrad):
np.testing.assert_allclose(origin_res, load_res, rtol=1e-05) np.testing.assert_allclose(origin_res, load_res, rtol=1e-05)
def test_save_train_program(self): def test_save_train_program(self):
self.setUp() # make self.func change to ast mode
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0) grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
optimizer = paddle.optimizer.SGD( optimizer = paddle.optimizer.SGD(
learning_rate=0.01, learning_rate=0.01,
...@@ -138,7 +139,7 @@ class TestGradLinear(TestGrad): ...@@ -138,7 +139,7 @@ class TestGradLinear(TestGrad):
class TestNoGradLinear(TestGradLinear): class TestNoGradLinear(TestGradLinear):
def setUp(self): def setUp(self):
self.func = NoGradLinearLayer() self.func = paddle.jit.to_static(NoGradLinearLayer())
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False self.x.stop_gradient = False
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle import paddle
from paddle import ParamAttr
from paddle.nn import BatchNorm, Linear from paddle.nn import BatchNorm, Linear
...@@ -28,13 +28,9 @@ class SimpleNet(paddle.nn.Layer): ...@@ -28,13 +28,9 @@ class SimpleNet(paddle.nn.Layer):
self.linear0 = Linear(100, 50) self.linear0 = Linear(100, 50)
self.linear1 = Linear(50, 10) self.linear1 = Linear(50, 10)
param_attr0 = ParamAttr(name="aaaprefix_bn_scale") self.bn0 = BatchNorm(50)
bias_attr0 = ParamAttr(name="aaaprefix_bn_offset")
self.bn0 = BatchNorm(50, param_attr=param_attr0, bias_attr=bias_attr0)
param_attr1 = ParamAttr(name="bn_scale") self.bn1 = BatchNorm(10)
bias_attr1 = ParamAttr(name="bn_offset")
self.bn1 = BatchNorm(10, param_attr=param_attr1, bias_attr=bias_attr1)
def forward(self, x): def forward(self, x):
x1 = self.linear0(x) x1 = self.linear0(x)
...@@ -45,6 +41,7 @@ class SimpleNet(paddle.nn.Layer): ...@@ -45,6 +41,7 @@ class SimpleNet(paddle.nn.Layer):
return dx[0] return dx[0]
@dy2static_unittest
class TestGradNameParse(unittest.TestCase): class TestGradNameParse(unittest.TestCase):
def test_grad_name_parse(self): def test_grad_name_parse(self):
net = SimpleNet() net = SimpleNet()
...@@ -72,6 +69,7 @@ def tanh_high_order_grad(x): ...@@ -72,6 +69,7 @@ def tanh_high_order_grad(x):
return paddle.grad(y, x, create_graph=True)[0] return paddle.grad(y, x, create_graph=True)[0]
@dy2static_unittest
class TestTanhHighOrderGrad(unittest.TestCase): class TestTanhHighOrderGrad(unittest.TestCase):
def setUp(self): def setUp(self):
self.func = tanh_high_order_grad self.func = tanh_high_order_grad
...@@ -116,10 +114,11 @@ class TestTanhHighOrderGrad(unittest.TestCase): ...@@ -116,10 +114,11 @@ class TestTanhHighOrderGrad(unittest.TestCase):
def matmul_high_order_grad(x, y): def matmul_high_order_grad(x, y):
z = paddle.matmul(x, y) z = paddle.matmul(x, y)
g = paddle.grad(z, [x, y], create_graph=False) g = paddle.grad(z, [x, y], create_graph=True)
return g[0] return g[0]
@dy2static_unittest
class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad): class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad):
def setUp(self): def setUp(self):
self.func = matmul_high_order_grad self.func = matmul_high_order_grad
...@@ -139,6 +138,7 @@ class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad): ...@@ -139,6 +138,7 @@ class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad):
self.dy2st_grad_input = (x2,) self.dy2st_grad_input = (x2,)
@dy2static_unittest
class TestMatMulHighOrderGrad2(TestTanhHighOrderGrad): class TestMatMulHighOrderGrad2(TestTanhHighOrderGrad):
def setUp(self): def setUp(self):
self.func = matmul_high_order_grad self.func = matmul_high_order_grad
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from ifelse_simple_func import ( from ifelse_simple_func import (
NetWithControlFlowIf, NetWithControlFlowIf,
add_fn, add_fn,
...@@ -54,12 +55,14 @@ else: ...@@ -54,12 +55,14 @@ else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
@dy2static_unittest
class TestDy2staticException(unittest.TestCase): class TestDy2staticException(unittest.TestCase):
def setUp(self): def setUp(self):
self.x = np.random.random([10, 16]).astype('float32') self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None self.dyfunc = None
self.error = "Your if/else have different number of return value." self.error = "Your if/else have different number of return value."
@ast_only_test
def test_error(self): def test_error(self):
if self.dyfunc: if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error): with self.assertRaisesRegex(Dygraph2StaticException, self.error):
...@@ -412,10 +415,11 @@ class TestNewVarCreateInOneBranch(unittest.TestCase): ...@@ -412,10 +415,11 @@ class TestNewVarCreateInOneBranch(unittest.TestCase):
self.assertEqual(paddle.jit.to_static(case_func)(True), -2) self.assertEqual(paddle.jit.to_static(case_func)(True), -2)
@dy2static_unittest
class TestDy2StIfElseRetInt1(unittest.TestCase): class TestDy2StIfElseRetInt1(unittest.TestCase):
def setUp(self): def setUp(self):
self.x = np.random.random([5]).astype('float32') self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int1 self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int1)
self.out = self.get_dy2stat_out() self.out = self.get_dy2stat_out()
def get_dy2stat_out(self): def get_dy2stat_out(self):
...@@ -425,7 +429,9 @@ class TestDy2StIfElseRetInt1(unittest.TestCase): ...@@ -425,7 +429,9 @@ class TestDy2StIfElseRetInt1(unittest.TestCase):
paddle.jit.enable_to_static(False) paddle.jit.enable_to_static(False)
return out return out
@ast_only_test
def test_ast_to_func(self): def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor)) self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor))
self.assertIsInstance(self.out[1], int) self.assertIsInstance(self.out[1], int)
...@@ -437,21 +443,26 @@ class TestDy2StIfElseRetInt2(TestDy2staticException): ...@@ -437,21 +443,26 @@ class TestDy2StIfElseRetInt2(TestDy2staticException):
self.dyfunc = dyfunc_ifelse_ret_int2 self.dyfunc = dyfunc_ifelse_ret_int2
@dy2static_unittest
class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1): class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
def setUp(self): def setUp(self):
self.x = np.random.random([5]).astype('float32') self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int3 self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3)
self.out = self.get_dy2stat_out() self.out = self.get_dy2stat_out()
@ast_only_test
def test_ast_to_func(self): def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor)) self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor))
@dy2static_unittest
class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
def setUp(self): def setUp(self):
self.x = np.random.random([5]).astype('float32') self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int4 self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4)
@ast_only_test
def test_ast_to_func(self): def test_ast_to_func(self):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
with self.assertRaises(Dygraph2StaticException): with self.assertRaises(Dygraph2StaticException):
......
...@@ -286,7 +286,7 @@ class TestListInWhileLoop(TestListWithoutControlFlow): ...@@ -286,7 +286,7 @@ class TestListInWhileLoop(TestListWithoutControlFlow):
def train(self, to_static=False): def train(self, to_static=False):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
if to_static: if to_static:
print(paddle.jit.to_static(self.dygraph_func).code) # print(paddle.jit.to_static(self.dygraph_func).code)
res = paddle.jit.to_static(self.dygraph_func)( res = paddle.jit.to_static(self.dygraph_func)(
self.input, self.iter_num self.input, self.iter_num
) )
......
...@@ -17,6 +17,7 @@ import tempfile ...@@ -17,6 +17,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
from paddle import nn from paddle import nn
...@@ -44,6 +45,7 @@ class Net(nn.Layer): ...@@ -44,6 +45,7 @@ class Net(nn.Layer):
return x return x
@dy2static_unittest
class TestLstm(unittest.TestCase): class TestLstm(unittest.TestCase):
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
...@@ -69,6 +71,7 @@ class TestLstm(unittest.TestCase): ...@@ -69,6 +71,7 @@ class TestLstm(unittest.TestCase):
static_out = self.run_lstm(to_static=True) static_out = self.run_lstm(to_static=True)
np.testing.assert_allclose(dygraph_out, static_out, rtol=1e-05) np.testing.assert_allclose(dygraph_out, static_out, rtol=1e-05)
@ast_only_test
def test_save_in_eval(self, with_training=True): def test_save_in_eval(self, with_training=True):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
net = Net(12, 2) net = Net(12, 2)
...@@ -133,6 +136,7 @@ class LinearNet(nn.Layer): ...@@ -133,6 +136,7 @@ class LinearNet(nn.Layer):
return y return y
@dy2static_unittest
class TestSaveInEvalMode(unittest.TestCase): class TestSaveInEvalMode(unittest.TestCase):
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
...@@ -178,6 +182,7 @@ class TestSaveInEvalMode(unittest.TestCase): ...@@ -178,6 +182,7 @@ class TestSaveInEvalMode(unittest.TestCase):
) )
@dy2static_unittest
class TestEvalAfterSave(unittest.TestCase): class TestEvalAfterSave(unittest.TestCase):
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
from time import time from time import time
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -158,6 +159,7 @@ class TestMNISTWithToStatic(TestMNIST): ...@@ -158,6 +159,7 @@ class TestMNISTWithToStatic(TestMNIST):
def train_dygraph(self): def train_dygraph(self):
return self.train(to_static=False) return self.train(to_static=False)
@ast_only_test
def test_mnist_to_static(self): def test_mnist_to_static(self):
dygraph_loss = self.train_dygraph() dygraph_loss = self.train_dygraph()
static_loss = self.train_static() static_loss = self.train_static()
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import unittest import unittest
from dygraph_to_static_util import ast_only_test
import paddle import paddle
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -75,6 +77,7 @@ class CheckOpAttr(unittest.TestCase): ...@@ -75,6 +77,7 @@ class CheckOpAttr(unittest.TestCase):
'elementwise_sub': self.sub_attrs, 'elementwise_sub': self.sub_attrs,
} }
@ast_only_test
def test_set_op_attrs(self): def test_set_op_attrs(self):
net = NetWithOpAttr(self.in_num, self.out_num) net = NetWithOpAttr(self.in_num, self.out_num)
# set attrs # set attrs
...@@ -116,6 +119,7 @@ class CheckOpAttr(unittest.TestCase): ...@@ -116,6 +119,7 @@ class CheckOpAttr(unittest.TestCase):
else: else:
self.assertEqual(op_val, expect_val) self.assertEqual(op_val, expect_val)
@ast_only_test
def test_set_op_attrs_with_sub_block(self): def test_set_op_attrs_with_sub_block(self):
net = NetWithOpAttr(self.in_num, self.out_num) net = NetWithOpAttr(self.in_num, self.out_num)
# set attrs # set attrs
......
...@@ -79,11 +79,6 @@ class TestParameterList(unittest.TestCase): ...@@ -79,11 +79,6 @@ class TestParameterList(unittest.TestCase):
dygraph_loss = self.train(False, to_static=False) dygraph_loss = self.train(False, to_static=False)
np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-05) np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-05)
def test_parameter_list_iter(self):
static_loss = self.train(True, to_static=True)
dygraph_loss = self.train(True, to_static=False)
np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-05)
class NetWithRawParamList(paddle.nn.Layer): class NetWithRawParamList(paddle.nn.Layer):
def __init__(self, in_size, out_size): def __init__(self, in_size, out_size):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from test_fetch_feed import Linear from test_fetch_feed import Linear
import paddle import paddle
...@@ -52,6 +53,7 @@ def fake_data(shape): ...@@ -52,6 +53,7 @@ def fake_data(shape):
return fluid.dygraph.to_variable(x_data) return fluid.dygraph.to_variable(x_data)
@dy2static_unittest
class TestWithNestedInput(unittest.TestCase): class TestWithNestedInput(unittest.TestCase):
def setUp(self): def setUp(self):
self.x = None self.x = None
...@@ -88,6 +90,7 @@ class TestWithNestedInput(unittest.TestCase): ...@@ -88,6 +90,7 @@ class TestWithNestedInput(unittest.TestCase):
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
@dy2static_unittest
class TestWithNestedOutput(unittest.TestCase): class TestWithNestedOutput(unittest.TestCase):
def setUp(self): def setUp(self):
self.x = None self.x = None
...@@ -124,10 +127,13 @@ class TestWithNestedOutput(unittest.TestCase): ...@@ -124,10 +127,13 @@ class TestWithNestedOutput(unittest.TestCase):
self.assertTrue(dy_var, st_var) self.assertTrue(dy_var, st_var)
@dy2static_unittest
class TestWithTrainAndEval(unittest.TestCase): class TestWithTrainAndEval(unittest.TestCase):
@ast_only_test
def test_switch_eval_and_train(self): def test_switch_eval_and_train(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
linear_net = Linear() linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net)
x_data = np.random.random((4, 10)).astype('float32') x_data = np.random.random((4, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data) x = fluid.dygraph.to_variable(x_data)
linear_net(x) linear_net(x)
...@@ -154,16 +160,20 @@ class TestWithTrainAndEval(unittest.TestCase): ...@@ -154,16 +160,20 @@ class TestWithTrainAndEval(unittest.TestCase):
) )
@dy2static_unittest
class TestWithNoGrad(unittest.TestCase): class TestWithNoGrad(unittest.TestCase):
@ast_only_test
def test_with_no_grad(self): def test_with_no_grad(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
linear_net = Linear() linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net)
x_data = np.random.random((5, 10)).astype('float32') x_data = np.random.random((5, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data) x = fluid.dygraph.to_variable(x_data)
with paddle.no_grad(): with paddle.no_grad():
linear_net.train() linear_net.train()
linear_net(x) linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1] _, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual( self.assertEqual(
partial_layer.program, partial_layer._train_program partial_layer.program, partial_layer._train_program
...@@ -186,6 +196,7 @@ class GPT2LMHeadModel(paddle.nn.Layer): ...@@ -186,6 +196,7 @@ class GPT2LMHeadModel(paddle.nn.Layer):
return x1 return x1
@dy2static_unittest
class TestPruneUnusedParamInProgram(unittest.TestCase): class TestPruneUnusedParamInProgram(unittest.TestCase):
def test_prune(self): def test_prune(self):
input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32") input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32")
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
import paddle import paddle
...@@ -21,6 +22,7 @@ from paddle.jit.dy2static import partial_program, program_translator ...@@ -21,6 +22,7 @@ from paddle.jit.dy2static import partial_program, program_translator
class TestPartiaProgramLayerHook(unittest.TestCase): class TestPartiaProgramLayerHook(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["ENABLE_FALL_BACK"] = "False"
self._hook = partial_program.PartialProgramLayerHook() self._hook = partial_program.PartialProgramLayerHook()
def test_before_append_backward(self): def test_before_append_backward(self):
...@@ -35,6 +37,7 @@ class TestPartiaProgramLayerHook(unittest.TestCase): ...@@ -35,6 +37,7 @@ class TestPartiaProgramLayerHook(unittest.TestCase):
class TestPrimHook(unittest.TestCase): class TestPrimHook(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["ENABLE_FALL_BACK"] = "False"
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
def f(): def f():
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import astor import astor
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
from ifelse_simple_func import ( from ifelse_simple_func import (
dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return1,
dyfunc_with_if_else_early_return2, dyfunc_with_if_else_early_return2,
...@@ -216,6 +217,7 @@ class TestEnableDeclarative(unittest.TestCase): ...@@ -216,6 +217,7 @@ class TestEnableDeclarative(unittest.TestCase):
self.x = np.random.randn(30, 10, 32).astype('float32') self.x = np.random.randn(30, 10, 32).astype('float32')
self.weight = np.random.randn(32, 64).astype('float32') self.weight = np.random.randn(32, 64).astype('float32')
@ast_only_test
def test_raise_error(self): def test_raise_error(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
...@@ -266,6 +268,7 @@ def switch_mode_function(): ...@@ -266,6 +268,7 @@ def switch_mode_function():
class TestFunctionTrainEvalMode(unittest.TestCase): class TestFunctionTrainEvalMode(unittest.TestCase):
@ast_only_test
def test_switch_mode(self): def test_switch_mode(self):
paddle.disable_static() paddle.disable_static()
switch_mode_function.eval() switch_mode_function.eval()
......
...@@ -133,11 +133,7 @@ class BottleneckBlock(paddle.nn.Layer): ...@@ -133,11 +133,7 @@ class BottleneckBlock(paddle.nn.Layer):
short = self.short(inputs) short = self.short(inputs)
y = paddle.add(x=short, y=conv2) y = paddle.add(x=short, y=conv2)
return paddle.nn.functional.relu(y)
layer_helper = fluid.layer_helper.LayerHelper(
self.full_name(), act='relu'
)
return layer_helper.append_activation(y)
class ResNet(paddle.nn.Layer): class ResNet(paddle.nn.Layer):
......
...@@ -131,10 +131,12 @@ class BottleneckBlock(paddle.nn.Layer): ...@@ -131,10 +131,12 @@ class BottleneckBlock(paddle.nn.Layer):
y = paddle.add(x=short, y=conv2) y = paddle.add(x=short, y=conv2)
layer_helper = paddle.fluid.layer_helper.LayerHelper( # TODO: uncomment this lines to reproduce the oneDNN segment fault error.
self.full_name(), act='relu' # layer_helper = paddle.fluid.layer_helper.LayerHelper(
) # self.full_name(), act='relu'
return layer_helper.append_activation(y) # )
# return layer_helper.append_activation(y)
return paddle.nn.functional.relu(y)
class ResNet(paddle.nn.Layer): class ResNet(paddle.nn.Layer):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
from ifelse_simple_func import dyfunc_with_if_else from ifelse_simple_func import dyfunc_with_if_else
import paddle import paddle
...@@ -349,12 +350,20 @@ class TestReturnInWhile2(TestReturnBase): ...@@ -349,12 +350,20 @@ class TestReturnInWhile2(TestReturnBase):
self.dygraph_func = test_return_in_while_2 self.dygraph_func = test_return_in_while_2
self.error = "Found return statement in While or For body and loop" self.error = "Found return statement in While or For body and loop"
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnInFor2(TestReturnBase): class TestReturnInFor2(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_return_in_for_2 self.dygraph_func = test_return_in_for_2
self.error = "Found return statement in While or For body and loop" self.error = "Found return statement in While or For body and loop"
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestRecursiveReturn(TestReturnBase): class TestRecursiveReturn(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
...@@ -367,12 +376,20 @@ class TestReturnDifferentLengthIfBody(TestReturnBase): ...@@ -367,12 +376,20 @@ class TestReturnDifferentLengthIfBody(TestReturnBase):
self.dygraph_func = test_return_different_length_if_body self.dygraph_func = test_return_different_length_if_body
self.error = "Your if/else have different number of return value." self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnDifferentLengthElse(TestReturnBase): class TestReturnDifferentLengthElse(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_else self.dygraph_func = test_return_different_length_else
self.error = "Your if/else have different number of return value." self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestNoReturn(TestReturnBase): class TestNoReturn(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
...@@ -384,12 +401,20 @@ class TestReturnNone(TestReturnBase): ...@@ -384,12 +401,20 @@ class TestReturnNone(TestReturnBase):
self.dygraph_func = test_return_none self.dygraph_func = test_return_none
self.error = "Your if/else have different number of return value." self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnNoVariable(TestReturnBase): class TestReturnNoVariable(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_return_no_variable self.dygraph_func = test_return_no_variable
self.error = "Your if/else have different number of return value." self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnListOneValue(TestReturnBase): class TestReturnListOneValue(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
from paddle.jit.dy2static.program_translator import StaticFunction from paddle.jit.dy2static.program_translator import StaticFunction
...@@ -88,6 +89,7 @@ class TestRollBackNet(unittest.TestCase): ...@@ -88,6 +89,7 @@ class TestRollBackNet(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.set_device("cpu") paddle.set_device("cpu")
@ast_only_test
def test_net(self): def test_net(self):
net = paddle.jit.to_static(Net()) net = paddle.jit.to_static(Net())
x = paddle.randn([3, 4]) x = paddle.randn([3, 4])
......
...@@ -17,6 +17,7 @@ import tempfile ...@@ -17,6 +17,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -53,6 +54,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -53,6 +54,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.temp_dir.cleanup() self.temp_dir.cleanup()
@ast_only_test
def test_save_inference_model(self): def test_save_inference_model(self):
fc_size = 20 fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32') x_data = np.random.random((fc_size, fc_size)).astype('float32')
...@@ -144,6 +146,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -144,6 +146,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
class TestPartialProgramRaiseError(unittest.TestCase): class TestPartialProgramRaiseError(unittest.TestCase):
@ast_only_test
def test_param_type(self): def test_param_type(self):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
x_data = np.random.random((20, 20)).astype('float32') x_data = np.random.random((20, 20)).astype('float32')
......
...@@ -17,6 +17,7 @@ import tempfile ...@@ -17,6 +17,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
from test_fetch_feed import Linear from test_fetch_feed import Linear
import paddle import paddle
...@@ -114,6 +115,7 @@ class TestDyToStaticSaveLoad(unittest.TestCase): ...@@ -114,6 +115,7 @@ class TestDyToStaticSaveLoad(unittest.TestCase):
dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05 dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05
) )
@ast_only_test
def test_save_load_prim(self): def test_save_load_prim(self):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32") self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
...@@ -154,6 +156,7 @@ class TestDyToStaticSaveLoad(unittest.TestCase): ...@@ -154,6 +156,7 @@ class TestDyToStaticSaveLoad(unittest.TestCase):
self.assertIn("pool2d", load_op_type_list) self.assertIn("pool2d", load_op_type_list)
np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05) np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05)
@ast_only_test
def test_save_load_prim_with_hook(self): def test_save_load_prim_with_hook(self):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32") self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
......
...@@ -20,6 +20,7 @@ import time ...@@ -20,6 +20,7 @@ import time
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -560,6 +561,7 @@ class TestSeResnet(unittest.TestCase): ...@@ -560,6 +561,7 @@ class TestSeResnet(unittest.TestCase):
), ),
) )
@ast_only_test
def test_check_result(self): def test_check_result(self):
pred_1, loss_1, acc1_1, acc5_1 = self.train( pred_1, loss_1, acc1_1, acc5_1 = self.train(
self.train_reader, to_static=False self.train_reader, to_static=False
......
...@@ -17,6 +17,7 @@ import tempfile ...@@ -17,6 +17,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -51,7 +52,6 @@ def test_slice_in_if(x): ...@@ -51,7 +52,6 @@ def test_slice_in_if(x):
return out return out
@paddle.jit.to_static
def test_slice_in_while_loop(x, iter_num=3): def test_slice_in_while_loop(x, iter_num=3):
x = paddle.to_tensor(x) x = paddle.to_tensor(x)
iter_num_var = paddle.full(shape=[1], fill_value=iter_num, dtype="int32") iter_num_var = paddle.full(shape=[1], fill_value=iter_num, dtype="int32")
...@@ -153,7 +153,7 @@ class TestSliceInIf(TestSliceWithoutControlFlow): ...@@ -153,7 +153,7 @@ class TestSliceInIf(TestSliceWithoutControlFlow):
class TestSliceInWhileLoop(TestSliceWithoutControlFlow): class TestSliceInWhileLoop(TestSliceWithoutControlFlow):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_slice_in_while_loop self.dygraph_func = paddle.jit.to_static(test_slice_in_while_loop)
class TestSliceInForLoop(TestSliceWithoutControlFlow): class TestSliceInForLoop(TestSliceWithoutControlFlow):
...@@ -179,6 +179,7 @@ class TestSetValueWithLayerAndSave(unittest.TestCase): ...@@ -179,6 +179,7 @@ class TestSetValueWithLayerAndSave(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.temp_dir.cleanup() self.temp_dir.cleanup()
@ast_only_test
def test_set_value_with_save(self): def test_set_value_with_save(self):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
model = LayerWithSetValue(input_dim=10, hidden=1) model = LayerWithSetValue(input_dim=10, hidden=1)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import unittest import unittest
from dygraph_to_static_util import enable_fallback_guard
import paddle import paddle
from paddle.nn import Layer from paddle.nn import Layer
...@@ -101,4 +103,5 @@ class TestArgsSpecName(unittest.TestCase): ...@@ -101,4 +103,5 @@ class TestArgsSpecName(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() with enable_fallback_guard("False"):
unittest.main()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle import paddle
...@@ -33,6 +34,7 @@ class TestTensorClone(unittest.TestCase): ...@@ -33,6 +34,7 @@ class TestTensorClone(unittest.TestCase):
return tensor_clone(x).numpy() return tensor_clone(x).numpy()
def test_tensor_clone(self): def test_tensor_clone(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False) dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True) static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
...@@ -52,7 +54,9 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase): ...@@ -52,7 +54,9 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase):
y = tensor_numpy(x) y = tensor_numpy(x)
return y.numpy() return y.numpy()
@ast_only_test
def test_to_static_numpy_report_error(self): def test_to_static_numpy_report_error(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False) dygraph_res = self._run(to_static=False)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
static_res = self._run(to_static=True) static_res = self._run(to_static=True)
...@@ -74,6 +78,7 @@ class TestTensorItem(unittest.TestCase): ...@@ -74,6 +78,7 @@ class TestTensorItem(unittest.TestCase):
return tensor_item(x) return tensor_item(x)
def test_tensor_clone(self): def test_tensor_clone(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False) dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True) static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res) np.testing.assert_allclose(dygraph_res, static_res)
...@@ -93,9 +98,13 @@ class TestTensorSize(unittest.TestCase): ...@@ -93,9 +98,13 @@ class TestTensorSize(unittest.TestCase):
x = paddle.ones([1, 2, 3]) x = paddle.ones([1, 2, 3])
if not to_static: if not to_static:
return tensor_size(x) return tensor_size(x)
return tensor_size(x).numpy() ret = tensor_size(x)
if hasattr(ret, 'numpy'):
ret = ret.numpy()
return ret
def test_tensor_clone(self): def test_tensor_clone(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False) dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True) static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5) np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)
...@@ -115,6 +124,7 @@ class TestTrueDiv(unittest.TestCase): ...@@ -115,6 +124,7 @@ class TestTrueDiv(unittest.TestCase):
return true_div(x, y).numpy() return true_div(x, y).numpy()
def test_ture_div(self): def test_ture_div(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False) dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True) static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5) np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -230,6 +231,7 @@ def dyfunc_dict_assign_shape(): ...@@ -230,6 +231,7 @@ def dyfunc_dict_assign_shape():
# 1. Basic tests without control flow # 1. Basic tests without control flow
@dy2static_unittest
class TestTensorShapeBasic(unittest.TestCase): class TestTensorShapeBasic(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.ones(5).astype("int32") self.input = np.ones(5).astype("int32")
...@@ -287,6 +289,7 @@ class TestTensorShapeBasic(unittest.TestCase): ...@@ -287,6 +289,7 @@ class TestTensorShapeBasic(unittest.TestCase):
[op for op in block.ops if op.type == "slice"] [op for op in block.ops if op.type == "slice"]
) )
@ast_only_test
def test_op_num(self): def test_op_num(self):
static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
program = static_layer.main_program program = static_layer.main_program
...@@ -519,6 +522,7 @@ class TestOpNumBasicWithTensorShape(unittest.TestCase): ...@@ -519,6 +522,7 @@ class TestOpNumBasicWithTensorShape(unittest.TestCase):
[op for op in block.ops if op.type == "slice"] [op for op in block.ops if op.type == "slice"]
) )
@ast_only_test
def test_op_num(self): def test_op_num(self):
static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
program = static_layer.main_program program = static_layer.main_program
...@@ -609,6 +613,7 @@ def dyfunc_with_static_convert_var_shape(x): ...@@ -609,6 +613,7 @@ def dyfunc_with_static_convert_var_shape(x):
class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
@ast_only_test
def test(self): def test(self):
x_spec = paddle.static.InputSpec(shape=[None, 10]) x_spec = paddle.static.InputSpec(shape=[None, 10])
func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec]) func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec])
......
...@@ -15,6 +15,11 @@ ...@@ -15,6 +15,11 @@
import unittest import unittest
import numpy import numpy
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
)
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -95,6 +100,7 @@ def case8(x): ...@@ -95,6 +100,7 @@ def case8(x):
return a return a
@dy2static_unittest
class TestToTensorReturnVal(unittest.TestCase): class TestToTensorReturnVal(unittest.TestCase):
def test_to_tensor_badreturn(self): def test_to_tensor_badreturn(self):
paddle.disable_static() paddle.disable_static()
...@@ -148,6 +154,7 @@ class TestToTensorReturnVal(unittest.TestCase): ...@@ -148,6 +154,7 @@ class TestToTensorReturnVal(unittest.TestCase):
self.assertTrue(a.stop_gradient == b.stop_gradient) self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place)) self.assertTrue(a.place._equals(b.place))
@ast_only_test
def test_to_tensor_err_log(self): def test_to_tensor_err_log(self):
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor([3]) x = paddle.to_tensor([3])
...@@ -159,6 +166,18 @@ class TestToTensorReturnVal(unittest.TestCase): ...@@ -159,6 +166,18 @@ class TestToTensorReturnVal(unittest.TestCase):
in str(e) in str(e)
) )
@sot_only_test
def test_to_tensor_err_log_sot(self):
paddle.disable_static()
x = paddle.to_tensor([3])
try:
a = paddle.jit.to_static(case8)(x)
except Exception as e:
self.assertTrue(
"Can't constructs a 'paddle.Tensor' with data type <class 'dict'>"
in str(e)
)
class TestStatic(unittest.TestCase): class TestStatic(unittest.TestCase):
def test_static(self): def test_static(self):
......
...@@ -17,6 +17,7 @@ import unittest ...@@ -17,6 +17,7 @@ import unittest
from functools import partial from functools import partial
import numpy as np import numpy as np
from dygraph_to_static_util import enable_fallback_guard
import paddle import paddle
...@@ -433,4 +434,5 @@ class TestTrainStepTinyModelLRCyclicLR(TestTrainStepTinyModel): ...@@ -433,4 +434,5 @@ class TestTrainStepTinyModelLRCyclicLR(TestTrainStepTinyModel):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() with enable_fallback_guard("False"):
unittest.main()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import platform import platform
import unittest import unittest
from dygraph_to_static_util import enable_fallback_guard
from test_train_step import ( from test_train_step import (
TestTrainStepTinyModel, TestTrainStepTinyModel,
loss_fn_tiny_model, loss_fn_tiny_model,
...@@ -40,4 +41,5 @@ class TestTrainStepResNet18Adam(TestTrainStepTinyModel): ...@@ -40,4 +41,5 @@ class TestTrainStepResNet18Adam(TestTrainStepTinyModel):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() with enable_fallback_guard("False"):
unittest.main()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import platform import platform
import unittest import unittest
from dygraph_to_static_util import enable_fallback_guard
from test_train_step import ( from test_train_step import (
TestTrainStepTinyModel, TestTrainStepTinyModel,
loss_fn_tiny_model, loss_fn_tiny_model,
...@@ -40,4 +41,5 @@ class TestTrainStepResNet18Sgd(TestTrainStepTinyModel): ...@@ -40,4 +41,5 @@ class TestTrainStepResNet18Sgd(TestTrainStepTinyModel):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() with enable_fallback_guard("False"):
unittest.main()
...@@ -45,7 +45,9 @@ def parse_args(): ...@@ -45,7 +45,9 @@ def parse_args():
default=fluid.is_compiled_with_cuda(), default=fluid.is_compiled_with_cuda(),
help='default use gpu.', help='default use gpu.',
) )
args = parser.parse_args(['--config', 'tsm.yaml']) args = parser.parse_args(
['--config', __file__.rpartition('/')[0] + '/tsm.yaml']
)
return args return args
......
...@@ -17,6 +17,7 @@ import unittest ...@@ -17,6 +17,7 @@ import unittest
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle import paddle
...@@ -68,6 +69,7 @@ class LinearNetWithDict(BaseLayer): ...@@ -68,6 +69,7 @@ class LinearNetWithDict(BaseLayer):
return {'out': out2} return {'out': out2}
@dy2static_unittest
class TestTyping(unittest.TestCase): class TestTyping(unittest.TestCase):
def setUp(self): def setUp(self):
self.in_num = 16 self.in_num = 16
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import unittest import unittest
import warnings import warnings
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
from paddle.static.nn import cond from paddle.static.nn import cond
...@@ -37,12 +39,14 @@ def false_fn(): ...@@ -37,12 +39,14 @@ def false_fn():
return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]] return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]]
@dy2static_unittest
class TestReturnNoneInIfelse(unittest.TestCase): class TestReturnNoneInIfelse(unittest.TestCase):
@ast_only_test
def test_dy2static_warning(self): def test_dy2static_warning(self):
paddle.disable_static() paddle.disable_static()
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
fun1() paddle.jit.to_static(fun1)()
flag = False flag = False
for warn in w: for warn in w:
if ( if (
......
...@@ -14,6 +14,12 @@ ...@@ -14,6 +14,12 @@
import unittest import unittest
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
)
import paddle import paddle
...@@ -93,6 +99,7 @@ def func_ifelse_write_nest_list_dict(x): ...@@ -93,6 +99,7 @@ def func_ifelse_write_nest_list_dict(x):
return res return res
@dy2static_unittest
class TestWriteContainer(unittest.TestCase): class TestWriteContainer(unittest.TestCase):
def setUp(self): def setUp(self):
self.set_func() self.set_func()
...@@ -110,6 +117,15 @@ class TestWriteContainer(unittest.TestCase): ...@@ -110,6 +117,15 @@ class TestWriteContainer(unittest.TestCase):
out = out[path] out = out[path]
return out return out
@sot_only_test
def test_write_container_sot(self):
func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3])
out_static = self.get_raw_value(func_static(input), self.getitem_path)
out_dygraph = self.get_raw_value(self.func(input), self.getitem_path)
self.assertEqual(out_static, out_dygraph)
@ast_only_test
def test_write_container(self): def test_write_container(self):
func_static = paddle.jit.to_static(self.func) func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3]) input = paddle.to_tensor([1, 2, 3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册