未验证 提交 a5dc0a79 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] Rename EagerPyLayer to PyLayer (#43696)

* rename eagerpylayer
上级 8a122ecc
......@@ -129,16 +129,19 @@ PyObject* pylayer_method_apply(PyObject* cls,
bool require_any_grad = false;
size_t inputs_size = 0;
size_t args_size = 0;
size_t kwargs_size = 0;
PyObject* forward_args = nullptr;
PyObject* kwargs_value_list = nullptr;
if (kwargs) {
inputs_size = PyDict_Size(kwargs);
kwargs_size = PyDict_Size(kwargs);
kwargs_value_list = PyDict_Values(kwargs);
forward_args = PyTuple_New(1);
} else {
inputs_size = PyTuple_GET_SIZE(args);
forward_args = PyTuple_New(inputs_size + 1);
}
if (args) {
args_size = PyTuple_GET_SIZE(args);
}
inputs_size = kwargs_size + args_size;
forward_args = PyTuple_New(args_size + 1);
Py_INCREF(ctx);
PyTuple_SET_ITEM(forward_args, 0, reinterpret_cast<PyObject*>(ctx));
......@@ -150,8 +153,8 @@ PyObject* pylayer_method_apply(PyObject* cls,
ctx->forward_input_tensor_is_duplicable.reserve(inputs_size);
for (size_t i = 0; i < inputs_size; i++) {
PyObject* obj = nullptr;
if (kwargs) {
obj = PyList_GetItem(kwargs_value_list, i);
if (i >= args_size) {
obj = PyList_GetItem(kwargs_value_list, i - args_size);
} else {
obj = PyTuple_GET_ITEM(args, i);
}
......@@ -212,7 +215,7 @@ PyObject* pylayer_method_apply(PyObject* cls,
}
}
if (!kwargs) {
if (i < args_size) {
Py_INCREF(obj);
PyTuple_SET_ITEM(forward_args, i + 1, obj);
}
......
......@@ -17,7 +17,13 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext, EagerPyLayer, EagerPyLayerContext # noqa: F401
from ..fluid.framework import _in_eager_mode_
if _in_eager_mode_:
from .py_layer import EagerPyLayer as PyLayer # noqa: F401
from .py_layer import EagerPyLayerContext as PyLayerContext # noqa: F401
else:
from .py_layer import LegacyPyLayer as PyLayer # noqa: F401
from .py_layer import LegacyPyLayerContext as PyLayerContext # noqa: F401
from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import vjp, jvp, Jacobian, Hessian # noqa: F401
......
......@@ -21,7 +21,7 @@ from paddle.fluid import core
__all__ = []
class PyLayerContext(object):
class LegacyPyLayerContext(object):
"""
The object of this class is a context that is used in PyLayer to enhance the function.
......@@ -181,7 +181,7 @@ class CPyLayer(object):
return core.pylayer_apply(place, cls, *args, **kwargs)
class PyLayerBackward(PyLayerContext):
class PyLayerBackward(LegacyPyLayerContext):
def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.guard():
......@@ -205,7 +205,7 @@ class LayerMeta(type):
return super(LayerMeta, cls).__init__(name, bases, attrs)
class PyLayer(with_mateclass(LayerMeta, CPyLayer)):
class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)):
"""
Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules:
1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod.
......@@ -425,6 +425,8 @@ class EagerPyLayerContext(object):
Examples:
.. code-block:: python
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import paddle
from paddle.autograd import PyLayer
import numpy as np
......@@ -464,6 +466,8 @@ class EagerPyLayerContext(object):
Examples:
.. code-block:: python
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import paddle
from paddle.autograd import PyLayer
import numpy as np
......
......@@ -1181,9 +1181,9 @@ def _mp_allreduce(tensor,
if in_dygraph_mode():
assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)
from paddle.autograd import EagerPyLayer
from paddle.autograd import PyLayer
class mp_allreduce_eager(EagerPyLayer):
class mp_allreduce_eager(PyLayer):
@staticmethod
def forward(ctx, tensor, use_calc_stream, ring_id,
......
......@@ -37,7 +37,7 @@ from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
from paddle.distributed.fleet.utils.recompute import RecomputeFunction
from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
__all__ = []
......@@ -68,7 +68,8 @@ class _RecomputeModelWrapper(paddle.nn.Layer):
return do_run
def _checkpoint(self, func, *args, **kwargs):
return RecomputeFunction.apply(func, self._preserve_rng_state, *args)
return LegacyRecomputeFunction.apply(func, self._preserve_rng_state,
*args)
def forward(self, input):
end = 0
......
......@@ -17,7 +17,7 @@ import contextlib
import paddle
from paddle.fluid import core
from paddle import _C_ops
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.autograd import PyLayer
from paddle.fluid import framework
from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..parallel_layers.random import get_rng_state_tracker
......@@ -151,7 +151,7 @@ def _merge_activation(tensor):
return _all_gather(tensor, group=mp_group)
class _HPEagerRecomputeFunction(EagerPyLayer):
class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
......@@ -256,7 +256,7 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.eager.Tensor):
if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, )
assert len(outputs) == len(args)
......@@ -266,137 +266,8 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
for i in range(len(outputs)):
if isinstance(
outputs[i],
core.eager.Tensor) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has stop_gradient=False, this recompute() is not necessary"
)
# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.eager.Tensor))
return grads
class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""
@staticmethod
def forward(ctx, run_function, all_outputs, *args):
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_shapes = []
tensor_inputs = []
cur_device = paddle.get_device()
assert 'gpu:' in paddle.get_device(
), "Recompute with RNG is not support current device: {}.".format(
cur_device)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if _recompute_partition:
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach()).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if _recompute_offload else partition
else:
arg = arg.cpu() if _recompute_offload else arg
arg.stop_gradient = state
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
if paddle.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
return tuple(outputs)
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor())
device_id = paddle.distributed.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices):
if _recompute_partition:
state = tensors[i].stop_gradient
tensors[i] = _merge_activation(
tensors[i]).detach().reshape_(tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if _recompute_offload else tensors[i]
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# need restore auto_cast state as well as w/b list
with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase):
outputs = (outputs, )
assert len(outputs) == len(args)
forward_outputs_with_grad = []
backward_inputs = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
(core.VarBase,
core.eager.Tensor)) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i])
......@@ -408,7 +279,7 @@ class _HPRecomputeFunction(PyLayer):
# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
return grads
......@@ -420,10 +291,7 @@ def _hp_recompute(function, *args):
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
all_outputs = []
if in_dygraph_mode():
_HPEagerRecomputeFunction.apply(function, all_outputs, *args)
else:
_HPRecomputeFunction.apply(function, all_outputs, *args)
_HPRecomputeFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1:
return all_outputs[0]
......
......@@ -20,7 +20,7 @@ from collections import OrderedDict
import paddle
from paddle import nn
from paddle.autograd import EagerPyLayer
from paddle.autograd import PyLayer
import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.framework import EagerParamBase
......@@ -398,7 +398,7 @@ class GroupShardedStage3(nn.Layer):
def _register_forward_hooks(self, layer):
"""
Register EagerPyLayer to manage memory slices.
Register PyLayer to manage memory slices.
There are four stages:
FW
1. Before the forward layers, synchronize the full parameters.
......@@ -653,7 +653,7 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer_size,
return
class ForwardPostHooks(EagerPyLayer):
class ForwardPostHooks(PyLayer):
@staticmethod
def forward(ctx, inputs, layer, order_tracer, trainable_params,
......
......@@ -14,7 +14,8 @@
import paddle
from paddle.fluid import core
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.autograd import PyLayer
from paddle.autograd.py_layer import LegacyPyLayer
from paddle.fluid import framework
import contextlib
......@@ -68,7 +69,7 @@ def swith_rng_state_tracker(rng_state, tracker):
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
class EagerRecomputeFunction(EagerPyLayer):
class LegacyRecomputeFunction(LegacyPyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
......@@ -171,7 +172,7 @@ class EagerRecomputeFunction(EagerPyLayer):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.eager.Tensor):
if isinstance(outputs, core.VarBase):
outputs = (outputs, )
assert len(outputs) == len(args)
......@@ -183,9 +184,8 @@ class EagerRecomputeFunction(EagerPyLayer):
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(
outputs[i],
core.eager.Tensor) and not outputs[i].stop_gradient:
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
......@@ -199,8 +199,8 @@ class EagerRecomputeFunction(EagerPyLayer):
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
grads = tuple(inp.grad for inp in detached_inputs
if isinstance(inp, core.eager.Tensor))
grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
return grads
......@@ -307,7 +307,7 @@ class RecomputeFunction(PyLayer):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase):
if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, )
assert len(outputs) == len(args)
......@@ -319,8 +319,10 @@ class RecomputeFunction(PyLayer):
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
if isinstance(
outputs[i],
(core.VarBase,
core.eager.Tensor)) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
......@@ -334,8 +336,14 @@ class RecomputeFunction(PyLayer):
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
if in_dygraph_mode():
grads = tuple(
inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
else:
grads = list(
inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
return grads
......@@ -465,7 +473,4 @@ def recompute(function, *args, **kwargs):
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
if in_dygraph_mode():
return EagerRecomputeFunction.apply(function, preserve, *args)
else:
return RecomputeFunction.apply(function, preserve, *args)
return RecomputeFunction.apply(function, preserve, *args)
......@@ -60,7 +60,9 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3_for_eager)
list(APPEND DIST_TEST_OPS test_dygraph_group_sharded_api)
list(APPEND DIST_TEST_OPS test_dygraph_group_sharded_api_for_eager)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
......@@ -305,13 +307,17 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3_for_eager)
list(REMOVE_ITEM TEST_OPS test_dygraph_group_sharded_api)
list(REMOVE_ITEM TEST_OPS test_dygraph_group_sharded_api_for_eager)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
list(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
list(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision_for_eager)
list(REMOVE_ITEM TEST_OPS test_mixed_precision)
list(REMOVE_ITEM TEST_OPS test_fleet_base_single)
list(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
list(REMOVE_ITEM TEST_OPS test_dygraph_recompute_for_eager)
list(REMOVE_ITEM TEST_OPS test_hybrid_parallel_inference_helper)
list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample)
list(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
......@@ -1547,7 +1553,11 @@ if(WITH_DISTRIBUTE
120)
set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 200)
set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 350)
set_tests_properties(test_dygraph_sharding_stage3_for_eager PROPERTIES TIMEOUT
350)
set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_group_sharded_api_for_eager
PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT
......@@ -1637,6 +1647,8 @@ endif()
if(WITH_GPU OR WITH_ROCM)
set_tests_properties(test_imperative_auto_mixed_precision PROPERTIES TIMEOUT
300)
set_tests_properties(test_imperative_auto_mixed_precision_for_eager
PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_dygraph_sync_batch_norm PROPERTIES TIMEOUT
120)
set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120)
......
......@@ -21,7 +21,7 @@ import paddle
import numpy as np
import paddle.distributed as dist
from paddle.fluid.dygraph.nn import Linear
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.autograd import PyLayer
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
......@@ -45,21 +45,6 @@ class cus_tanh(PyLayer):
return grad
class cus_tanh_eager(EagerPyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
class SimpleNet(paddle.nn.Layer):
def __init__(self, train_id, model_id):
......@@ -73,10 +58,7 @@ class SimpleNet(paddle.nn.Layer):
def forward(self, inputs):
if self.model_id == 0:
if in_dygraph_mode():
inputs = cus_tanh_eager.apply(inputs)
elif _in_legacy_dygraph():
inputs = cus_tanh.apply(inputs)
inputs = cus_tanh.apply(inputs)
else:
inputs = self.tanh(inputs)
......
......@@ -15,6 +15,9 @@
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '0'
import unittest
import paddle.fluid as fluid
......@@ -26,9 +29,7 @@ class TestDygraphGroupSharded(TestMultipleGpus):
# check group sharded logic as well as the accuracy with single mode
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api.py', eager_mode=False)
self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py')
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphGroupSharded(TestMultipleGpus):
# check group sharded logic as well as the accuracy with single mode
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py')
if __name__ == "__main__":
unittest.main()
......@@ -23,7 +23,6 @@ from paddle.distributed.fleet.utils import recompute
import random
import paddle.fluid.layers as layers
from paddle.fluid.framework import _test_eager_guard
def get_fc_block(block_idx, input_size, is_last=False):
......@@ -181,34 +180,15 @@ class TestPyLayer(unittest.TestCase):
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self):
with _test_eager_guard():
self.test_base_case()
self.test_base_case()
def test_fc_net_without_restore_rng(self):
with _test_eager_guard():
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
def test_fc_net_with_amp(self):
with _test_eager_guard():
self.test_base_case(enable_autocast=True)
self.test_base_case(enable_autocast=True)
def test_fc_net_with_fp16(self):
with _test_eager_guard():
self.test_base_case(enable_autocast=True, pure_fp16=True)
self.test_base_case(enable_autocast=True, pure_fp16=True)
def test_recompute_kwargs(self):
with _test_eager_guard():
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2], recompute_kwargs=kwargs)
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
......@@ -216,11 +196,6 @@ class TestPyLayer(unittest.TestCase):
recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self):
with _test_eager_guard():
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import unittest
import numpy as np
import paddle
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute
import random
import paddle.fluid.layers as layers
def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(block_name + "_fc_0",
paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(block_name + "_fc_1",
paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_relu_2", paddle.nn.ReLU()),
)
if is_last:
block.add_sublayer(block_name + "_fc_2",
paddle.nn.Linear(input_size, 1,
bias_attr=False)) # add sublayer
else:
block.add_sublayer(block_name + "_fc_2",
paddle.nn.Linear(input_size,
input_size,
bias_attr=False)) # add sublayer
return block
class Naive_fc_net(paddle.nn.Layer):
def __init__(self,
input_size=10,
recompute_blocks=[1, 3],
recompute_kwargs={}):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
def forward(self, inputs):
if 0 in self.recompute_blocks:
inputs = recompute(self.runfunc0, inputs)
else:
inputs = self.runfunc0(inputs)
if 1 in self.recompute_blocks:
inputs = recompute(self.runfunc1, inputs)
else:
inputs = self.runfunc1(inputs)
if 2 in self.recompute_blocks:
inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs)
else:
inputs = self.runfunc2(inputs)
if 3 in self.recompute_blocks:
inputs = recompute(self.runfunc3, inputs)
else:
inputs = self.runfunc3(inputs)
if 4 in self.recompute_blocks:
inputs = recompute(self.runfunc4, inputs)
else:
inputs = self.runfunc4(inputs)
return inputs
def run_model(recompute_block=[],
recompute_kwargs={},
enable_autocast=False,
pure_fp16=False):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)
batch_size, input_size = 1, 10
model = Naive_fc_net(input_size,
recompute_blocks=recompute_block,
recompute_kwargs=recompute_kwargs)
loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters())
if enable_autocast:
scaler = paddle.amp.GradScaler()
loss_ = []
param_ = []
grad_ = []
for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
# x.stop_gradient = False
level = 'O2' if pure_fp16 else 'O1'
with paddle.amp.auto_cast(True, level=level):
y_pred = model(x)
loss = y_pred.mean()
if enable_autocast:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()
param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())
optimizer.clear_grad()
return loss_, param_, grad_
class TestPyLayer(unittest.TestCase):
def test_base_case(self, enable_autocast=False, pure_fp16=False):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
# recompute second block
loss, param, grad = run_model(recompute_block=[1],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(recompute_block=[3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(recompute_block=[1, 2, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(recompute_block=[1, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self):
self.test_base_case()
def test_fc_net_without_restore_rng(self):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
def test_fc_net_with_amp(self):
self.test_base_case(enable_autocast=True)
def test_fc_net_with_fp16(self):
self.test_base_case(enable_autocast=True, pure_fp16=True)
def test_recompute_kwargs(self):
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2],
recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self):
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
if __name__ == '__main__':
unittest.main()
......@@ -15,6 +15,9 @@
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '0'
import unittest
import paddle.fluid as fluid
......@@ -25,15 +28,12 @@ class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_stage3(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3.py')
self.run_mnist_2gpu('dygraph_sharding_stage3.py', eager_mode=False)
def test_dygraph_sharding_stage3_offload(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3_offload.py')
self.run_mnist_2gpu('dygraph_sharding_stage3_offload.py',
eager_mode=False)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import os
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_stage3(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3.py')
def test_dygraph_sharding_stage3_offload(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3_offload.py')
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
......@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['FLAGS_enable_eager_mode'] = '0'
import unittest
import paddle
import paddle.fluid as fluid
......@@ -19,13 +23,11 @@ import paddle.fluid.core as core
import numpy as np
import six
import cv2
import os
import tempfile
from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting
import paddle.nn as nn
from paddle.static import InputSpec
from paddle.autograd import PyLayer
from paddle.fluid.framework import _test_eager_guard
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
......@@ -73,8 +75,6 @@ class TestAutoCast(unittest.TestCase):
self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32)
def test_amp_guard_white_op(self):
with _test_eager_guard():
self.amp_guard_white_op()
self.amp_guard_white_op()
def amp_guard_black_op(self):
......@@ -88,8 +88,6 @@ class TestAutoCast(unittest.TestCase):
self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32)
def test_amp_guard_black_op(self):
with _test_eager_guard():
self.amp_guard_black_op()
self.amp_guard_black_op()
def custom_op_list(self):
......@@ -123,8 +121,6 @@ class TestAutoCast(unittest.TestCase):
| {"conv2d"})
def test_custom_op_list(self):
with _test_eager_guard():
self.custom_op_list()
self.custom_op_list()
def custom_op_list_exception(self):
......@@ -145,8 +141,6 @@ class TestAutoCast(unittest.TestCase):
self.assertRaises(ValueError, func)
def test_custom_op_list_exception(self):
with _test_eager_guard():
self.custom_op_list_exception()
self.custom_op_list_exception()
def amp_guard_upsupported_fp16_op(self):
......@@ -174,8 +168,6 @@ class TestAutoCast(unittest.TestCase):
out_purefp16_fp32.dtype == fluid.core.VarDesc.VarType.FP32)
def test_amp_guard_upsupported_fp16_op(self):
with _test_eager_guard():
self.amp_guard_upsupported_fp16_op()
self.amp_guard_upsupported_fp16_op()
def mode_exception(self):
......@@ -195,8 +187,6 @@ class TestAutoCast(unittest.TestCase):
self.assertRaises(ValueError, func)
def test_mode_exception(self):
with _test_eager_guard():
self.mode_exception()
self.mode_exception()
......@@ -212,8 +202,6 @@ class TestAmpScaler(unittest.TestCase):
data.numpy() * 1024), True)
def test_scale(self):
with _test_eager_guard():
self.scale()
self.scale()
def minimize(self):
......@@ -265,8 +253,6 @@ class TestAmpScaler(unittest.TestCase):
outs_no_scaler[1][i][0].numpy()), True)
def test_minimize(self):
with _test_eager_guard():
self.minimize()
self.minimize()
def step(self):
......@@ -310,8 +296,6 @@ class TestAmpScaler(unittest.TestCase):
outs_no_scaler[i].numpy()), True)
def test_step(self):
with _test_eager_guard():
self.step()
self.step()
def nan_inf(self):
......@@ -344,8 +328,6 @@ class TestAmpScaler(unittest.TestCase):
np.array_equal(param.numpy(), params_init[param.name]))
def test_nan_inf(self):
with _test_eager_guard():
self.nan_inf()
self.nan_inf()
def step_update_exception(self):
......@@ -396,8 +378,6 @@ class TestAmpScaler(unittest.TestCase):
self.assertRaises(RuntimeError, func3)
def test_step_update_exception(self):
with _test_eager_guard():
self.step_update_exception()
self.step_update_exception()
def test_get_and_set(self):
......@@ -578,8 +558,6 @@ class TestGradScalerStateDict(unittest.TestCase):
self.assertTrue(
np.allclose(out_use_state_dict[0], out_no_state_dict[0]))
with _test_eager_guard():
func_isinstance()
func_isinstance()
......@@ -742,8 +720,6 @@ class TestStateDictHookForAMP(unittest.TestCase):
for key in param_value_ori.keys():
print(np.equal(param_value_ori[key], param_value_now[key]))
with _test_eager_guard():
func_isinstance()
func_isinstance()
......@@ -899,8 +875,6 @@ class TestPureFp16SaveLoad(unittest.TestCase):
self.assertTrue(
np.allclose(out_use_save_load[0], out_no_save_load[0]))
with _test_eager_guard():
func_isinstance()
func_isinstance()
......@@ -1005,8 +979,6 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
def test_inference_save_load(self):
self.inference_save_load()
with _test_eager_guard():
self.inference_save_load()
class TestResnet2(unittest.TestCase):
......@@ -1146,8 +1118,6 @@ class TestResnet2(unittest.TestCase):
self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2))
with _test_eager_guard():
func_isinstance()
func_isinstance()
def test_with_data_loader(self):
......@@ -1166,8 +1136,6 @@ class TestResnet2(unittest.TestCase):
self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2))
with _test_eager_guard():
func_isinstance()
func_isinstance()
def test_param_group(self):
......@@ -1189,8 +1157,6 @@ class TestResnet2(unittest.TestCase):
self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2))
with _test_eager_guard():
func_isinstance()
func_isinstance()
......@@ -1285,8 +1251,6 @@ class TestResnet(unittest.TestCase):
self.assertTrue(
np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-1))
with _test_eager_guard():
func_isinstance()
func_isinstance()
......@@ -1308,8 +1272,6 @@ class TestLayerNormFp16(unittest.TestCase):
self.assertTrue(
out.dtype == fluid.core.VarDesc.VarType.FP16)
with _test_eager_guard():
func_isinstance()
func_isinstance()
......@@ -1344,8 +1306,6 @@ class TestBf16(unittest.TestCase):
self.assertTrue(
np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
with _test_eager_guard():
func_isinstance()
func_isinstance()
......@@ -1399,8 +1359,6 @@ class TestAmpWithHook(unittest.TestCase):
loss = a.sum()
self.assertRaises(RuntimeError, loss.backward)
with _test_eager_guard():
func_isinstance()
func_isinstance()
def test_hook_change_place(self):
......@@ -1420,8 +1378,6 @@ class TestAmpWithHook(unittest.TestCase):
loss = a.sum()
self.assertRaises(RuntimeError, loss.backward)
with _test_eager_guard():
func_isinstance()
func_isinstance()
......
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
import paddle
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.autograd.py_layer import LegacyPyLayer, EagerPyLayer
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
......@@ -32,7 +32,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_simple_pylayer_multiple_output(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square):
......@@ -70,7 +70,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_simple_pylayer_return_none_with_no_grad(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square):
......@@ -112,7 +112,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_simple_pylayer_single_output(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.square):
......@@ -146,7 +146,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_num_output_match(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(
......@@ -175,7 +175,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_dtype(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x, dtype):
......@@ -206,7 +206,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_Exception_forward(self):
class Layer_None1(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_None1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, *args):
......@@ -220,7 +220,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z = Layer_None1.apply(input1)
class Layer_None2(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_None2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, *args):
......@@ -234,7 +234,7 @@ class TestPyLayer(unittest.TestCase):
# return None
z = Layer_None2.apply(input1)
class Layer_one1(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, *args):
......@@ -249,7 +249,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z = Layer_one1.apply(input1)
class Layer_one2(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, *args):
......@@ -263,7 +263,7 @@ class TestPyLayer(unittest.TestCase):
# return int
z = Layer_one2.apply(input1)
class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def backward(ctx, *args):
......@@ -280,7 +280,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_nograd(self):
class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.square, xx=None):
......@@ -305,7 +305,8 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_Exception_bk(self):
class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_bk_none1(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x):
......@@ -322,7 +323,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.sum().backward()
class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_bk_none2(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1, x2):
......@@ -339,7 +341,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_bk_one1(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_bk_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod
def forward(ctx, x):
......@@ -356,7 +359,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_bk_one2(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_bk_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod
def forward(ctx, x1, x2):
......@@ -374,7 +378,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x):
......@@ -388,7 +392,8 @@ class TestPyLayer(unittest.TestCase):
z = z[0] + z[1]
z.mean().backward()
class Layer_bk_match(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_bk_match(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x):
......@@ -412,7 +417,8 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_bk_return_none(self):
class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_bk_none1(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1, x2):
......@@ -431,7 +437,8 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Layer_bk_none2(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1, x2):
......@@ -457,7 +464,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_inplace(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x):
......@@ -494,7 +501,8 @@ class TestPyLayer(unittest.TestCase):
def test_pylayer_inplace_backward_error(self):
with _test_eager_guard():
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod
def forward(ctx, x):
......@@ -530,7 +538,8 @@ class TestPyLayer(unittest.TestCase):
def test_pylayer_inplace_backward_success_1(self):
with _test_eager_guard():
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod
def forward(ctx, x):
......@@ -564,7 +573,8 @@ class TestPyLayer(unittest.TestCase):
def test_pylayer_inplace_backward_success_2(self):
with _test_eager_guard():
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
@staticmethod
def forward(ctx, x):
......@@ -597,7 +607,8 @@ class TestPyLayer(unittest.TestCase):
def func_test_pylayer_inplace_and_leaf_exception(self):
class cus_pylayer_op(EagerPyLayer if in_dygraph_mode() else PyLayer):
class cus_pylayer_op(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x):
......@@ -633,7 +644,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_backward_in_backward(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x):
......@@ -665,7 +676,7 @@ class TestPyLayer(unittest.TestCase):
def func_test_return_to_tensor(self):
class Tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
class Tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
......@@ -779,7 +790,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_args_fake_tensor(self):
class Tanh(PyLayer):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
......@@ -797,7 +808,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_kwargs_fake_tensor(self):
class Tanh(PyLayer):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
......@@ -815,7 +826,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_return_fake_tensor(self):
class Tanh(PyLayer):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
......@@ -833,7 +844,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_forward_return_fake_tensor_tuple(self):
class Tanh(PyLayer):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
......@@ -851,7 +862,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_backward_return_fake_tensor_tuple(self):
class Tanh(PyLayer):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1, x2):
......@@ -871,7 +882,7 @@ class TestPyLayerReturnType(unittest.TestCase):
def test_backward_return_fake_tensor(self):
class Tanh(PyLayer):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
......
......@@ -31,7 +31,7 @@ from paddle.distributed import alltoall, all_gather
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed import fleet
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.autograd import PyLayer
from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate
from .utils import count_by_gate
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute
......@@ -132,53 +132,6 @@ class MoEScatter(PyLayer):
return grad_in, None, None, None
class EagerMoEScatter(EagerPyLayer):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
"""
@staticmethod
def forward(ctx,
inp,
pos,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
group=None):
local_input_buf = _local_scatter(inp, pos)
if world_size > 1:
global_input_buf = global_scatter(local_input_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_input_buf = local_input_buf
ctx.moe_args = inp.shape[0], world_size, group
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf
@staticmethod
def backward(ctx, grad):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensor()
(inp_batch_size, world_size, group) = ctx.moe_args
if world_size > 1:
local_grad_in = global_gather(grad,
local_expert_count,
global_expert_count,
group=group)
else:
local_grad_in = grad
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None
class MoEGather(PyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
......@@ -226,53 +179,6 @@ class MoEGather(PyLayer):
return global_grad_out_buf, None, None, None
class EagerMoEGather(EagerPyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MoEScatter.
"""
@staticmethod
def forward(ctx,
global_output_buf,
pos,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
group=None):
if world_size > 1:
local_output_buf = global_gather(global_output_buf,
local_expert_count,
global_expert_count,
group=group)
else:
local_output_buf = global_output_buf
output = _local_gather(local_output_buf,
pos,
local_batch_size,
maybe_overlap=False)
ctx.moe_args = (global_output_buf.shape[0], world_size, group)
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensor()
fwd_batch_size, world_size, group = ctx.moe_args
grad_out_buf = _local_scatter(grad_out, pos)
if world_size > 1:
global_grad_out_buf = global_scatter(grad_out_buf,
local_expert_count,
global_expert_count,
group=group)
else:
global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None
class AllGather(PyLayer):
r"""
A wrapper for the All-Gather function to support auto-differentiation.
......@@ -295,28 +201,6 @@ class AllGather(PyLayer):
ends=[(rank + 1) * dim0])
class EagerAllGather(EagerPyLayer):
r"""
A wrapper for the All-Gather function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
tensor_list = []
paddle.distributed.all_gather(tensor_list, inp, group=group)
output = paddle.concat(tensor_list, axis=0)
ctx.args = rank, inp.shape[0]
return output
@staticmethod
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return paddle.slice(grad_out,
axes=[0],
starts=[rank * dim0],
ends=[(rank + 1) * dim0])
class Slice(PyLayer):
r"""
A wrapper for the Slice function to support auto-differentiation.
......@@ -341,30 +225,6 @@ class Slice(PyLayer):
return _all_gather(grad_out, group=group)
class EagerSlice(EagerPyLayer):
r"""
A wrapper for the Slice function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B = inp.shape[0]
local_batch_size = B // world_size
batch_start = local_batch_size * rank
batch_end = min(batch_start + local_batch_size, B)
inp = paddle.slice(inp,
axes=[0],
starts=[batch_start],
ends=[batch_end])
ctx.args = world_size, group
return inp
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
return _all_gather(grad_out, group=group)
def prepare_forward(gate, num_expert, world_size, moe_group):
pos, local_expert_count, global_expert_count = count_by_gate(
gate, num_expert, world_size, group=moe_group)
......@@ -517,10 +377,7 @@ class MoELayer(nn.Layer):
mp_rank = self.mp_group.rank
mp_size = self.mp_group.nranks
if mp_size > 1:
if in_dygraph_mode():
inp = EagerSlice.apply(inp, mp_rank, mp_size, self.mp_group)
else:
inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group)
inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group)
value, gate = self.gate(inp)
(
......@@ -541,14 +398,9 @@ class MoELayer(nn.Layer):
temp_pos = pos
assert topk == self.top_k
if in_dygraph_mode():
x = EagerMoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)
else:
x = MoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)
x = MoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)
d_model = self.d_model
......@@ -577,23 +429,15 @@ class MoELayer(nn.Layer):
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
if in_dygraph_mode():
x = EagerMoEGather.apply(x, pos, local_expert_count,
global_expert_count, out_batch_size,
self.world_size, self.group)
else:
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count,
out_batch_size, self.world_size, self.group)
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count,
out_batch_size, self.world_size, self.group)
x = x.reshape([-1, self.top_k, d_model])
value = value.reshape([x.shape[0], 1, self.top_k])
x = paddle.bmm(value, x).reshape([-1, d_model])
if mp_size > 1:
if in_dygraph_mode():
x = EagerAllGather.apply(x, mp_rank, mp_size, self.mp_group)
else:
x = AllGather.apply(x, mp_rank, mp_size, self.mp_group)
x = AllGather.apply(x, mp_rank, mp_size, self.mp_group)
x = paddle.reshape_(x, origin_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册