未验证 提交 a5dc0a79 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] Rename EagerPyLayer to PyLayer (#43696)

* rename eagerpylayer
上级 8a122ecc
...@@ -129,16 +129,19 @@ PyObject* pylayer_method_apply(PyObject* cls, ...@@ -129,16 +129,19 @@ PyObject* pylayer_method_apply(PyObject* cls,
bool require_any_grad = false; bool require_any_grad = false;
size_t inputs_size = 0; size_t inputs_size = 0;
size_t args_size = 0;
size_t kwargs_size = 0;
PyObject* forward_args = nullptr; PyObject* forward_args = nullptr;
PyObject* kwargs_value_list = nullptr; PyObject* kwargs_value_list = nullptr;
if (kwargs) { if (kwargs) {
inputs_size = PyDict_Size(kwargs); kwargs_size = PyDict_Size(kwargs);
kwargs_value_list = PyDict_Values(kwargs); kwargs_value_list = PyDict_Values(kwargs);
forward_args = PyTuple_New(1);
} else {
inputs_size = PyTuple_GET_SIZE(args);
forward_args = PyTuple_New(inputs_size + 1);
} }
if (args) {
args_size = PyTuple_GET_SIZE(args);
}
inputs_size = kwargs_size + args_size;
forward_args = PyTuple_New(args_size + 1);
Py_INCREF(ctx); Py_INCREF(ctx);
PyTuple_SET_ITEM(forward_args, 0, reinterpret_cast<PyObject*>(ctx)); PyTuple_SET_ITEM(forward_args, 0, reinterpret_cast<PyObject*>(ctx));
...@@ -150,8 +153,8 @@ PyObject* pylayer_method_apply(PyObject* cls, ...@@ -150,8 +153,8 @@ PyObject* pylayer_method_apply(PyObject* cls,
ctx->forward_input_tensor_is_duplicable.reserve(inputs_size); ctx->forward_input_tensor_is_duplicable.reserve(inputs_size);
for (size_t i = 0; i < inputs_size; i++) { for (size_t i = 0; i < inputs_size; i++) {
PyObject* obj = nullptr; PyObject* obj = nullptr;
if (kwargs) { if (i >= args_size) {
obj = PyList_GetItem(kwargs_value_list, i); obj = PyList_GetItem(kwargs_value_list, i - args_size);
} else { } else {
obj = PyTuple_GET_ITEM(args, i); obj = PyTuple_GET_ITEM(args, i);
} }
...@@ -212,7 +215,7 @@ PyObject* pylayer_method_apply(PyObject* cls, ...@@ -212,7 +215,7 @@ PyObject* pylayer_method_apply(PyObject* cls,
} }
} }
if (!kwargs) { if (i < args_size) {
Py_INCREF(obj); Py_INCREF(obj);
PyTuple_SET_ITEM(forward_args, i + 1, obj); PyTuple_SET_ITEM(forward_args, i + 1, obj);
} }
......
...@@ -17,7 +17,13 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 ...@@ -17,7 +17,13 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401 from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401 from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401 from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext, EagerPyLayer, EagerPyLayerContext # noqa: F401 from ..fluid.framework import _in_eager_mode_
if _in_eager_mode_:
from .py_layer import EagerPyLayer as PyLayer # noqa: F401
from .py_layer import EagerPyLayerContext as PyLayerContext # noqa: F401
else:
from .py_layer import LegacyPyLayer as PyLayer # noqa: F401
from .py_layer import LegacyPyLayerContext as PyLayerContext # noqa: F401
from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401 from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import vjp, jvp, Jacobian, Hessian # noqa: F401 from .functional import vjp, jvp, Jacobian, Hessian # noqa: F401
......
...@@ -21,7 +21,7 @@ from paddle.fluid import core ...@@ -21,7 +21,7 @@ from paddle.fluid import core
__all__ = [] __all__ = []
class PyLayerContext(object): class LegacyPyLayerContext(object):
""" """
The object of this class is a context that is used in PyLayer to enhance the function. The object of this class is a context that is used in PyLayer to enhance the function.
...@@ -181,7 +181,7 @@ class CPyLayer(object): ...@@ -181,7 +181,7 @@ class CPyLayer(object):
return core.pylayer_apply(place, cls, *args, **kwargs) return core.pylayer_apply(place, cls, *args, **kwargs)
class PyLayerBackward(PyLayerContext): class PyLayerBackward(LegacyPyLayerContext):
def backward(self, *args, **kwargs): def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
...@@ -205,7 +205,7 @@ class LayerMeta(type): ...@@ -205,7 +205,7 @@ class LayerMeta(type):
return super(LayerMeta, cls).__init__(name, bases, attrs) return super(LayerMeta, cls).__init__(name, bases, attrs)
class PyLayer(with_mateclass(LayerMeta, CPyLayer)): class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)):
""" """
Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules: Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules:
1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod. 1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod.
...@@ -425,6 +425,8 @@ class EagerPyLayerContext(object): ...@@ -425,6 +425,8 @@ class EagerPyLayerContext(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import paddle import paddle
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
import numpy as np import numpy as np
...@@ -464,6 +466,8 @@ class EagerPyLayerContext(object): ...@@ -464,6 +466,8 @@ class EagerPyLayerContext(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import paddle import paddle
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
import numpy as np import numpy as np
......
...@@ -1181,9 +1181,9 @@ def _mp_allreduce(tensor, ...@@ -1181,9 +1181,9 @@ def _mp_allreduce(tensor,
if in_dygraph_mode(): if in_dygraph_mode():
assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op) assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)
from paddle.autograd import EagerPyLayer from paddle.autograd import PyLayer
class mp_allreduce_eager(EagerPyLayer): class mp_allreduce_eager(PyLayer):
@staticmethod @staticmethod
def forward(ctx, tensor, use_calc_stream, ring_id, def forward(ctx, tensor, use_calc_stream, ring_id,
......
...@@ -37,7 +37,7 @@ from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer ...@@ -37,7 +37,7 @@ from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.distributed.fleet.utils.recompute import RecomputeFunction from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
__all__ = [] __all__ = []
...@@ -68,7 +68,8 @@ class _RecomputeModelWrapper(paddle.nn.Layer): ...@@ -68,7 +68,8 @@ class _RecomputeModelWrapper(paddle.nn.Layer):
return do_run return do_run
def _checkpoint(self, func, *args, **kwargs): def _checkpoint(self, func, *args, **kwargs):
return RecomputeFunction.apply(func, self._preserve_rng_state, *args) return LegacyRecomputeFunction.apply(func, self._preserve_rng_state,
*args)
def forward(self, input): def forward(self, input):
end = 0 end = 0
......
...@@ -17,7 +17,7 @@ import contextlib ...@@ -17,7 +17,7 @@ import contextlib
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle import _C_ops from paddle import _C_ops
from paddle.autograd import PyLayer, EagerPyLayer from paddle.autograd import PyLayer
from paddle.fluid import framework from paddle.fluid import framework
from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..parallel_layers.random import get_rng_state_tracker from ..parallel_layers.random import get_rng_state_tracker
...@@ -151,7 +151,7 @@ def _merge_activation(tensor): ...@@ -151,7 +151,7 @@ def _merge_activation(tensor):
return _all_gather(tensor, group=mp_group) return _all_gather(tensor, group=mp_group)
class _HPEagerRecomputeFunction(EagerPyLayer): class _HPRecomputeFunction(PyLayer):
""" """
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
...@@ -256,7 +256,7 @@ class _HPEagerRecomputeFunction(EagerPyLayer): ...@@ -256,7 +256,7 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.eager.Tensor): if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, ) outputs = (outputs, )
assert len(outputs) == len(args) assert len(outputs) == len(args)
...@@ -266,137 +266,8 @@ class _HPEagerRecomputeFunction(EagerPyLayer): ...@@ -266,137 +266,8 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
for i in range(len(outputs)): for i in range(len(outputs)):
if isinstance( if isinstance(
outputs[i], outputs[i],
core.eager.Tensor) and not outputs[i].stop_gradient: (core.VarBase,
forward_outputs_with_grad.append(outputs[i]) core.eager.Tensor)) and not outputs[i].stop_gradient:
backward_inputs.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has stop_gradient=False, this recompute() is not necessary"
)
# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.eager.Tensor))
return grads
class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""
@staticmethod
def forward(ctx, run_function, all_outputs, *args):
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_shapes = []
tensor_inputs = []
cur_device = paddle.get_device()
assert 'gpu:' in paddle.get_device(
), "Recompute with RNG is not support current device: {}.".format(
cur_device)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if _recompute_partition:
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach()).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if _recompute_offload else partition
else:
arg = arg.cpu() if _recompute_offload else arg
arg.stop_gradient = state
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
if paddle.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
return tuple(outputs)
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor())
device_id = paddle.distributed.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices):
if _recompute_partition:
state = tensors[i].stop_gradient
tensors[i] = _merge_activation(
tensors[i]).detach().reshape_(tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if _recompute_offload else tensors[i]
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# need restore auto_cast state as well as w/b list
with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase):
outputs = (outputs, )
assert len(outputs) == len(args)
forward_outputs_with_grad = []
backward_inputs = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i]) forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i]) backward_inputs.append(args[i])
...@@ -408,7 +279,7 @@ class _HPRecomputeFunction(PyLayer): ...@@ -408,7 +279,7 @@ class _HPRecomputeFunction(PyLayer):
# actually backward # actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase)) if isinstance(inp, (core.VarBase, core.eager.Tensor)))
return grads return grads
...@@ -420,9 +291,6 @@ def _hp_recompute(function, *args): ...@@ -420,9 +291,6 @@ def _hp_recompute(function, *args):
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
all_outputs = [] all_outputs = []
if in_dygraph_mode():
_HPEagerRecomputeFunction.apply(function, all_outputs, *args)
else:
_HPRecomputeFunction.apply(function, all_outputs, *args) _HPRecomputeFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1: if len(all_outputs) == 1:
......
...@@ -20,7 +20,7 @@ from collections import OrderedDict ...@@ -20,7 +20,7 @@ from collections import OrderedDict
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.autograd import EagerPyLayer from paddle.autograd import PyLayer
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.framework import EagerParamBase from paddle.fluid.framework import EagerParamBase
...@@ -398,7 +398,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -398,7 +398,7 @@ class GroupShardedStage3(nn.Layer):
def _register_forward_hooks(self, layer): def _register_forward_hooks(self, layer):
""" """
Register EagerPyLayer to manage memory slices. Register PyLayer to manage memory slices.
There are four stages: There are four stages:
FW FW
1. Before the forward layers, synchronize the full parameters. 1. Before the forward layers, synchronize the full parameters.
...@@ -653,7 +653,7 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer_size, ...@@ -653,7 +653,7 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer_size,
return return
class ForwardPostHooks(EagerPyLayer): class ForwardPostHooks(PyLayer):
@staticmethod @staticmethod
def forward(ctx, inputs, layer, order_tracer, trainable_params, def forward(ctx, inputs, layer, order_tracer, trainable_params,
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.autograd import PyLayer, EagerPyLayer from paddle.autograd import PyLayer
from paddle.autograd.py_layer import LegacyPyLayer
from paddle.fluid import framework from paddle.fluid import framework
import contextlib import contextlib
...@@ -68,7 +69,7 @@ def swith_rng_state_tracker(rng_state, tracker): ...@@ -68,7 +69,7 @@ def swith_rng_state_tracker(rng_state, tracker):
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker) get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
class EagerRecomputeFunction(EagerPyLayer): class LegacyRecomputeFunction(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
...@@ -171,7 +172,7 @@ class EagerRecomputeFunction(EagerPyLayer): ...@@ -171,7 +172,7 @@ class EagerRecomputeFunction(EagerPyLayer):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.eager.Tensor): if isinstance(outputs, core.VarBase):
outputs = (outputs, ) outputs = (outputs, )
assert len(outputs) == len(args) assert len(outputs) == len(args)
...@@ -183,9 +184,8 @@ class EagerRecomputeFunction(EagerPyLayer): ...@@ -183,9 +184,8 @@ class EagerRecomputeFunction(EagerPyLayer):
# the following backward_inputs_with_grad is used to avoid this case. # the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = [] backward_inputs_with_grad = []
for i in range(len(outputs)): for i in range(len(outputs)):
if isinstance( if isinstance(outputs[i],
outputs[i], core.VarBase) and not outputs[i].stop_gradient:
core.eager.Tensor) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i]) forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i]) backward_inputs_with_grad.append(args[i])
...@@ -199,8 +199,8 @@ class EagerRecomputeFunction(EagerPyLayer): ...@@ -199,8 +199,8 @@ class EagerRecomputeFunction(EagerPyLayer):
paddle.autograd.backward(forward_outputs_with_grad, paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad) backward_inputs_with_grad)
grads = tuple(inp.grad for inp in detached_inputs grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.eager.Tensor)) if isinstance(inp, core.VarBase))
return grads return grads
...@@ -307,7 +307,7 @@ class RecomputeFunction(PyLayer): ...@@ -307,7 +307,7 @@ class RecomputeFunction(PyLayer):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase): if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, ) outputs = (outputs, )
assert len(outputs) == len(args) assert len(outputs) == len(args)
...@@ -319,8 +319,10 @@ class RecomputeFunction(PyLayer): ...@@ -319,8 +319,10 @@ class RecomputeFunction(PyLayer):
# the following backward_inputs_with_grad is used to avoid this case. # the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = [] backward_inputs_with_grad = []
for i in range(len(outputs)): for i in range(len(outputs)):
if isinstance(outputs[i], if isinstance(
core.VarBase) and not outputs[i].stop_gradient: outputs[i],
(core.VarBase,
core.eager.Tensor)) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i]) forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i]) backward_inputs_with_grad.append(args[i])
...@@ -334,8 +336,14 @@ class RecomputeFunction(PyLayer): ...@@ -334,8 +336,14 @@ class RecomputeFunction(PyLayer):
paddle.autograd.backward(forward_outputs_with_grad, paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad) backward_inputs_with_grad)
grads = list(inp._grad_ivar() for inp in detached_inputs if in_dygraph_mode():
if isinstance(inp, core.VarBase)) grads = tuple(
inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
else:
grads = list(
inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
return grads return grads
...@@ -465,7 +473,4 @@ def recompute(function, *args, **kwargs): ...@@ -465,7 +473,4 @@ def recompute(function, *args, **kwargs):
if framework._dygraph_tracer()._has_grad: if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args) check_recompute_necessary(args)
if in_dygraph_mode():
return EagerRecomputeFunction.apply(function, preserve, *args)
else:
return RecomputeFunction.apply(function, preserve, *args) return RecomputeFunction.apply(function, preserve, *args)
...@@ -60,7 +60,9 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) ...@@ -60,7 +60,9 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3_for_eager)
list(APPEND DIST_TEST_OPS test_dygraph_group_sharded_api) list(APPEND DIST_TEST_OPS test_dygraph_group_sharded_api)
list(APPEND DIST_TEST_OPS test_dygraph_group_sharded_api_for_eager)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
...@@ -305,13 +307,17 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -305,13 +307,17 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3_for_eager)
list(REMOVE_ITEM TEST_OPS test_dygraph_group_sharded_api) list(REMOVE_ITEM TEST_OPS test_dygraph_group_sharded_api)
list(REMOVE_ITEM TEST_OPS test_dygraph_group_sharded_api_for_eager)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
list(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) list(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
list(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision_for_eager)
list(REMOVE_ITEM TEST_OPS test_mixed_precision) list(REMOVE_ITEM TEST_OPS test_mixed_precision)
list(REMOVE_ITEM TEST_OPS test_fleet_base_single) list(REMOVE_ITEM TEST_OPS test_fleet_base_single)
list(REMOVE_ITEM TEST_OPS test_dygraph_recompute) list(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
list(REMOVE_ITEM TEST_OPS test_dygraph_recompute_for_eager)
list(REMOVE_ITEM TEST_OPS test_hybrid_parallel_inference_helper) list(REMOVE_ITEM TEST_OPS test_hybrid_parallel_inference_helper)
list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample) list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample)
list(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) list(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
...@@ -1547,7 +1553,11 @@ if(WITH_DISTRIBUTE ...@@ -1547,7 +1553,11 @@ if(WITH_DISTRIBUTE
120) 120)
set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 200) set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 200)
set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 350) set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 350)
set_tests_properties(test_dygraph_sharding_stage3_for_eager PROPERTIES TIMEOUT
350)
set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_group_sharded_api_for_eager
PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT
...@@ -1637,6 +1647,8 @@ endif() ...@@ -1637,6 +1647,8 @@ endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
set_tests_properties(test_imperative_auto_mixed_precision PROPERTIES TIMEOUT set_tests_properties(test_imperative_auto_mixed_precision PROPERTIES TIMEOUT
300) 300)
set_tests_properties(test_imperative_auto_mixed_precision_for_eager
PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_dygraph_sync_batch_norm PROPERTIES TIMEOUT set_tests_properties(test_parallel_dygraph_sync_batch_norm PROPERTIES TIMEOUT
120) 120)
set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120) set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120)
......
...@@ -21,7 +21,7 @@ import paddle ...@@ -21,7 +21,7 @@ import paddle
import numpy as np import numpy as np
import paddle.distributed as dist import paddle.distributed as dist
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
from paddle.autograd import PyLayer, EagerPyLayer from paddle.autograd import PyLayer
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
...@@ -45,21 +45,6 @@ class cus_tanh(PyLayer): ...@@ -45,21 +45,6 @@ class cus_tanh(PyLayer):
return grad return grad
class cus_tanh_eager(EagerPyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
class SimpleNet(paddle.nn.Layer): class SimpleNet(paddle.nn.Layer):
def __init__(self, train_id, model_id): def __init__(self, train_id, model_id):
...@@ -73,9 +58,6 @@ class SimpleNet(paddle.nn.Layer): ...@@ -73,9 +58,6 @@ class SimpleNet(paddle.nn.Layer):
def forward(self, inputs): def forward(self, inputs):
if self.model_id == 0: if self.model_id == 0:
if in_dygraph_mode():
inputs = cus_tanh_eager.apply(inputs)
elif _in_legacy_dygraph():
inputs = cus_tanh.apply(inputs) inputs = cus_tanh.apply(inputs)
else: else:
inputs = self.tanh(inputs) inputs = self.tanh(inputs)
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
from __future__ import print_function from __future__ import print_function
import os import os
os.environ['FLAGS_enable_eager_mode'] = '0'
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -26,9 +29,7 @@ class TestDygraphGroupSharded(TestMultipleGpus): ...@@ -26,9 +29,7 @@ class TestDygraphGroupSharded(TestMultipleGpus):
# check group sharded logic as well as the accuracy with single mode # check group sharded logic as well as the accuracy with single mode
def test_dygraph_group_sharded(self): def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api.py', eager_mode=False) self.run_mnist_2gpu('dygraph_group_sharded_api.py', eager_mode=False)
self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py')
if __name__ == "__main__": if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphGroupSharded(TestMultipleGpus):
# check group sharded logic as well as the accuracy with single mode
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py')
if __name__ == "__main__":
unittest.main()
...@@ -23,7 +23,6 @@ from paddle.distributed.fleet.utils import recompute ...@@ -23,7 +23,6 @@ from paddle.distributed.fleet.utils import recompute
import random import random
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.framework import _test_eager_guard
def get_fc_block(block_idx, input_size, is_last=False): def get_fc_block(block_idx, input_size, is_last=False):
...@@ -181,34 +180,15 @@ class TestPyLayer(unittest.TestCase): ...@@ -181,34 +180,15 @@ class TestPyLayer(unittest.TestCase):
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self): def test_fc_net_with_dropout(self):
with _test_eager_guard():
self.test_base_case() self.test_base_case()
self.test_base_case()
def test_fc_net_without_restore_rng(self):
with _test_eager_guard():
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
def test_fc_net_with_amp(self): def test_fc_net_with_amp(self):
with _test_eager_guard():
self.test_base_case(enable_autocast=True)
self.test_base_case(enable_autocast=True) self.test_base_case(enable_autocast=True)
def test_fc_net_with_fp16(self): def test_fc_net_with_fp16(self):
with _test_eager_guard():
self.test_base_case(enable_autocast=True, pure_fp16=True)
self.test_base_case(enable_autocast=True, pure_fp16=True) self.test_base_case(enable_autocast=True, pure_fp16=True)
def test_recompute_kwargs(self): def test_recompute_kwargs(self):
with _test_eager_guard():
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2], recompute_kwargs=kwargs)
paddle.set_device("gpu") paddle.set_device("gpu")
kwargs = {"is_test": False} kwargs = {"is_test": False}
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -216,11 +196,6 @@ class TestPyLayer(unittest.TestCase): ...@@ -216,11 +196,6 @@ class TestPyLayer(unittest.TestCase):
recompute_kwargs=kwargs) recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self): def test_recompute_cpu_rng(self):
with _test_eager_guard():
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
paddle.set_device("cpu") paddle.set_device("cpu")
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import unittest
import numpy as np
import paddle
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute
import random
import paddle.fluid.layers as layers
def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(block_name + "_fc_0",
paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(block_name + "_fc_1",
paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_relu_2", paddle.nn.ReLU()),
)
if is_last:
block.add_sublayer(block_name + "_fc_2",
paddle.nn.Linear(input_size, 1,
bias_attr=False)) # add sublayer
else:
block.add_sublayer(block_name + "_fc_2",
paddle.nn.Linear(input_size,
input_size,
bias_attr=False)) # add sublayer
return block
class Naive_fc_net(paddle.nn.Layer):
def __init__(self,
input_size=10,
recompute_blocks=[1, 3],
recompute_kwargs={}):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
def forward(self, inputs):
if 0 in self.recompute_blocks:
inputs = recompute(self.runfunc0, inputs)
else:
inputs = self.runfunc0(inputs)
if 1 in self.recompute_blocks:
inputs = recompute(self.runfunc1, inputs)
else:
inputs = self.runfunc1(inputs)
if 2 in self.recompute_blocks:
inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs)
else:
inputs = self.runfunc2(inputs)
if 3 in self.recompute_blocks:
inputs = recompute(self.runfunc3, inputs)
else:
inputs = self.runfunc3(inputs)
if 4 in self.recompute_blocks:
inputs = recompute(self.runfunc4, inputs)
else:
inputs = self.runfunc4(inputs)
return inputs
def run_model(recompute_block=[],
recompute_kwargs={},
enable_autocast=False,
pure_fp16=False):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)
batch_size, input_size = 1, 10
model = Naive_fc_net(input_size,
recompute_blocks=recompute_block,
recompute_kwargs=recompute_kwargs)
loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters())
if enable_autocast:
scaler = paddle.amp.GradScaler()
loss_ = []
param_ = []
grad_ = []
for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
# x.stop_gradient = False
level = 'O2' if pure_fp16 else 'O1'
with paddle.amp.auto_cast(True, level=level):
y_pred = model(x)
loss = y_pred.mean()
if enable_autocast:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()
param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())
optimizer.clear_grad()
return loss_, param_, grad_
class TestPyLayer(unittest.TestCase):
def test_base_case(self, enable_autocast=False, pure_fp16=False):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
# recompute second block
loss, param, grad = run_model(recompute_block=[1],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(recompute_block=[3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(recompute_block=[1, 2, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(recompute_block=[1, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self):
self.test_base_case()
def test_fc_net_without_restore_rng(self):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
def test_fc_net_with_amp(self):
self.test_base_case(enable_autocast=True)
def test_fc_net_with_fp16(self):
self.test_base_case(enable_autocast=True, pure_fp16=True)
def test_recompute_kwargs(self):
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2],
recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self):
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
if __name__ == '__main__':
unittest.main()
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
from __future__ import print_function from __future__ import print_function
import os import os
os.environ['FLAGS_enable_eager_mode'] = '0'
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -25,15 +28,12 @@ class TestDygraphShardingStage3(TestMultipleGpus): ...@@ -25,15 +28,12 @@ class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_stage3(self): def test_dygraph_sharding_stage3(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3.py')
self.run_mnist_2gpu('dygraph_sharding_stage3.py', eager_mode=False) self.run_mnist_2gpu('dygraph_sharding_stage3.py', eager_mode=False)
def test_dygraph_sharding_stage3_offload(self): def test_dygraph_sharding_stage3_offload(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3_offload.py')
self.run_mnist_2gpu('dygraph_sharding_stage3_offload.py', self.run_mnist_2gpu('dygraph_sharding_stage3_offload.py',
eager_mode=False) eager_mode=False)
if __name__ == "__main__": if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main() unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import os
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_stage3(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3.py')
def test_dygraph_sharding_stage3_offload(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3_offload.py')
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
os.environ['FLAGS_enable_eager_mode'] = '0'
import unittest import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -19,13 +23,11 @@ import paddle.fluid.core as core ...@@ -19,13 +23,11 @@ import paddle.fluid.core as core
import numpy as np import numpy as np
import six import six
import cv2 import cv2
import os
import tempfile import tempfile
from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting
import paddle.nn as nn import paddle.nn as nn
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
from paddle.fluid.framework import _test_eager_guard
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True}) fluid.set_flags({"FLAGS_cudnn_deterministic": True})
...@@ -73,8 +75,6 @@ class TestAutoCast(unittest.TestCase): ...@@ -73,8 +75,6 @@ class TestAutoCast(unittest.TestCase):
self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32) self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32)
def test_amp_guard_white_op(self): def test_amp_guard_white_op(self):
with _test_eager_guard():
self.amp_guard_white_op()
self.amp_guard_white_op() self.amp_guard_white_op()
def amp_guard_black_op(self): def amp_guard_black_op(self):
...@@ -88,8 +88,6 @@ class TestAutoCast(unittest.TestCase): ...@@ -88,8 +88,6 @@ class TestAutoCast(unittest.TestCase):
self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32) self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32)
def test_amp_guard_black_op(self): def test_amp_guard_black_op(self):
with _test_eager_guard():
self.amp_guard_black_op()
self.amp_guard_black_op() self.amp_guard_black_op()
def custom_op_list(self): def custom_op_list(self):
...@@ -123,8 +121,6 @@ class TestAutoCast(unittest.TestCase): ...@@ -123,8 +121,6 @@ class TestAutoCast(unittest.TestCase):
| {"conv2d"}) | {"conv2d"})
def test_custom_op_list(self): def test_custom_op_list(self):
with _test_eager_guard():
self.custom_op_list()
self.custom_op_list() self.custom_op_list()
def custom_op_list_exception(self): def custom_op_list_exception(self):
...@@ -145,8 +141,6 @@ class TestAutoCast(unittest.TestCase): ...@@ -145,8 +141,6 @@ class TestAutoCast(unittest.TestCase):
self.assertRaises(ValueError, func) self.assertRaises(ValueError, func)
def test_custom_op_list_exception(self): def test_custom_op_list_exception(self):
with _test_eager_guard():
self.custom_op_list_exception()
self.custom_op_list_exception() self.custom_op_list_exception()
def amp_guard_upsupported_fp16_op(self): def amp_guard_upsupported_fp16_op(self):
...@@ -174,8 +168,6 @@ class TestAutoCast(unittest.TestCase): ...@@ -174,8 +168,6 @@ class TestAutoCast(unittest.TestCase):
out_purefp16_fp32.dtype == fluid.core.VarDesc.VarType.FP32) out_purefp16_fp32.dtype == fluid.core.VarDesc.VarType.FP32)
def test_amp_guard_upsupported_fp16_op(self): def test_amp_guard_upsupported_fp16_op(self):
with _test_eager_guard():
self.amp_guard_upsupported_fp16_op()
self.amp_guard_upsupported_fp16_op() self.amp_guard_upsupported_fp16_op()
def mode_exception(self): def mode_exception(self):
...@@ -195,8 +187,6 @@ class TestAutoCast(unittest.TestCase): ...@@ -195,8 +187,6 @@ class TestAutoCast(unittest.TestCase):
self.assertRaises(ValueError, func) self.assertRaises(ValueError, func)
def test_mode_exception(self): def test_mode_exception(self):
with _test_eager_guard():
self.mode_exception()
self.mode_exception() self.mode_exception()
...@@ -212,8 +202,6 @@ class TestAmpScaler(unittest.TestCase): ...@@ -212,8 +202,6 @@ class TestAmpScaler(unittest.TestCase):
data.numpy() * 1024), True) data.numpy() * 1024), True)
def test_scale(self): def test_scale(self):
with _test_eager_guard():
self.scale()
self.scale() self.scale()
def minimize(self): def minimize(self):
...@@ -265,8 +253,6 @@ class TestAmpScaler(unittest.TestCase): ...@@ -265,8 +253,6 @@ class TestAmpScaler(unittest.TestCase):
outs_no_scaler[1][i][0].numpy()), True) outs_no_scaler[1][i][0].numpy()), True)
def test_minimize(self): def test_minimize(self):
with _test_eager_guard():
self.minimize()
self.minimize() self.minimize()
def step(self): def step(self):
...@@ -310,8 +296,6 @@ class TestAmpScaler(unittest.TestCase): ...@@ -310,8 +296,6 @@ class TestAmpScaler(unittest.TestCase):
outs_no_scaler[i].numpy()), True) outs_no_scaler[i].numpy()), True)
def test_step(self): def test_step(self):
with _test_eager_guard():
self.step()
self.step() self.step()
def nan_inf(self): def nan_inf(self):
...@@ -344,8 +328,6 @@ class TestAmpScaler(unittest.TestCase): ...@@ -344,8 +328,6 @@ class TestAmpScaler(unittest.TestCase):
np.array_equal(param.numpy(), params_init[param.name])) np.array_equal(param.numpy(), params_init[param.name]))
def test_nan_inf(self): def test_nan_inf(self):
with _test_eager_guard():
self.nan_inf()
self.nan_inf() self.nan_inf()
def step_update_exception(self): def step_update_exception(self):
...@@ -396,8 +378,6 @@ class TestAmpScaler(unittest.TestCase): ...@@ -396,8 +378,6 @@ class TestAmpScaler(unittest.TestCase):
self.assertRaises(RuntimeError, func3) self.assertRaises(RuntimeError, func3)
def test_step_update_exception(self): def test_step_update_exception(self):
with _test_eager_guard():
self.step_update_exception()
self.step_update_exception() self.step_update_exception()
def test_get_and_set(self): def test_get_and_set(self):
...@@ -578,8 +558,6 @@ class TestGradScalerStateDict(unittest.TestCase): ...@@ -578,8 +558,6 @@ class TestGradScalerStateDict(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(out_use_state_dict[0], out_no_state_dict[0])) np.allclose(out_use_state_dict[0], out_no_state_dict[0]))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
...@@ -742,8 +720,6 @@ class TestStateDictHookForAMP(unittest.TestCase): ...@@ -742,8 +720,6 @@ class TestStateDictHookForAMP(unittest.TestCase):
for key in param_value_ori.keys(): for key in param_value_ori.keys():
print(np.equal(param_value_ori[key], param_value_now[key])) print(np.equal(param_value_ori[key], param_value_now[key]))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
...@@ -899,8 +875,6 @@ class TestPureFp16SaveLoad(unittest.TestCase): ...@@ -899,8 +875,6 @@ class TestPureFp16SaveLoad(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(out_use_save_load[0], out_no_save_load[0])) np.allclose(out_use_save_load[0], out_no_save_load[0]))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
...@@ -1005,8 +979,6 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): ...@@ -1005,8 +979,6 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
def test_inference_save_load(self): def test_inference_save_load(self):
self.inference_save_load() self.inference_save_load()
with _test_eager_guard():
self.inference_save_load()
class TestResnet2(unittest.TestCase): class TestResnet2(unittest.TestCase):
...@@ -1146,8 +1118,6 @@ class TestResnet2(unittest.TestCase): ...@@ -1146,8 +1118,6 @@ class TestResnet2(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
def test_with_data_loader(self): def test_with_data_loader(self):
...@@ -1166,8 +1136,6 @@ class TestResnet2(unittest.TestCase): ...@@ -1166,8 +1136,6 @@ class TestResnet2(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
def test_param_group(self): def test_param_group(self):
...@@ -1189,8 +1157,6 @@ class TestResnet2(unittest.TestCase): ...@@ -1189,8 +1157,6 @@ class TestResnet2(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
...@@ -1285,8 +1251,6 @@ class TestResnet(unittest.TestCase): ...@@ -1285,8 +1251,6 @@ class TestResnet(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-1)) np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-1))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
...@@ -1308,8 +1272,6 @@ class TestLayerNormFp16(unittest.TestCase): ...@@ -1308,8 +1272,6 @@ class TestLayerNormFp16(unittest.TestCase):
self.assertTrue( self.assertTrue(
out.dtype == fluid.core.VarDesc.VarType.FP16) out.dtype == fluid.core.VarDesc.VarType.FP16)
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
...@@ -1344,8 +1306,6 @@ class TestBf16(unittest.TestCase): ...@@ -1344,8 +1306,6 @@ class TestBf16(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
...@@ -1399,8 +1359,6 @@ class TestAmpWithHook(unittest.TestCase): ...@@ -1399,8 +1359,6 @@ class TestAmpWithHook(unittest.TestCase):
loss = a.sum() loss = a.sum()
self.assertRaises(RuntimeError, loss.backward) self.assertRaises(RuntimeError, loss.backward)
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
def test_hook_change_place(self): def test_hook_change_place(self):
...@@ -1420,8 +1378,6 @@ class TestAmpWithHook(unittest.TestCase): ...@@ -1420,8 +1378,6 @@ class TestAmpWithHook(unittest.TestCase):
loss = a.sum() loss = a.sum()
self.assertRaises(RuntimeError, loss.backward) self.assertRaises(RuntimeError, loss.backward)
with _test_eager_guard():
func_isinstance()
func_isinstance() func_isinstance()
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.autograd import PyLayer, EagerPyLayer from paddle.autograd.py_layer import LegacyPyLayer, EagerPyLayer
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
...@@ -32,7 +32,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -32,7 +32,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_simple_pylayer_multiple_output(self): def func_test_simple_pylayer_multiple_output(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square): def forward(ctx, x1, x2, func1, func2=paddle.square):
...@@ -70,7 +70,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_simple_pylayer_return_none_with_no_grad(self): def func_test_simple_pylayer_return_none_with_no_grad(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square): def forward(ctx, x1, x2, func1, func2=paddle.square):
...@@ -112,7 +112,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -112,7 +112,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_simple_pylayer_single_output(self): def func_test_simple_pylayer_single_output(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, func1, func2=paddle.square): def forward(ctx, x1, func1, func2=paddle.square):
...@@ -146,7 +146,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -146,7 +146,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_num_output_match(self): def func_test_pylayer_num_output_match(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward( def forward(
...@@ -175,7 +175,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -175,7 +175,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_dtype(self): def func_test_pylayer_dtype(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x, dtype): def forward(ctx, x, dtype):
...@@ -206,7 +206,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -206,7 +206,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_Exception_forward(self): def func_test_pylayer_Exception_forward(self):
class Layer_None1(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_None1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, *args):
...@@ -220,7 +220,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -220,7 +220,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
z = Layer_None1.apply(input1) z = Layer_None1.apply(input1)
class Layer_None2(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_None2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, *args):
...@@ -234,7 +234,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -234,7 +234,7 @@ class TestPyLayer(unittest.TestCase):
# return None # return None
z = Layer_None2.apply(input1) z = Layer_None2.apply(input1)
class Layer_one1(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, *args):
...@@ -249,7 +249,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -249,7 +249,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
z = Layer_one1.apply(input1) z = Layer_one1.apply(input1)
class Layer_one2(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, *args):
...@@ -263,7 +263,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -263,7 +263,7 @@ class TestPyLayer(unittest.TestCase):
# return int # return int
z = Layer_one2.apply(input1) z = Layer_one2.apply(input1)
class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
...@@ -280,7 +280,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -280,7 +280,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_nograd(self): def func_test_pylayer_nograd(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, func1, func2=paddle.square, xx=None): def forward(ctx, x1, func1, func2=paddle.square, xx=None):
...@@ -305,7 +305,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -305,7 +305,8 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_Exception_bk(self): def func_test_pylayer_Exception_bk(self):
class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_bk_none1(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -322,7 +323,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -322,7 +323,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
z.sum().backward() z.sum().backward()
class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_bk_none2(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, x2): def forward(ctx, x1, x2):
...@@ -339,7 +341,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -339,7 +341,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
z.mean().backward() z.mean().backward()
class Layer_bk_one1(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_bk_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -356,7 +359,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -356,7 +359,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
z.mean().backward() z.mean().backward()
class Layer_bk_one2(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_bk_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod @staticmethod
def forward(ctx, x1, x2): def forward(ctx, x1, x2):
...@@ -374,7 +378,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -374,7 +378,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
z.mean().backward() z.mean().backward()
class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -388,7 +392,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -388,7 +392,8 @@ class TestPyLayer(unittest.TestCase):
z = z[0] + z[1] z = z[0] + z[1]
z.mean().backward() z.mean().backward()
class Layer_bk_match(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_bk_match(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -412,7 +417,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -412,7 +417,8 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_bk_return_none(self): def func_test_pylayer_bk_return_none(self):
class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_bk_none1(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, x2): def forward(ctx, x1, x2):
...@@ -431,7 +437,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -431,7 +437,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
z.mean().backward() z.mean().backward()
class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer): class Layer_bk_none2(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, x2): def forward(ctx, x1, x2):
...@@ -457,7 +464,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -457,7 +464,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_inplace(self): def func_test_pylayer_inplace(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -494,7 +501,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -494,7 +501,8 @@ class TestPyLayer(unittest.TestCase):
def test_pylayer_inplace_backward_error(self): def test_pylayer_inplace_backward_error(self):
with _test_eager_guard(): with _test_eager_guard():
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -530,7 +538,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -530,7 +538,8 @@ class TestPyLayer(unittest.TestCase):
def test_pylayer_inplace_backward_success_1(self): def test_pylayer_inplace_backward_success_1(self):
with _test_eager_guard(): with _test_eager_guard():
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -564,7 +573,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -564,7 +573,8 @@ class TestPyLayer(unittest.TestCase):
def test_pylayer_inplace_backward_success_2(self): def test_pylayer_inplace_backward_success_2(self):
with _test_eager_guard(): with _test_eager_guard():
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -597,7 +607,8 @@ class TestPyLayer(unittest.TestCase): ...@@ -597,7 +607,8 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_inplace_and_leaf_exception(self): def func_test_pylayer_inplace_and_leaf_exception(self):
class cus_pylayer_op(EagerPyLayer if in_dygraph_mode() else PyLayer): class cus_pylayer_op(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -633,7 +644,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -633,7 +644,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_backward_in_backward(self): def func_test_backward_in_backward(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
...@@ -665,7 +676,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -665,7 +676,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_return_to_tensor(self): def func_test_return_to_tensor(self):
class Tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): class Tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1): def forward(ctx, x1):
...@@ -779,7 +790,7 @@ class TestPyLayerReturnType(unittest.TestCase): ...@@ -779,7 +790,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_args_fake_tensor(self): def test_forward_args_fake_tensor(self):
class Tanh(PyLayer): class Tanh(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1): def forward(ctx, x1):
...@@ -797,7 +808,7 @@ class TestPyLayerReturnType(unittest.TestCase): ...@@ -797,7 +808,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_kwargs_fake_tensor(self): def test_forward_kwargs_fake_tensor(self):
class Tanh(PyLayer): class Tanh(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1): def forward(ctx, x1):
...@@ -815,7 +826,7 @@ class TestPyLayerReturnType(unittest.TestCase): ...@@ -815,7 +826,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_return_fake_tensor(self): def test_forward_return_fake_tensor(self):
class Tanh(PyLayer): class Tanh(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1): def forward(ctx, x1):
...@@ -833,7 +844,7 @@ class TestPyLayerReturnType(unittest.TestCase): ...@@ -833,7 +844,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_return_fake_tensor_tuple(self): def test_forward_return_fake_tensor_tuple(self):
class Tanh(PyLayer): class Tanh(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1): def forward(ctx, x1):
...@@ -851,7 +862,7 @@ class TestPyLayerReturnType(unittest.TestCase): ...@@ -851,7 +862,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_backward_return_fake_tensor_tuple(self): def test_backward_return_fake_tensor_tuple(self):
class Tanh(PyLayer): class Tanh(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1, x2): def forward(ctx, x1, x2):
...@@ -871,7 +882,7 @@ class TestPyLayerReturnType(unittest.TestCase): ...@@ -871,7 +882,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_backward_return_fake_tensor(self): def test_backward_return_fake_tensor(self):
class Tanh(PyLayer): class Tanh(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, x1): def forward(ctx, x1):
......
...@@ -31,7 +31,7 @@ from paddle.distributed import alltoall, all_gather ...@@ -31,7 +31,7 @@ from paddle.distributed import alltoall, all_gather
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.autograd import PyLayer, EagerPyLayer from paddle.autograd import PyLayer
from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate
from .utils import count_by_gate from .utils import count_by_gate
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute
...@@ -132,53 +132,6 @@ class MoEScatter(PyLayer): ...@@ -132,53 +132,6 @@ class MoEScatter(PyLayer):
return grad_in, None, None, None return grad_in, None, None, None
class EagerMoEScatter(EagerPyLayer):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
"""
@staticmethod
def forward(ctx,
inp,
pos,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
group=None):
local_input_buf = _local_scatter(inp, pos)
if world_size > 1:
global_input_buf = global_scatter(local_input_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_input_buf = local_input_buf
ctx.moe_args = inp.shape[0], world_size, group
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf
@staticmethod
def backward(ctx, grad):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensor()
(inp_batch_size, world_size, group) = ctx.moe_args
if world_size > 1:
local_grad_in = global_gather(grad,
local_expert_count,
global_expert_count,
group=group)
else:
local_grad_in = grad
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None
class MoEGather(PyLayer): class MoEGather(PyLayer):
r""" r"""
Gather output samples from contiguous alone experts back to [batch x Gather output samples from contiguous alone experts back to [batch x
...@@ -226,53 +179,6 @@ class MoEGather(PyLayer): ...@@ -226,53 +179,6 @@ class MoEGather(PyLayer):
return global_grad_out_buf, None, None, None return global_grad_out_buf, None, None, None
class EagerMoEGather(EagerPyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MoEScatter.
"""
@staticmethod
def forward(ctx,
global_output_buf,
pos,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
group=None):
if world_size > 1:
local_output_buf = global_gather(global_output_buf,
local_expert_count,
global_expert_count,
group=group)
else:
local_output_buf = global_output_buf
output = _local_gather(local_output_buf,
pos,
local_batch_size,
maybe_overlap=False)
ctx.moe_args = (global_output_buf.shape[0], world_size, group)
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensor()
fwd_batch_size, world_size, group = ctx.moe_args
grad_out_buf = _local_scatter(grad_out, pos)
if world_size > 1:
global_grad_out_buf = global_scatter(grad_out_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None
class AllGather(PyLayer): class AllGather(PyLayer):
r""" r"""
A wrapper for the All-Gather function to support auto-differentiation. A wrapper for the All-Gather function to support auto-differentiation.
...@@ -295,28 +201,6 @@ class AllGather(PyLayer): ...@@ -295,28 +201,6 @@ class AllGather(PyLayer):
ends=[(rank + 1) * dim0]) ends=[(rank + 1) * dim0])
class EagerAllGather(EagerPyLayer):
r"""
A wrapper for the All-Gather function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
tensor_list = []
paddle.distributed.all_gather(tensor_list, inp, group=group)
output = paddle.concat(tensor_list, axis=0)
ctx.args = rank, inp.shape[0]
return output
@staticmethod
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return paddle.slice(grad_out,
axes=[0],
starts=[rank * dim0],
ends=[(rank + 1) * dim0])
class Slice(PyLayer): class Slice(PyLayer):
r""" r"""
A wrapper for the Slice function to support auto-differentiation. A wrapper for the Slice function to support auto-differentiation.
...@@ -341,30 +225,6 @@ class Slice(PyLayer): ...@@ -341,30 +225,6 @@ class Slice(PyLayer):
return _all_gather(grad_out, group=group) return _all_gather(grad_out, group=group)
class EagerSlice(EagerPyLayer):
r"""
A wrapper for the Slice function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B = inp.shape[0]
local_batch_size = B // world_size
batch_start = local_batch_size * rank
batch_end = min(batch_start + local_batch_size, B)
inp = paddle.slice(inp,
axes=[0],
starts=[batch_start],
ends=[batch_end])
ctx.args = world_size, group
return inp
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
return _all_gather(grad_out, group=group)
def prepare_forward(gate, num_expert, world_size, moe_group): def prepare_forward(gate, num_expert, world_size, moe_group):
pos, local_expert_count, global_expert_count = count_by_gate( pos, local_expert_count, global_expert_count = count_by_gate(
gate, num_expert, world_size, group=moe_group) gate, num_expert, world_size, group=moe_group)
...@@ -517,9 +377,6 @@ class MoELayer(nn.Layer): ...@@ -517,9 +377,6 @@ class MoELayer(nn.Layer):
mp_rank = self.mp_group.rank mp_rank = self.mp_group.rank
mp_size = self.mp_group.nranks mp_size = self.mp_group.nranks
if mp_size > 1: if mp_size > 1:
if in_dygraph_mode():
inp = EagerSlice.apply(inp, mp_rank, mp_size, self.mp_group)
else:
inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group) inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group)
value, gate = self.gate(inp) value, gate = self.gate(inp)
...@@ -541,11 +398,6 @@ class MoELayer(nn.Layer): ...@@ -541,11 +398,6 @@ class MoELayer(nn.Layer):
temp_pos = pos temp_pos = pos
assert topk == self.top_k assert topk == self.top_k
if in_dygraph_mode():
x = EagerMoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)
else:
x = MoEScatter.apply(inp, temp_pos, local_expert_count, x = MoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size, global_expert_count, fwd_batch_size,
self.world_size, self.group) self.world_size, self.group)
...@@ -577,11 +429,6 @@ class MoELayer(nn.Layer): ...@@ -577,11 +429,6 @@ class MoELayer(nn.Layer):
if len(gate.shape) == 2: if len(gate.shape) == 2:
out_batch_size *= gate.shape[1] out_batch_size *= gate.shape[1]
if in_dygraph_mode():
x = EagerMoEGather.apply(x, pos, local_expert_count,
global_expert_count, out_batch_size,
self.world_size, self.group)
else:
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count, x = MoEGather.apply(x, pos, local_expert_count, global_expert_count,
out_batch_size, self.world_size, self.group) out_batch_size, self.world_size, self.group)
...@@ -590,9 +437,6 @@ class MoELayer(nn.Layer): ...@@ -590,9 +437,6 @@ class MoELayer(nn.Layer):
x = paddle.bmm(value, x).reshape([-1, d_model]) x = paddle.bmm(value, x).reshape([-1, d_model])
if mp_size > 1: if mp_size > 1:
if in_dygraph_mode():
x = EagerAllGather.apply(x, mp_rank, mp_size, self.mp_group)
else:
x = AllGather.apply(x, mp_rank, mp_size, self.mp_group) x = AllGather.apply(x, mp_rank, mp_size, self.mp_group)
x = paddle.reshape_(x, origin_shape) x = paddle.reshape_(x, origin_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册