未验证 提交 c1913a5f 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] PaddleSOT pr (#54202)

* add paddle-symbolic-trace to paddle

* add symoblic trace

* delete swp

* support Layer in symbolic trace

* fix test-symbolic-trace, make symbolic trace return a StaticFunction

* template the error message

* fix some unittest

* Modify the execution mode of test

* Modify the module name

* add dy2static unittest decorator

* change some unittest files by @ast_only_test

* fix unittest.

* test-symbolic-trace

* update test_write_python_container.py

* update

* fix test_param_parse.py

* add submodule and ln -sf in cmakefile

* update

* update

* fix some ast only errors

* update

* Polish ut

* fix unittests

* update

* update

* fix unittests

* update

* test warning ast only

* update

* Ast only some uts

* Fix unitests

* test_error ast only

* update

* update

* Support build_strategy for sot

* update

* import sot as a third party module

* update

* update

* Polish code

* update

* update

* update

* update

* update

* remove old fluid api and use paddle.nn.relu instead

* fix

* comment the print of ast code

* add try-finally block

* fix dy2static stop-gradient bugs

* fix code

* remove unused submodule and minor codestyle fix

* fix

* fix cast error

* fix interpolate meets int64 in static model

* add evalframe support for py311

* fix

* fix err

* switch ENABLE_FALL_BACK=False

* fix

* Fix CI for some unittest

* add ENABLE_SOT

* remove setup.py dependences

---------
Co-authored-by: NNotHaozi <zhangmenghao@baidu.com>
Co-authored-by: Nfeifei-111 <2364819892@qq.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
Co-authored-by: NSigureMo <sigure.qaq@gmail.com>
上级 fc177d25
......@@ -254,7 +254,10 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
if (disable_eval_frame != Py_True) {
// Re-enable custom behavior
eval_frame_callback_set(callback);
VLOG(7) << "Start eval new frame and code.";
auto out = eval_custom_code(tstate, frame, code, throw_flag);
Py_DECREF(result);
Py_DECREF(code);
return out;
} else {
auto out = eval_custom_code(tstate, frame, code, throw_flag);
......
......@@ -106,15 +106,17 @@ def program_desc_tracing_guard(enable):
def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers
if in_declarative_mode() and not paddle.in_dynamic_mode() and parameters:
origin_parameters = parameters.copy()
for name, var_base in parameters.items():
if isinstance(var_base, list):
new_var = [_convert_into_variable(var) for var in var_base]
else:
new_var = _convert_into_variable(var_base)
parameters[name] = new_var
yield
parameters.update(origin_parameters)
try:
origin_parameters = parameters.copy()
for name, var_base in parameters.items():
if isinstance(var_base, list):
new_var = [_convert_into_variable(var) for var in var_base]
else:
new_var = _convert_into_variable(var_base)
parameters[name] = new_var
yield
finally:
parameters.update(origin_parameters)
else:
yield
......
......@@ -25,6 +25,7 @@ from collections import OrderedDict
import inspect
import threading
from typing import Any
import types
import paddle
from paddle.fluid import core, dygraph
......@@ -46,6 +47,8 @@ from .dy2static.convert_call_func import (
from .dy2static.program_translator import (
ProgramTranslator,
StaticFunction,
ASTStaticFunction,
SymbolicStaticFunction,
unwrap_decorators,
)
from paddle.jit.translated_layer import (
......@@ -232,6 +235,7 @@ def to_static(
input_spec=None,
build_strategy=None,
backend=None,
enable_fallback=None,
**kwargs,
):
"""
......@@ -283,15 +287,29 @@ def to_static(
def decorated(python_func):
"""
Decorates a python function into a StaticFunction object.
Decorates a python function into a ASTStaticFunction object.
"""
nonlocal enable_fallback
if enable_fallback is None:
flag = os.environ.get("ENABLE_FALL_BACK", None)
if flag == "True":
enable_fallback = True
else: # None or True
enable_fallback = False
StaticClass = StaticFunctionClass = {
True: SymbolicStaticFunction,
False: ASTStaticFunction,
}[enable_fallback]
# Step 1. unwrap the function if it is already decorated.
_, python_func = unwrap_decorators(python_func)
# Step 2. copy some attributes from original python function.
static_layer = copy_decorator_attrs(
original_func=python_func,
decorated_obj=StaticFunction(
decorated_obj=StaticClass(
function=python_func,
input_spec=input_spec,
build_strategy=build_strategy,
......@@ -1033,7 +1051,9 @@ def save(layer, path, input_spec=None, **configs):
concrete_program = None
for attr_func in functions:
if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None)
static_func = get_ast_static_function(
getattr(inner_layer, attr_func, None)
)
if isinstance(static_func, StaticFunction):
if static_func.is_property:
# property method to be exported
......@@ -1066,7 +1086,9 @@ def save(layer, path, input_spec=None, **configs):
input_spec, inner_input_spec
)
static_forward = to_static(
inner_layer.forward, input_spec=inner_input_spec
inner_layer.forward,
input_spec=inner_input_spec,
enable_fallback=False,
)
concrete_program = (
static_forward.concrete_program_specify_input_spec(
......@@ -1082,24 +1104,29 @@ def save(layer, path, input_spec=None, **configs):
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
if attr_func.is_property:
static_func = get_ast_static_function(attr_func)
if static_func.is_property:
# property method to be exported
immediate_val = attr_func()
property_vals.append((immediate_val, attr_func))
immediate_val = static_func()
property_vals.append((immediate_val, static_func))
continue
concrete_program = (
attr_func.concrete_program_specify_input_spec(
static_func.concrete_program_specify_input_spec(
inner_input_spec, is_prim_infer=is_prim_infer
)
)
else:
static_func = get_ast_static_function(attr_func)
if inner_input_spec:
inner_input_spec = paddle.utils.pack_sequence_as(
input_spec, inner_input_spec
)
static_function = to_static(
attr_func, input_spec=inner_input_spec
static_func,
input_spec=inner_input_spec,
enable_fallback=False,
)
concrete_program = static_function.concrete_program
......@@ -1115,9 +1142,9 @@ def save(layer, path, input_spec=None, **configs):
if isinstance(inner_layer, Layer):
dygraph_state_dict = inner_layer.to_static_state_dict()
elif isinstance(attr_func, StaticFunction):
if attr_func._class_instance:
if static_func._class_instance:
dygraph_state_dict = (
attr_func._class_instance.to_static_state_dict()
static_func._class_instance.to_static_state_dict()
)
if dygraph_state_dict:
......@@ -1887,3 +1914,29 @@ class TracedLayer:
clip_extra=clip_extra,
legacy_format=legacy_format,
)
def get_ast_static_function(function):
if isinstance(function, SymbolicStaticFunction):
if function._class_instance:
dygraph_function = types.MethodType(
function._dygraph_function, function._class_instance
)
else:
dygraph_function = function._dygraph_function
if function._function_spec._input_spec is None:
ast_static_function = ASTStaticFunction(
dygraph_function,
function.last_call_input_spec,
**function._kwargs,
)
return ast_static_function
else:
ast_static_function = ASTStaticFunction(
dygraph_function,
function._function_spec._input_spec,
**function._kwargs,
)
return ast_static_function
return function
......@@ -231,6 +231,7 @@ class PartialProgramLayer:
self._cuda_graph_vec,
*attrs
)
self._update_stop_gradient(out_vars)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
......@@ -960,6 +961,17 @@ class PartialProgramLayer:
var.stop_gradient = True
return var
def _update_stop_gradient(self, out_vars):
# Update stop_gradient for all outputs
def set_stop_gradient(var_id, eager_tensor):
var = self._outputs[var_id]
assert isinstance(var, framework.Variable)
eager_tensor.stop_gradient = var.stop_gradient
return None
for idx, var in zip(self._outputs.var_ids, out_vars):
set_stop_gradient(idx, var)
def _restore_out(self, out_vars):
"""
Restores same nested outputs by only replacing the Variable with Tensor.
......
......@@ -309,11 +309,6 @@ def unwrap_decorators(func):
class StaticFunction:
"""
Wrapper class to Manage program conversion of decorated function.
"""
def __init__(self, function, input_spec=None, **kwargs):
"""
Initializes a `StaticFunction`.
......@@ -364,7 +359,6 @@ class StaticFunction:
self._training = True
self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0
self._property = kwargs.get("property", False)
@property
......@@ -473,42 +467,7 @@ class StaticFunction:
)
)
# 2. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs, is_train=self._is_train_mode()
)
# 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training
else:
partial_program_layer.training = self._training
partial_program_layer._cuda_graph_capture_mode = (
self._cuda_graph_capture_mode
)
partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id
# 4. return outputs.
try:
return partial_program_layer(args)
except Exception as e:
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except Exception as e:
error_data = getattr(e, error.ERROR_DATA, None)
if error_data:
error_data.raise_new_exception()
else:
logging_utils.warn(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself.".format(type(e))
)
raise e
return self._perform_call(*args, **kwargs)
def _is_train_mode(self):
if self._class_instance is not None:
......@@ -543,6 +502,298 @@ class StaticFunction:
if self.is_property:
raise RuntimeError("Can not call the func when property=True.")
def get_concrete_program(self, *args, **kwargs):
raise NotImplementedError("Not implemented yet.")
def get_concrete_program_with_cache_key(self, cached_key):
raise NotImplementedError("Not implemented yet.")
def get_traced_count(self):
raise NotImplementedError("Not implemented yet.")
@property
def code(self):
raise NotImplementedError("Not implemented yet.")
@property
def dygraph_function(self):
"""
Returns the original decorated function.
"""
if self._class_instance is not None:
return self._dygraph_function.__get__(self._class_instance)
else:
return self._dygraph_function
@property
def concrete_program(self):
raise NotImplementedError("Not implemented yet.")
def concrete_program_specify_input_spec(
self, input_spec=None, with_hook=False, is_prim_infer=False
):
raise NotImplementedError("Not implemented yet.")
def rollback(self):
"""
Rollback into original dygraph functions for current class instance.
Returns:
Function or Method
Example::
.. code-block:: python
>>> # doctest: +SKIP
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> out = net(x)
>>> net.forward.rollback() # rollback into dygraph mode
>>> out = net(x)
"""
def rollback_impl(class_instance):
for name, func in class_instance._original_funcs.items():
setattr(class_instance, name, func.__get__(class_instance))
for sublayer in class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
if self._class_instance is None:
return self._dygraph_function
# only rollback sub-functions on path of top _dygraph_function
func_name = self._dygraph_function.__name__
assert (
func_name in self._class_instance._original_funcs
), "Not Found function '{}' in class '{}'.".format(
func_name, self._class_instance.__name__
)
func = self._class_instance._original_funcs[func_name]
setattr(
self._class_instance, func_name, func.__get__(self._class_instance)
)
for sublayer in self._class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
return getattr(self._class_instance, func_name)
def __deepcopy__(self, memo):
"""
Customized behavior for copy.deepcopy, return original decorated function instead
of a new StaticFunction Object. StaticFunction itself is not copyable becuase it's
associated with class_instance.
We add __deepcopy__ here only for the following usage:
Example::
.. code-block:: python
>>> import copy
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> copy_net = copy.deepcopy(net) # deepcopy a new net without @to_static
Please attention that original 'net' will unwrap @to_static and rollback into simple Layer.
"""
if self._class_instance is not None:
net_name = type(self._class_instance).__name__
logging_utils.log(
level=-1,
msg="Not recommend to deepcopy '{}' decorated with @to_static, it has side effect that will"
" rollback into original state before @to_static. Please deepcopy '{}' before applying @to_static.".format(
net_name, net_name
),
)
self.rollback()
return self._dygraph_function.__get__(
memo[id(self._class_instance)]
)
else:
return self._dygraph_function
@property
def inputs(self):
raise NotImplementedError("Not implemented yet.")
@property
def outputs(self):
raise NotImplementedError("Not implemented yet.")
@property
def main_program(self):
raise NotImplementedError("Not implemented yet.")
@property
def program_cache(self):
raise NotImplementedError("Not implemented yet.")
@property
def function_spec(self):
raise NotImplementedError("Not implemented yet.")
def raise_error_template(func_str):
def _raise_error(*args, **kwargs):
error_template = (
"Can't call {func} when enable_fallback=True."
"Use paddle.jit.to_static(enable_fallback=False) instead."
)
raise RuntimeError(error_template.format(func=func_str))
return _raise_error
class SymbolicStaticFunction(StaticFunction):
def __init__(self, function, input_spec=None, **kwargs):
if input_spec is not None:
warnings.warn(
"\nSymbolic Trace don't support input_spec arguments. It will Will not produce any effect.\n"
"1. You can disable fallback mode by `paddle.jit.to_static(enable_fallback=False)` to switch to AST to static, then you can assign input spec.\n"
)
super().__init__(function, input_spec, **kwargs)
self.last_call_input_spec = None
def _perform_call(self, *args, **kwargs):
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
(
input_args_with_spec,
input_kwargs_with_spec,
) = self._function_spec.args_to_input_spec(args, kwargs)
self.last_call_input_spec = input_args_with_spec
try:
from sot import symbolic_translate
except:
import os
os.system(
"pip install git+https://github.com/PaddlePaddle/PaddleSOT@develop"
)
from sot import symbolic_translate
build_strategy = self._kwargs.get("build_strategy", None)
traced_fun = symbolic_translate(
self._dygraph_function, build_strategy=build_strategy
)
if self._class_instance is not None:
args = (self._class_instance,) + args
return traced_fun(*args, **kwargs)
@property
def code(self):
raise_error_template("code")()
@property
def concrete_program(self):
raise_error_template("concrete_program")()
concrete_program_specify_input_spec = raise_error_template(
"concrete_program_specify_input_spec"
)
get_concrete_program = raise_error_template("get_concrete_program")
get_concrete_program_with_cache_key = raise_error_template(
"get_concrete_program_with_cache_key"
)
get_traced_count = raise_error_template("get_traced_count")
@property
def inputs(self):
raise_error_template("inputs")()
@property
def outputs(self):
raise_error_template("outputs")()
@property
def main_program(self):
raise_error_template("main_program")()
@property
def program_cache(self):
raise_error_template("program_cache")()
@property
def function_spec(self):
raise_error_template("function_spec ")()
class ASTStaticFunction(StaticFunction):
"""
Wrapper class to Manage program conversion of decorated function.
"""
def __init__(self, function, input_spec=None, **kwargs):
super().__init__(function, input_spec, **kwargs)
def _perform_call(self, *args, **kwargs):
# 1. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs, is_train=self._is_train_mode()
)
# 2. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training
else:
partial_program_layer.training = self._training
partial_program_layer._cuda_graph_capture_mode = (
self._cuda_graph_capture_mode
)
partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id
# 3. return outputs.
try:
return partial_program_layer(args)
except Exception as e:
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except Exception as e:
error_data = getattr(e, error.ERROR_DATA, None)
if error_data:
error_data.raise_new_exception()
else:
logging_utils.warn(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself.".format(type(e))
)
raise e
def get_concrete_program(self, *args, **kwargs):
"""
Returns traced concrete program and inner executable partial layer.
......@@ -629,16 +880,6 @@ class StaticFunction:
source_code = func_to_source_code(static_func)
return source_code
@property
def dygraph_function(self):
"""
Returns the original decorated function.
"""
if self._class_instance is not None:
return self._dygraph_function.__get__(self._class_instance)
else:
return self._dygraph_function
@property
def concrete_program(self):
"""
......@@ -757,113 +998,6 @@ class StaticFunction:
)
return concrete_program
def rollback(self):
"""
Rollback into original dygraph functions for current class instance.
Returns:
Function or Method
Example::
.. code-block:: python
>>> # doctest: +SKIP
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> out = net(x)
>>> net.forward.rollback() # rollback into dygraph mode
>>> out = net(x)
"""
def rollback_impl(class_instance):
for name, func in class_instance._original_funcs.items():
setattr(class_instance, name, func.__get__(class_instance))
for sublayer in class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
if self._class_instance is None:
return self._dygraph_function
# only rollback sub-functions on path of top _dygraph_function
func_name = self._dygraph_function.__name__
assert (
func_name in self._class_instance._original_funcs
), "Not Found function '{}' in class '{}'.".format(
func_name, self._class_instance.__name__
)
func = self._class_instance._original_funcs[func_name]
setattr(
self._class_instance, func_name, func.__get__(self._class_instance)
)
for sublayer in self._class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
return getattr(self._class_instance, func_name)
def __deepcopy__(self, memo):
"""
Customized behavior for copy.deepcopy, return original decorated function instead
of a new StaticFunction Object. StaticFunction itself is not copyable becuase it's
associated with class_instance.
We add __deepcopy__ here only for the following usage:
Example::
.. code-block:: python
>>> import copy
>>> import paddle
>>> class Net(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, x, flag=True):
... if flag:
... out = x + 1
... else:
... out = x - 1
... return out
...
>>> x = paddle.randn([10, 1], 'float32')
>>> net = paddle.jit.to_static(Net()) # convert into static graph mode
>>> copy_net = copy.deepcopy(net) # deepcopy a new net without @to_static
Please attention that original 'net' will unwrap @to_static and rollback into simple Layer.
"""
if self._class_instance is not None:
net_name = type(self._class_instance).__name__
logging_utils.log(
level=-1,
msg="Not recommend to deepcopy '{}' decorated with @to_static, it has side effect that will"
" rollback into original state before @to_static. Please deepcopy '{}' before applying @to_static.".format(
net_name, net_name
),
)
self.rollback()
return self._dygraph_function.__get__(
memo[id(self._class_instance)]
)
else:
return self._dygraph_function
@property
def inputs(self):
"""
......
......@@ -404,6 +404,7 @@ def interpolate(
)
if isinstance(size, Variable):
size = size.cast("int32") # static mode only support int32
if size.ndim != 1:
raise ValueError(
f"If size is a tensor, it's rank must be 1, but received {size.ndim}."
......
......@@ -133,7 +133,7 @@ def shape(input):
outputs={'Out': out},
stop_gradient=True,
)
out.stop_gradient = True
return out
......
......@@ -178,9 +178,9 @@ def cast(x, dtype):
x = paddle.to_tensor([2, 3, 4], 'float64')
y = paddle.cast(x, 'uint8')
"""
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dynamic_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return _C_ops.cast(x, dtype)
else:
check_variable_and_dtype(
......@@ -2038,7 +2038,7 @@ def split(x, num_or_sections, axis=0, name=None):
attrs['axis'] = dim
if isinstance(num_or_sections, int):
assert num_or_sections > 1, 'num_or_sections must be more than 1.'
assert num_or_sections > 0, 'num_or_sections must be than 0.'
if isinstance(dim, int) and input_shape[dim] > 0:
assert input_shape[dim] % num_or_sections == 0, (
"The input's size along the split dimension "
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
from functools import wraps
import numpy as np
from paddle import set_flags, static
from paddle.fluid import core
@contextlib.contextmanager
def enable_fallback_guard(enable):
flag = os.environ.get("ENABLE_FALL_BACK", None)
os.environ["ENABLE_FALL_BACK"] = enable
yield
if flag is not None:
os.environ["ENABLE_FALL_BACK"] = flag
else:
del os.environ["ENABLE_FALL_BACK"]
def to_ast(func):
"""
convet run fall_back to ast
"""
def impl(*args, **kwargs):
with enable_fallback_guard("False"):
func(*args, **kwargs)
return impl
def to_sot(func):
"""
convet run fall_back to ast
"""
enable_sot = os.environ.get("ENABLE_SOT", "False") == "True"
def impl(*args, **kwargs):
if enable_sot:
with enable_fallback_guard("True"):
func(*args, **kwargs)
else:
return
return impl
def dy2static_unittest(cls):
"""
dy2static unittest must be decorated to each Dy2static Unittests.
run both in Fallback and Ast mode.
Usage like:
@dy2static_unittest
class TestA (unittest.TestCase):
...
"""
for key in dir(cls):
if key.startswith("test"):
if not key.endswith("_ast"):
test_func = getattr(cls, key)
setattr(cls, key + "_ast", to_ast(test_func))
test_func = getattr(cls, key)
setattr(cls, key, to_sot(test_func))
return cls
def ast_only_test(func):
"""
run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase):
@ast_only_test
def test_ast_only(self):
pass
"""
def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "False":
func(*args, **kwargs)
return impl
def sot_only_test(func):
"""
run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase):
@ast_only_test
def test_ast_only(self):
pass
"""
def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "True":
func(*args, **kwargs)
return impl
def test_with_new_ir(func):
@wraps(func)
def impl(*args, **kwargs):
ir_outs = None
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
try:
new_ir_flag = 'FLAGS_enable_new_ir_in_executor'
os.environ[new_ir_flag] = 'True'
set_flags({new_ir_flag: True})
ir_outs = func(*args, **kwargs)
finally:
del os.environ[new_ir_flag]
set_flags({new_ir_flag: False})
return ir_outs
return impl
def test_and_compare_with_new_ir(need_check_output: bool = True):
def decorator(func):
@wraps(func)
def impl(*args, **kwargs):
outs = func(*args, **kwargs)
if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled():
return outs
# only run in CI-Coverage
if os.environ.get('FLAGS_NEW_IR_DY2ST_TEST', None) is None:
return outs
ir_outs = test_with_new_ir(func)(*args, **kwargs)
if not need_check_output:
return outs
for i in range(len(outs)):
np.testing.assert_array_equal(
outs[i],
ir_outs[i],
err_msg='Dy2St Unittest Check ('
+ func.__name__
+ ') has diff '
+ '\nExpect '
+ str(outs[i])
+ '\n'
+ 'But Got'
+ str(ir_outs[i]),
)
return outs
return impl
return decorator
......@@ -297,7 +297,7 @@ class BaseModel(paddle.nn.Layer):
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
)
loss = paddle.squeeze(loss, axes=[2])
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
......@@ -835,13 +835,13 @@ class AttentionModel(paddle.nn.Layer):
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
)
loss = paddle.squeeze(loss, axes=[2])
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
)
loss = loss * tar_mask
loss = paddle.mean(loss, axis=[0])
loss = fluid.layers.reduce_sum(loss)
loss = paddle.sum(loss)
return loss
......@@ -50,6 +50,7 @@ class TestAST2Func(unittest.TestCase):
self.assertEqual(func(x, y), self._ast2func(func)(x, y))
def test_ast2func_dygraph(self):
paddle.disable_static()
funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else]
x_data = np.random.random([10, 16]).astype('float32')
for func in funcs:
......@@ -60,6 +61,8 @@ class TestAST2Func(unittest.TestCase):
self.assertTrue((true_ret == test_ret).all())
def test_ast2func_static(self):
paddle.enable_static()
def func(x):
y = F.relu(x)
loss = paddle.mean(y)
......
......@@ -20,6 +20,7 @@ import unittest
import numpy as np
from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools
import paddle
......@@ -262,6 +263,7 @@ class TestBert(unittest.TestCase):
out = output()
return out
@ast_only_test
def test_train(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
......
......@@ -18,6 +18,7 @@ import tempfile
import unittest
import numpy as np
from dygraph_to_static_util import dy2static_unittest
from predictor_utils import PredictorTools
import paddle
......@@ -636,6 +637,7 @@ def val_bmn(model, args):
return loss_data
@dy2static_unittest
class TestTrain(unittest.TestCase):
def setUp(self):
self.args = Args()
......@@ -666,6 +668,7 @@ class TestTrain(unittest.TestCase):
local_random = np.random.RandomState(SEED)
bmn = BMN(args)
bmn = paddle.jit.to_static(bmn)
adam = optimizer(args, parameter_list=bmn.parameters())
train_reader = fake_data_reader(args, 'train')
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
from paddle import fluid
......@@ -25,12 +26,14 @@ SEED = 2020
np.random.seed(SEED)
@dy2static_unittest
class TestDy2staticException(unittest.TestCase):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None
self.error = "Your if/else have different number of return value."
@ast_only_test
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
......
......@@ -15,11 +15,13 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from test_resnet import ResNetHelper
import paddle
@dy2static_unittest
class TestResnetWithPass(unittest.TestCase):
def setUp(self):
self.build_strategy = paddle.static.BuildStrategy()
......@@ -64,6 +66,7 @@ class TestResnetWithPass(unittest.TestCase):
),
)
@ast_only_test
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
......@@ -77,6 +80,7 @@ class TestResnetWithPass(unittest.TestCase):
)
self.verify_predict()
@ast_only_test
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
......
......@@ -16,6 +16,7 @@ import unittest
from collections import Counter
import numpy as np
from dygraph_to_static_util import dy2static_unittest
from test_fetch_feed import Linear, Pool2D
import paddle
......@@ -24,6 +25,7 @@ from paddle.jit.api import to_static
from paddle.jit.dy2static import convert_to_static
@dy2static_unittest
class TestCacheProgram(unittest.TestCase):
def setUp(self):
self.batch_num = 5
......@@ -36,7 +38,7 @@ class TestCacheProgram(unittest.TestCase):
with fluid.dygraph.guard(fluid.CPUPlace()):
static_net = self.dygraph_class()
for batch_id in range(self.batch_num):
out = static_net(self.data)
out = static_net(paddle.to_tensor(self.data))
# Check outputs
prev_out = cur_out
cur_out = out
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from dy2st_test_utils import test_and_compare_with_new_ir
from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir
from paddle import fluid
from paddle.jit.api import to_static
......@@ -38,7 +38,6 @@ def test_int_cast(x):
return x
@to_static
def test_float_cast(x):
x = fluid.dygraph.to_variable(x)
x = float(x)
......@@ -89,6 +88,7 @@ class TestCastBase(unittest.TestCase):
res = self.func(self.input)
return res
@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False)
def test_cast_result(self):
res = self.do_test().numpy()
......@@ -136,7 +136,7 @@ class TestFloatCast(TestCastBase):
self.cast_dtype = 'float32'
def set_func(self):
self.func = test_float_cast
self.func = to_static(test_float_cast)
class TestMixCast(TestCastBase):
......@@ -156,6 +156,7 @@ class TestMixCast(TestCastBase):
def set_func(self):
self.func = test_mix_cast
@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False)
def test_cast_result(self):
res = self.do_test().numpy()
......@@ -189,6 +190,7 @@ class TestNotVarCast(TestCastBase):
def set_func(self):
self.func = test_not_var_cast
@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False)
def test_cast_result(self):
res = self.do_test()
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
import paddle.nn.functional as F
......@@ -38,6 +39,7 @@ class PrimeNet(paddle.nn.Layer):
return out
@dy2static_unittest
class TestPrimForward(unittest.TestCase):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
......@@ -88,6 +90,7 @@ class TestPrimForward(unittest.TestCase):
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
@ast_only_test
def test_cinn_prim_forward(self):
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
......@@ -98,6 +101,7 @@ class TestPrimForward(unittest.TestCase):
)
@dy2static_unittest
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
......@@ -153,6 +157,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
if op != "matmul_v2_grad":
self.assertTrue("_grad" not in op)
@ast_only_test
def test_cinn_prim(self):
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
import paddle.nn.functional as F
......@@ -52,6 +53,7 @@ class PrimeNet(paddle.nn.Layer):
return out
@dy2static_unittest
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
......@@ -104,6 +106,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that gelu is splitted into small ops
self.assertTrue('gelu' not in fwd_ops)
@ast_only_test
def test_cinn_prim(self):
for shape in self.shapes:
for dtype in self.dtypes:
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
import paddle.nn.functional as F
......@@ -101,6 +102,7 @@ class TestPrimForward(unittest.TestCase):
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)
@ast_only_test
def test_cinn_prim_forward(self):
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu":
......@@ -168,6 +170,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)
@ast_only_test
def test_cinn_prim(self):
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu":
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
from paddle import tensor
......@@ -54,6 +55,7 @@ class PrimeNet(
return out
@dy2static_unittest
class TestPrimForward(unittest.TestCase):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
......@@ -110,6 +112,7 @@ class TestPrimForward(unittest.TestCase):
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)
@ast_only_test
def test_cinn_prim_forward(self):
for shape in self.shapes:
for dtype in self.dtypes:
......@@ -131,6 +134,7 @@ class TestPrimForward(unittest.TestCase):
)
@dy2static_unittest
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
......@@ -183,6 +187,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)
@ast_only_test
def test_cinn_prim(self):
for shape in self.shapes:
for dtype in self.dtypes:
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import os
import unittest
from numpy import append
......@@ -324,4 +325,5 @@ class TestPushPopTrans(unittest.TestCase):
if __name__ == '__main__':
os.environ['ENABLE_FALL_BACK'] = "False"
unittest.main()
......@@ -17,6 +17,7 @@ import tempfile
import unittest
import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle
......@@ -69,6 +70,7 @@ class NestSequentialNet(paddle.nn.Layer):
return self.layers(x)
@dy2static_unittest
class TestSequential(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
......
......@@ -16,6 +16,7 @@ import logging
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
import paddle.jit.dy2static as _jst
......@@ -252,6 +253,7 @@ class TestNotToConvert(TestRecursiveCall2):
)
@dy2static_unittest
class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.net = NotToStaticHelper()
......@@ -264,7 +266,9 @@ class TestNotToConvert2(TestRecursiveCall2):
self.assertIsNotNone(options)
self.assertTrue(options.not_convert)
@ast_only_test
def test_code(self):
self.dygraph_func = paddle.jit.to_static(self.net.sum)
# check 'if statement' is not converted
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)
......@@ -277,19 +281,23 @@ def forward(self, x):
return x
@dy2static_unittest
class TestConvertPaddleAPI(unittest.TestCase):
@ast_only_test
def test_functional_api(self):
func = paddle.nn.functional.relu
func = paddle.jit.to_static(func)
self.assertNotIn("_jst.IfElse", func.code)
self.assertIn("if in_dynamic_mode()", func.code)
@ast_only_test
def test_class_api(self):
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if in_dynamic_mode()", bn.forward.code)
@ast_only_test
def test_class_patch_api(self):
paddle.nn.SyncBatchNorm.forward = forward
bn = paddle.nn.SyncBatchNorm(2)
......
......@@ -14,6 +14,8 @@
import unittest
from dygraph_to_static_util import ast_only_test
import paddle
from paddle.jit import to_static
from paddle.jit.dy2static.convert_call_func import translator_logger
......@@ -31,6 +33,8 @@ def main_func():
class TestConvertGenerator(unittest.TestCase):
# fallback will ok.
@ast_only_test
def test_raise_error(self):
translator_logger.verbosity_level = 1
with self.assertLogs(
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
......@@ -40,6 +41,8 @@ net.forward = "A string so that convert forward will fail"
class TestConvertCall(unittest.TestCase):
# fallback mode will raise a InnerError, it's ok.
@ast_only_test
def test_class_exception(self):
@paddle.jit.to_static
def call_not_exist():
......
......@@ -15,6 +15,11 @@
import unittest
import numpy as np
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
)
import paddle
......@@ -28,8 +33,8 @@ class TestCpuCuda(unittest.TestCase):
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
print(paddle.jit.to_static(func)(x))
# print(paddle.jit.to_static(func).code)
# print(paddle.jit.to_static(func)(x))
class TestToTensor(unittest.TestCase):
......@@ -41,7 +46,7 @@ class TestToTensor(unittest.TestCase):
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
# print(paddle.jit.to_static(func).code)
np.testing.assert_allclose(
paddle.jit.to_static(func)(x).numpy(),
np.array([1, 2, 3, 4]),
......@@ -49,7 +54,9 @@ class TestToTensor(unittest.TestCase):
)
@dy2static_unittest
class TestToTensor1(unittest.TestCase):
@ast_only_test
def test_to_tensor_with_variable_list(self):
def func(x):
ones = paddle.to_tensor([1])
......@@ -61,28 +68,59 @@ class TestToTensor1(unittest.TestCase):
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
np.testing.assert_allclose(
paddle.jit.to_static(func)(x).numpy(),
np.array([1, 2, 3, 4]),
rtol=1e-05,
)
@sot_only_test
def test_to_tensor_with_variable_list_sot(self):
def func(x):
ones = paddle.to_tensor([1])
twos = paddle.to_tensor([2])
""" we ignore the [3] and [4], they will be assign to a variable, and is regard as scalar.
TODO: deal with this case after 0-dim tensor is developed.
"""
x = paddle.to_tensor([ones, twos, [3], [4]])
return x
x = paddle.to_tensor([3])
np.testing.assert_allclose(
paddle.jit.to_static(func)(x),
np.array([[1], [2], [3], [4]]),
rtol=1e-05,
)
@dy2static_unittest
class TestToTensor2(unittest.TestCase):
@ast_only_test
def test_to_tensor_with_variable_list(self):
def func(x):
x = paddle.to_tensor([[1], [2], [3], [4]])
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
np.testing.assert_allclose(
paddle.jit.to_static(func)(x).numpy(),
np.array([[1], [2], [3], [4]]),
rtol=1e-05,
)
@sot_only_test
def test_to_tensor_with_variable_list_sot(self):
def func(x):
x = paddle.to_tensor([[1], [2], [3], [4]])
return x
x = paddle.to_tensor([3])
np.testing.assert_allclose(
paddle.jit.to_static(func)(x),
np.array([[1], [2], [3], [4]]),
rtol=1e-05,
)
if __name__ == '__main__':
unittest.main()
......@@ -30,6 +30,8 @@ from paddle.jit.dy2static.program_translator import (
from paddle.nn import Layer
from paddle.static import InputSpec
os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only
class SimpleNet(Layer):
def __init__(self):
......
......@@ -19,6 +19,7 @@ from functools import wraps
import decos
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
......@@ -147,7 +148,6 @@ def fun8(x, y=0):
return a
@paddle.jit.to_static
def forward():
funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8]
out = []
......@@ -166,7 +166,6 @@ def fun9():
print('in fun9 want contextmanager warning')
@paddle.jit.to_static
def warn1():
fun9()
......@@ -182,9 +181,10 @@ def deco_with_paddle_api():
return fun10()
@dy2static_unittest
class TestDecoratorTransform(unittest.TestCase):
def test_deco_transform(self):
outs = forward()
outs = paddle.jit.to_static(forward)()
np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05)
np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05)
np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05)
......@@ -194,11 +194,12 @@ class TestDecoratorTransform(unittest.TestCase):
np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05)
@ast_only_test
def test_contextmanager_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn1()
paddle.jit.to_static(warn1)()
flag = False
for warn in w:
if (
......
......@@ -23,6 +23,8 @@ from paddle import fluid
from paddle.jit.dy2static import error
from paddle.jit.dy2static.origin_info import unwrap
os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only
def inner_func():
paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")
......@@ -255,11 +257,11 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
def set_message(self):
self.expected_message = [
'File "{}", line 35, in func_error_in_compile_time'.format(
'File "{}", line 37, in func_error_in_compile_time'.format(
self.filepath
),
'inner_func()',
f'File "{self.filepath}", line 28, in inner_func',
f'File "{self.filepath}", line 30, in inner_func',
'def inner_func():',
'paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
......@@ -286,7 +288,7 @@ class TestErrorStaticLayerCallInCompiletime_2(
def set_message(self):
self.expected_message = [
'File "{}", line 46, in func_error_in_compile_time_2'.format(
'File "{}", line 48, in func_error_in_compile_time_2'.format(
self.filepath
),
'def func_error_in_compile_time_2(x):',
......@@ -312,7 +314,7 @@ class TestErrorStaticLayerCallInCompiletime_3(
def set_message(self):
self.expected_message = [
f'File "{self.filepath}", line 91, in forward',
f'File "{self.filepath}", line 93, in forward',
'@paddle.jit.to_static',
'def forward(self):',
'self.test_func()',
......@@ -336,7 +338,7 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
def set_message(self):
self.expected_message = [
'File "{}", line 54, in func_error_in_runtime'.format(
'File "{}", line 56, in func_error_in_runtime'.format(
self.filepath
),
'x = fluid.dygraph.to_variable(x)',
......@@ -353,7 +355,7 @@ class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime):
def set_message(self):
self.expected_message = [
'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format(
'File "{}", line 108, in func_error_in_runtime_with_empty_line'.format(
self.filepath
),
'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")',
......@@ -376,7 +378,7 @@ class TestJitSaveInCompiletime(TestErrorBase):
def set_message(self):
self.expected_message = [
f'File "{self.filepath}", line 80, in forward',
f'File "{self.filepath}", line 82, in forward',
'def forward(self, x):',
'y = self._linear(x)',
'z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")',
......
......@@ -16,6 +16,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
......@@ -84,6 +85,7 @@ class TestFallback(unittest.TestCase):
u_net(self.x).numpy(),
)
@ast_only_test
def test_case_net_error(self):
s_net = SuppportNet()
u_net = UnsuppportNet()
......
......@@ -53,7 +53,7 @@ class Linear(paddle.nn.Layer):
)
self.act = paddle.nn.ReLU()
@to_static
# @to_static
def forward(self, x):
pre = self.fc(x)
pre = self.act(pre)
......
......@@ -15,10 +15,10 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
from paddle import fluid
from paddle.jit import to_static
@paddle.jit.to_static
......@@ -48,7 +48,7 @@ def decorated_call_decorated(x):
class DoubleDecorated:
@classmethod
@to_static
@paddle.jit.to_static
def double_decorated_func1(self, x):
return dygraph_decorated_func(x)
......@@ -59,6 +59,7 @@ class DoubleDecorated:
class TestFullNameDecorator(unittest.TestCase):
@ast_only_test
def test_run_success(self):
x = np.ones([1, 2]).astype("float32")
answer = np.zeros([1, 2]).astype("float32")
......
......@@ -17,6 +17,7 @@ import tempfile
import unittest
import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle
......@@ -25,7 +26,6 @@ class GradLayer(paddle.nn.Layer):
def __init__(self):
super().__init__()
@paddle.jit.to_static
def forward(self, x):
x.stop_gradient = False
y = x * x
......@@ -38,7 +38,6 @@ class GradLinearLayer(paddle.nn.Layer):
super().__init__()
self.linear = paddle.nn.Linear(5, 5, bias_attr=False)
@paddle.jit.to_static
def forward(self, x):
x.stop_gradient = False
tmp = x + x
......@@ -56,7 +55,6 @@ class NoGradLinearLayer(paddle.nn.Layer):
super().__init__()
self.linear = paddle.nn.Linear(5, 5, bias_attr=False)
@paddle.jit.to_static
def forward(self, x):
x.stop_gradient = False
......@@ -69,7 +67,7 @@ class NoGradLinearLayer(paddle.nn.Layer):
class TestGrad(unittest.TestCase):
def setUp(self):
self.func = GradLayer()
self.func = paddle.jit.to_static(GradLayer())
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False
......@@ -85,9 +83,10 @@ class TestGrad(unittest.TestCase):
np.testing.assert_allclose(static_res, dygraph_res, rtol=1e-05)
@dy2static_unittest
class TestGradLinear(TestGrad):
def setUp(self):
self.func = GradLinearLayer()
self.func = paddle.jit.to_static(GradLinearLayer())
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False
......@@ -103,6 +102,7 @@ class TestGradLinear(TestGrad):
self.temp_dir.cleanup()
def test_save_infer_program(self):
self.setUp() # make self.func change to ast mode
input_spec = [
paddle.static.InputSpec(shape=[10, 2, 5], dtype='float32')
]
......@@ -114,6 +114,7 @@ class TestGradLinear(TestGrad):
np.testing.assert_allclose(origin_res, load_res, rtol=1e-05)
def test_save_train_program(self):
self.setUp() # make self.func change to ast mode
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
optimizer = paddle.optimizer.SGD(
learning_rate=0.01,
......@@ -138,7 +139,7 @@ class TestGradLinear(TestGrad):
class TestNoGradLinear(TestGradLinear):
def setUp(self):
self.func = NoGradLinearLayer()
self.func = paddle.jit.to_static(NoGradLinearLayer())
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False
......
......@@ -16,9 +16,9 @@
import unittest
import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle
from paddle import ParamAttr
from paddle.nn import BatchNorm, Linear
......@@ -28,13 +28,9 @@ class SimpleNet(paddle.nn.Layer):
self.linear0 = Linear(100, 50)
self.linear1 = Linear(50, 10)
param_attr0 = ParamAttr(name="aaaprefix_bn_scale")
bias_attr0 = ParamAttr(name="aaaprefix_bn_offset")
self.bn0 = BatchNorm(50, param_attr=param_attr0, bias_attr=bias_attr0)
self.bn0 = BatchNorm(50)
param_attr1 = ParamAttr(name="bn_scale")
bias_attr1 = ParamAttr(name="bn_offset")
self.bn1 = BatchNorm(10, param_attr=param_attr1, bias_attr=bias_attr1)
self.bn1 = BatchNorm(10)
def forward(self, x):
x1 = self.linear0(x)
......@@ -45,6 +41,7 @@ class SimpleNet(paddle.nn.Layer):
return dx[0]
@dy2static_unittest
class TestGradNameParse(unittest.TestCase):
def test_grad_name_parse(self):
net = SimpleNet()
......@@ -72,6 +69,7 @@ def tanh_high_order_grad(x):
return paddle.grad(y, x, create_graph=True)[0]
@dy2static_unittest
class TestTanhHighOrderGrad(unittest.TestCase):
def setUp(self):
self.func = tanh_high_order_grad
......@@ -116,10 +114,11 @@ class TestTanhHighOrderGrad(unittest.TestCase):
def matmul_high_order_grad(x, y):
z = paddle.matmul(x, y)
g = paddle.grad(z, [x, y], create_graph=False)
g = paddle.grad(z, [x, y], create_graph=True)
return g[0]
@dy2static_unittest
class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad):
def setUp(self):
self.func = matmul_high_order_grad
......@@ -139,6 +138,7 @@ class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad):
self.dy2st_grad_input = (x2,)
@dy2static_unittest
class TestMatMulHighOrderGrad2(TestTanhHighOrderGrad):
def setUp(self):
self.func = matmul_high_order_grad
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from ifelse_simple_func import (
NetWithControlFlowIf,
add_fn,
......@@ -54,12 +55,14 @@ else:
place = fluid.CPUPlace()
@dy2static_unittest
class TestDy2staticException(unittest.TestCase):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None
self.error = "Your if/else have different number of return value."
@ast_only_test
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
......@@ -412,10 +415,11 @@ class TestNewVarCreateInOneBranch(unittest.TestCase):
self.assertEqual(paddle.jit.to_static(case_func)(True), -2)
@dy2static_unittest
class TestDy2StIfElseRetInt1(unittest.TestCase):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int1
self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int1)
self.out = self.get_dy2stat_out()
def get_dy2stat_out(self):
......@@ -425,7 +429,9 @@ class TestDy2StIfElseRetInt1(unittest.TestCase):
paddle.jit.enable_to_static(False)
return out
@ast_only_test
def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor))
self.assertIsInstance(self.out[1], int)
......@@ -437,21 +443,26 @@ class TestDy2StIfElseRetInt2(TestDy2staticException):
self.dyfunc = dyfunc_ifelse_ret_int2
@dy2static_unittest
class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int3
self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3)
self.out = self.get_dy2stat_out()
@ast_only_test
def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor))
@dy2static_unittest
class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int4
self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4)
@ast_only_test
def test_ast_to_func(self):
paddle.jit.enable_to_static(True)
with self.assertRaises(Dygraph2StaticException):
......
......@@ -286,7 +286,7 @@ class TestListInWhileLoop(TestListWithoutControlFlow):
def train(self, to_static=False):
with fluid.dygraph.guard():
if to_static:
print(paddle.jit.to_static(self.dygraph_func).code)
# print(paddle.jit.to_static(self.dygraph_func).code)
res = paddle.jit.to_static(self.dygraph_func)(
self.input, self.iter_num
)
......
......@@ -17,6 +17,7 @@ import tempfile
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
from paddle import nn
......@@ -44,6 +45,7 @@ class Net(nn.Layer):
return x
@dy2static_unittest
class TestLstm(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
......@@ -69,6 +71,7 @@ class TestLstm(unittest.TestCase):
static_out = self.run_lstm(to_static=True)
np.testing.assert_allclose(dygraph_out, static_out, rtol=1e-05)
@ast_only_test
def test_save_in_eval(self, with_training=True):
paddle.jit.enable_to_static(True)
net = Net(12, 2)
......@@ -133,6 +136,7 @@ class LinearNet(nn.Layer):
return y
@dy2static_unittest
class TestSaveInEvalMode(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
......@@ -178,6 +182,7 @@ class TestSaveInEvalMode(unittest.TestCase):
)
@dy2static_unittest
class TestEvalAfterSave(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
......
......@@ -18,6 +18,7 @@ import unittest
from time import time
import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools
import paddle
......@@ -158,6 +159,7 @@ class TestMNISTWithToStatic(TestMNIST):
def train_dygraph(self):
return self.train(to_static=False)
@ast_only_test
def test_mnist_to_static(self):
dygraph_loss = self.train_dygraph()
static_loss = self.train_static()
......
......@@ -14,6 +14,8 @@
import unittest
from dygraph_to_static_util import ast_only_test
import paddle
from paddle.static import InputSpec
......@@ -75,6 +77,7 @@ class CheckOpAttr(unittest.TestCase):
'elementwise_sub': self.sub_attrs,
}
@ast_only_test
def test_set_op_attrs(self):
net = NetWithOpAttr(self.in_num, self.out_num)
# set attrs
......@@ -116,6 +119,7 @@ class CheckOpAttr(unittest.TestCase):
else:
self.assertEqual(op_val, expect_val)
@ast_only_test
def test_set_op_attrs_with_sub_block(self):
net = NetWithOpAttr(self.in_num, self.out_num)
# set attrs
......
......@@ -79,11 +79,6 @@ class TestParameterList(unittest.TestCase):
dygraph_loss = self.train(False, to_static=False)
np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-05)
def test_parameter_list_iter(self):
static_loss = self.train(True, to_static=True)
dygraph_loss = self.train(True, to_static=False)
np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-05)
class NetWithRawParamList(paddle.nn.Layer):
def __init__(self, in_size, out_size):
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from test_fetch_feed import Linear
import paddle
......@@ -52,6 +53,7 @@ def fake_data(shape):
return fluid.dygraph.to_variable(x_data)
@dy2static_unittest
class TestWithNestedInput(unittest.TestCase):
def setUp(self):
self.x = None
......@@ -88,6 +90,7 @@ class TestWithNestedInput(unittest.TestCase):
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
@dy2static_unittest
class TestWithNestedOutput(unittest.TestCase):
def setUp(self):
self.x = None
......@@ -124,10 +127,13 @@ class TestWithNestedOutput(unittest.TestCase):
self.assertTrue(dy_var, st_var)
@dy2static_unittest
class TestWithTrainAndEval(unittest.TestCase):
@ast_only_test
def test_switch_eval_and_train(self):
with fluid.dygraph.guard():
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net)
x_data = np.random.random((4, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data)
linear_net(x)
......@@ -154,16 +160,20 @@ class TestWithTrainAndEval(unittest.TestCase):
)
@dy2static_unittest
class TestWithNoGrad(unittest.TestCase):
@ast_only_test
def test_with_no_grad(self):
with fluid.dygraph.guard():
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net)
x_data = np.random.random((5, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data)
with paddle.no_grad():
linear_net.train()
linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
partial_layer.program, partial_layer._train_program
......@@ -186,6 +196,7 @@ class GPT2LMHeadModel(paddle.nn.Layer):
return x1
@dy2static_unittest
class TestPruneUnusedParamInProgram(unittest.TestCase):
def test_prune(self):
input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32")
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
......@@ -21,6 +22,7 @@ from paddle.jit.dy2static import partial_program, program_translator
class TestPartiaProgramLayerHook(unittest.TestCase):
def setUp(self):
os.environ["ENABLE_FALL_BACK"] = "False"
self._hook = partial_program.PartialProgramLayerHook()
def test_before_append_backward(self):
......@@ -35,6 +37,7 @@ class TestPartiaProgramLayerHook(unittest.TestCase):
class TestPrimHook(unittest.TestCase):
def setUp(self):
os.environ["ENABLE_FALL_BACK"] = "False"
core._set_prim_all_enabled(False)
def f():
......
......@@ -18,6 +18,7 @@ import unittest
import astor
import numpy as np
from dygraph_to_static_util import ast_only_test
from ifelse_simple_func import (
dyfunc_with_if_else_early_return1,
dyfunc_with_if_else_early_return2,
......@@ -216,6 +217,7 @@ class TestEnableDeclarative(unittest.TestCase):
self.x = np.random.randn(30, 10, 32).astype('float32')
self.weight = np.random.randn(32, 64).astype('float32')
@ast_only_test
def test_raise_error(self):
with fluid.dygraph.guard():
paddle.jit.enable_to_static(True)
......@@ -266,6 +268,7 @@ def switch_mode_function():
class TestFunctionTrainEvalMode(unittest.TestCase):
@ast_only_test
def test_switch_mode(self):
paddle.disable_static()
switch_mode_function.eval()
......
......@@ -133,11 +133,7 @@ class BottleneckBlock(paddle.nn.Layer):
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
layer_helper = fluid.layer_helper.LayerHelper(
self.full_name(), act='relu'
)
return layer_helper.append_activation(y)
return paddle.nn.functional.relu(y)
class ResNet(paddle.nn.Layer):
......
......@@ -131,10 +131,12 @@ class BottleneckBlock(paddle.nn.Layer):
y = paddle.add(x=short, y=conv2)
layer_helper = paddle.fluid.layer_helper.LayerHelper(
self.full_name(), act='relu'
)
return layer_helper.append_activation(y)
# TODO: uncomment this lines to reproduce the oneDNN segment fault error.
# layer_helper = paddle.fluid.layer_helper.LayerHelper(
# self.full_name(), act='relu'
# )
# return layer_helper.append_activation(y)
return paddle.nn.functional.relu(y)
class ResNet(paddle.nn.Layer):
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
from ifelse_simple_func import dyfunc_with_if_else
import paddle
......@@ -349,12 +350,20 @@ class TestReturnInWhile2(TestReturnBase):
self.dygraph_func = test_return_in_while_2
self.error = "Found return statement in While or For body and loop"
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnInFor2(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_for_2
self.error = "Found return statement in While or For body and loop"
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestRecursiveReturn(TestReturnBase):
def init_dygraph_func(self):
......@@ -367,12 +376,20 @@ class TestReturnDifferentLengthIfBody(TestReturnBase):
self.dygraph_func = test_return_different_length_if_body
self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnDifferentLengthElse(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_else
self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestNoReturn(TestReturnBase):
def init_dygraph_func(self):
......@@ -384,12 +401,20 @@ class TestReturnNone(TestReturnBase):
self.dygraph_func = test_return_none
self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnNoVariable(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_no_variable
self.error = "Your if/else have different number of return value."
@ast_only_test
def test_transformed_static_result(self):
super().test_transformed_static_result()
class TestReturnListOneValue(TestReturnBase):
def init_dygraph_func(self):
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
from paddle.jit.dy2static.program_translator import StaticFunction
......@@ -88,6 +89,7 @@ class TestRollBackNet(unittest.TestCase):
def setUp(self):
paddle.set_device("cpu")
@ast_only_test
def test_net(self):
net = paddle.jit.to_static(Net())
x = paddle.randn([3, 4])
......
......@@ -17,6 +17,7 @@ import tempfile
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
from paddle import fluid
......@@ -53,6 +54,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
def tearDown(self):
self.temp_dir.cleanup()
@ast_only_test
def test_save_inference_model(self):
fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32')
......@@ -144,6 +146,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
class TestPartialProgramRaiseError(unittest.TestCase):
@ast_only_test
def test_param_type(self):
paddle.jit.enable_to_static(True)
x_data = np.random.random((20, 20)).astype('float32')
......
......@@ -17,6 +17,7 @@ import tempfile
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
from test_fetch_feed import Linear
import paddle
......@@ -114,6 +115,7 @@ class TestDyToStaticSaveLoad(unittest.TestCase):
dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05
)
@ast_only_test
def test_save_load_prim(self):
with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
......@@ -154,6 +156,7 @@ class TestDyToStaticSaveLoad(unittest.TestCase):
self.assertIn("pool2d", load_op_type_list)
np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05)
@ast_only_test
def test_save_load_prim_with_hook(self):
with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
......
......@@ -20,6 +20,7 @@ import time
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools
import paddle
......@@ -560,6 +561,7 @@ class TestSeResnet(unittest.TestCase):
),
)
@ast_only_test
def test_check_result(self):
pred_1, loss_1, acc1_1, acc5_1 = self.train(
self.train_reader, to_static=False
......
......@@ -17,6 +17,7 @@ import tempfile
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
from paddle.static import InputSpec
......@@ -51,7 +52,6 @@ def test_slice_in_if(x):
return out
@paddle.jit.to_static
def test_slice_in_while_loop(x, iter_num=3):
x = paddle.to_tensor(x)
iter_num_var = paddle.full(shape=[1], fill_value=iter_num, dtype="int32")
......@@ -153,7 +153,7 @@ class TestSliceInIf(TestSliceWithoutControlFlow):
class TestSliceInWhileLoop(TestSliceWithoutControlFlow):
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_while_loop
self.dygraph_func = paddle.jit.to_static(test_slice_in_while_loop)
class TestSliceInForLoop(TestSliceWithoutControlFlow):
......@@ -179,6 +179,7 @@ class TestSetValueWithLayerAndSave(unittest.TestCase):
def tearDown(self):
self.temp_dir.cleanup()
@ast_only_test
def test_set_value_with_save(self):
paddle.jit.enable_to_static(True)
model = LayerWithSetValue(input_dim=10, hidden=1)
......
......@@ -14,6 +14,8 @@
import unittest
from dygraph_to_static_util import enable_fallback_guard
import paddle
from paddle.nn import Layer
......@@ -101,4 +103,5 @@ class TestArgsSpecName(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
with enable_fallback_guard("False"):
unittest.main()
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
import paddle
......@@ -33,6 +34,7 @@ class TestTensorClone(unittest.TestCase):
return tensor_clone(x).numpy()
def test_tensor_clone(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
......@@ -52,7 +54,9 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase):
y = tensor_numpy(x)
return y.numpy()
@ast_only_test
def test_to_static_numpy_report_error(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False)
with self.assertRaises(AssertionError):
static_res = self._run(to_static=True)
......@@ -74,6 +78,7 @@ class TestTensorItem(unittest.TestCase):
return tensor_item(x)
def test_tensor_clone(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res)
......@@ -93,9 +98,13 @@ class TestTensorSize(unittest.TestCase):
x = paddle.ones([1, 2, 3])
if not to_static:
return tensor_size(x)
return tensor_size(x).numpy()
ret = tensor_size(x)
if hasattr(ret, 'numpy'):
ret = ret.numpy()
return ret
def test_tensor_clone(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)
......@@ -115,6 +124,7 @@ class TestTrueDiv(unittest.TestCase):
return true_div(x, y).numpy()
def test_ture_div(self):
paddle.disable_static()
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
from paddle import fluid
......@@ -230,6 +231,7 @@ def dyfunc_dict_assign_shape():
# 1. Basic tests without control flow
@dy2static_unittest
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
self.input = np.ones(5).astype("int32")
......@@ -287,6 +289,7 @@ class TestTensorShapeBasic(unittest.TestCase):
[op for op in block.ops if op.type == "slice"]
)
@ast_only_test
def test_op_num(self):
static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
program = static_layer.main_program
......@@ -519,6 +522,7 @@ class TestOpNumBasicWithTensorShape(unittest.TestCase):
[op for op in block.ops if op.type == "slice"]
)
@ast_only_test
def test_op_num(self):
static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
program = static_layer.main_program
......@@ -609,6 +613,7 @@ def dyfunc_with_static_convert_var_shape(x):
class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
@ast_only_test
def test(self):
x_spec = paddle.static.InputSpec(shape=[None, 10])
func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec])
......
......@@ -15,6 +15,11 @@
import unittest
import numpy
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
)
import paddle
from paddle.fluid import core
......@@ -95,6 +100,7 @@ def case8(x):
return a
@dy2static_unittest
class TestToTensorReturnVal(unittest.TestCase):
def test_to_tensor_badreturn(self):
paddle.disable_static()
......@@ -148,6 +154,7 @@ class TestToTensorReturnVal(unittest.TestCase):
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
@ast_only_test
def test_to_tensor_err_log(self):
paddle.disable_static()
x = paddle.to_tensor([3])
......@@ -159,6 +166,18 @@ class TestToTensorReturnVal(unittest.TestCase):
in str(e)
)
@sot_only_test
def test_to_tensor_err_log_sot(self):
paddle.disable_static()
x = paddle.to_tensor([3])
try:
a = paddle.jit.to_static(case8)(x)
except Exception as e:
self.assertTrue(
"Can't constructs a 'paddle.Tensor' with data type <class 'dict'>"
in str(e)
)
class TestStatic(unittest.TestCase):
def test_static(self):
......
......@@ -17,6 +17,7 @@ import unittest
from functools import partial
import numpy as np
from dygraph_to_static_util import enable_fallback_guard
import paddle
......@@ -433,4 +434,5 @@ class TestTrainStepTinyModelLRCyclicLR(TestTrainStepTinyModel):
if __name__ == "__main__":
unittest.main()
with enable_fallback_guard("False"):
unittest.main()
......@@ -15,6 +15,7 @@
import platform
import unittest
from dygraph_to_static_util import enable_fallback_guard
from test_train_step import (
TestTrainStepTinyModel,
loss_fn_tiny_model,
......@@ -40,4 +41,5 @@ class TestTrainStepResNet18Adam(TestTrainStepTinyModel):
if __name__ == "__main__":
unittest.main()
with enable_fallback_guard("False"):
unittest.main()
......@@ -15,6 +15,7 @@
import platform
import unittest
from dygraph_to_static_util import enable_fallback_guard
from test_train_step import (
TestTrainStepTinyModel,
loss_fn_tiny_model,
......@@ -40,4 +41,5 @@ class TestTrainStepResNet18Sgd(TestTrainStepTinyModel):
if __name__ == "__main__":
unittest.main()
with enable_fallback_guard("False"):
unittest.main()
......@@ -45,7 +45,9 @@ def parse_args():
default=fluid.is_compiled_with_cuda(),
help='default use gpu.',
)
args = parser.parse_args(['--config', 'tsm.yaml'])
args = parser.parse_args(
['--config', __file__.rpartition('/')[0] + '/tsm.yaml']
)
return args
......
......@@ -17,6 +17,7 @@ import unittest
from typing import Dict, List, Tuple
import numpy as np
from dygraph_to_static_util import dy2static_unittest
import paddle
......@@ -68,6 +69,7 @@ class LinearNetWithDict(BaseLayer):
return {'out': out2}
@dy2static_unittest
class TestTyping(unittest.TestCase):
def setUp(self):
self.in_num = 16
......
......@@ -15,6 +15,8 @@
import unittest
import warnings
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
from paddle.static.nn import cond
......@@ -37,12 +39,14 @@ def false_fn():
return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]]
@dy2static_unittest
class TestReturnNoneInIfelse(unittest.TestCase):
@ast_only_test
def test_dy2static_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fun1()
paddle.jit.to_static(fun1)()
flag = False
for warn in w:
if (
......
......@@ -14,6 +14,12 @@
import unittest
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
)
import paddle
......@@ -93,6 +99,7 @@ def func_ifelse_write_nest_list_dict(x):
return res
@dy2static_unittest
class TestWriteContainer(unittest.TestCase):
def setUp(self):
self.set_func()
......@@ -110,6 +117,15 @@ class TestWriteContainer(unittest.TestCase):
out = out[path]
return out
@sot_only_test
def test_write_container_sot(self):
func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3])
out_static = self.get_raw_value(func_static(input), self.getitem_path)
out_dygraph = self.get_raw_value(self.func(input), self.getitem_path)
self.assertEqual(out_static, out_dygraph)
@ast_only_test
def test_write_container(self):
func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册