From 98c17a68f271a074fe934f1b27a850f80a9b98d2 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Fri, 23 Dec 2022 22:09:09 +0800 Subject: [PATCH] suport recompute for kunlun (#49069) --- .../bind_threaded_ssa_graph_executor.cc | 10 +- paddle/fluid/framework/generator.cc | 32 +++++ paddle/fluid/framework/generator.h | 2 + paddle/fluid/platform/device_context.cc | 7 ++ paddle/fluid/pybind/generator_py.cc | 1 + paddle/phi/kernels/xpu/dropout_kernel.cc | 3 + python/paddle/__init__.py | 4 + .../distributed/fleet/recompute/recompute.py | 37 ++---- .../fleet/test_dygraph_recompute_for_eager.py | 9 -- .../unittests/xpu/test_recompute_op_xpu.py | 116 ++++++++++++++++++ python/paddle/framework/random.py | 94 ++++++++++++++ 11 files changed, 277 insertions(+), 38 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_recompute_op_xpu.py diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc index aa31a556c92..577d458ba7d 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc @@ -254,9 +254,13 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( auto dev_ctxes = op->DeviceContext(); auto &inputs = op->Inputs(); for (auto &input : inputs) { - auto dev_ctxes = input->GeneratedOp()->DeviceContext(); - for (auto &item : dev_ctxes) { - ((platform::XPUDeviceContext *)(item.second))->Wait(); + if (input && input->GeneratedOp() != nullptr) { + auto dev_ctxes = input->GeneratedOp()->DeviceContext(); + for (auto &item : dev_ctxes) { + ((platform::XPUDeviceContext *)(item.second))->Wait(); + } + } else { + VLOG(3) << "No generated op:" << op->Name(); } } op->Run(strategy_.use_device_); diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index cc2d76a57a9..ce516ca4f55 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -20,11 +20,43 @@ limitations under the License. */ #include #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { +const std::shared_ptr& DefaultXPUGenerator(int64_t device_id) { +#if defined(PADDLE_WITH_XPU) + + static int64_t num_xpu_devices = -1; + static std::once_flag num_devices_init_flag; + static std::deque xpu_device_flags; + static std::vector> default_xpu_generators; + + std::call_once(num_devices_init_flag, []() { + num_xpu_devices = paddle::platform::GetXPUDeviceCount(); + xpu_device_flags.resize(num_xpu_devices); + default_xpu_generators.resize(num_xpu_devices); + }); + if (device_id < 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "xpu device id shoule be greater than 0")); + } + + std::call_once(xpu_device_flags[device_id], [device_id]() { + default_xpu_generators[device_id] = + std::make_shared(GetRandomSeed(), device_id); + VLOG(4) << "initial seed: " + << default_xpu_generators[device_id]->GetCurrentSeed(); + }); + return default_xpu_generators[device_id]; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "getDefaultXPUGenerator only support in XPU place")); +#endif +} + const std::shared_ptr& DefaultCUDAGenerator(int64_t device_id) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h index f62e8f74d26..54096bf5c44 100644 --- a/paddle/fluid/framework/generator.h +++ b/paddle/fluid/framework/generator.h @@ -107,6 +107,8 @@ const std::shared_ptr& DefaultCPUGenerator(); const std::shared_ptr& DefaultCUDAGenerator(int64_t device_id = -1); +const std::shared_ptr& DefaultXPUGenerator(int64_t device_id = -1); + std::shared_ptr GetCPURandomEngine(uint64_t); const std::shared_ptr& SetRandomSeedGenerator( diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 539bbfb87d0..4ec96f606fa 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -170,6 +170,13 @@ std::unique_ptr CreateDeviceContext( cuda_ctx->PartialInitWithAllocator(); dev_ctx->SetGenerator( framework::DefaultCUDAGenerator(p.GetDeviceId()).get()); +#endif + } else if (is_xpu_place(p)) { +#if defined(PADDLE_WITH_XPU) + dev_ctx->SetAllocator( + memory::allocation::AllocatorFacade::Instance().GetAllocator(p).get()); + dev_ctx->SetGenerator( + framework::DefaultXPUGenerator(p.GetDeviceId()).get()); #endif } else { dev_ctx->SetAllocator( diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index 00ee7b23368..b144c79aabb 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -90,6 +90,7 @@ void BindGenerator(py::module* m_ptr) { .def("random", &framework::Generator::Random64); m.def("default_cpu_generator", &framework::DefaultCPUGenerator); m.def("default_cuda_generator", &framework::DefaultCUDAGenerator); + m.def("default_xpu_generator", &framework::DefaultXPUGenerator); m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator); m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator); } diff --git a/paddle/phi/kernels/xpu/dropout_kernel.cc b/paddle/phi/kernels/xpu/dropout_kernel.cc index c9645f06a13..f710f313414 100644 --- a/paddle/phi/kernels/xpu/dropout_kernel.cc +++ b/paddle/phi/kernels/xpu/dropout_kernel.cc @@ -58,6 +58,9 @@ void DropoutRawKernel(const Context& dev_ctx, } else { seed_data = fix_seed ? seed : 0; } + if (seed_data == 0) { + seed_data = dev_ctx.GetGenerator()->Random64(); + } auto* mask_data = dev_ctx.template Alloc(mask); // Special case when dropout_prob is 1.0 diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index d9a51f7016d..820a02c9343 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -327,6 +327,8 @@ from .tensor.einsum import einsum # noqa: F401 from .framework.random import seed # noqa: F401 from .framework.random import get_cuda_rng_state # noqa: F401 from .framework.random import set_cuda_rng_state # noqa: F401 +from .framework.random import get_rng_state # noqa: F401 +from .framework.random import set_rng_state # noqa: F401 from .framework import ParamAttr # noqa: F401 from .framework import CPUPlace # noqa: F401 from .framework import IPUPlace # noqa: F401 @@ -424,6 +426,7 @@ __all__ = [ # noqa 'save', 'multinomial', 'get_cuda_rng_state', + 'get_rng_state', 'rank', 'empty_like', 'eye', @@ -606,6 +609,7 @@ __all__ = [ # noqa 'unique', 'unique_consecutive', 'set_cuda_rng_state', + 'set_rng_state', 'set_printoptions', 'std', 'flatten', diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index ae4bbae6a69..f20bbdb3834 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -56,16 +56,15 @@ def check_recompute_necessary(inputs): @contextlib.contextmanager def swith_rng_state_tracker(rng_state, tracker): - orig_cuda_rng_state = paddle.get_cuda_rng_state() - orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() - - paddle.set_cuda_rng_state(rng_state) + orig_rng_state = paddle.get_rng_state() + orig_rng_tracker = get_rng_state_tracker().get_states_tracker() + paddle.set_rng_state(rng_state) get_rng_state_tracker().set_states_tracker(tracker) try: yield finally: - paddle.set_cuda_rng_state(orig_cuda_rng_state) - get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker) + paddle.set_rng_state(orig_rng_state) + get_rng_state_tracker().set_states_tracker(orig_rng_tracker) class LegacyRecomputeFunction(LegacyPyLayer): @@ -95,15 +94,8 @@ class LegacyRecomputeFunction(LegacyPyLayer): # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu. # one process with multiple gpu and mix-gpu-cpu senarios are not support if ctx.preserve_rng_state: - cur_device = paddle.get_device() - if 'gpu:' not in cur_device: - raise RuntimeError( - "Recompute with RNG perserve is not support current device: {}.".format( - cur_device - ) - ) - ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = ( + ctx.fw_rng_state = paddle.get_rng_state() + ctx.fwd_rng_state_tracker = ( get_rng_state_tracker().get_states_tracker() ) @@ -156,7 +148,7 @@ class LegacyRecomputeFunction(LegacyPyLayer): # need restore auto_cast state as well as w/b list if ctx.preserve_rng_state: with swith_rng_state_tracker( - ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker + ctx.fw_rng_state, ctx.fwd_rng_state_tracker ): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, @@ -244,15 +236,8 @@ class RecomputeFunction(PyLayer): # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu. # one process with multiple gpu and mix-gpu-cpu senarios are not support if ctx.preserve_rng_state: - cur_device = paddle.get_device() - if 'gpu:' not in cur_device: - raise RuntimeError( - "Recompute with RNG perserve is not support current device: {}.".format( - cur_device - ) - ) - ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = ( + ctx.fw_rng_state = paddle.get_rng_state() + ctx.fwd_rng_state_tracker = ( get_rng_state_tracker().get_states_tracker() ) @@ -305,7 +290,7 @@ class RecomputeFunction(PyLayer): # need restore auto_cast state as well as w/b list if ctx.preserve_rng_state: with swith_rng_state_tracker( - ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker + ctx.fw_rng_state, ctx.fwd_rng_state_tracker ): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index 5e982587c25..f496b4e4f09 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -312,15 +312,6 @@ class TestRecompute(unittest.TestCase): recompute_block=[2], recompute_kwargs=kwargs ) - def test_recompute_cpu_rng(self): - paddle.set_device("cpu") - for flag in [True, False]: - with self.assertRaises(RuntimeError): - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[2], - recompute_kwargs={"use_reentrant": flag}, - ) - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_recompute_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_recompute_op_xpu.py new file mode 100644 index 00000000000..3ec67eb76e1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_recompute_op_xpu.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np + +import paddle +from paddle.distributed.fleet.utils import recompute + + +def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + ( + block_name + "_fc_0", + paddle.nn.Linear(input_size, input_size, bias_attr=False), + ), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + ( + block_name + "_fc_1", + paddle.nn.Linear(input_size, input_size, bias_attr=False), + ), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear(input_size, 1, bias_attr=False), + ) + else: + block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear(input_size, input_size, bias_attr=False), + ) + return block + + +class Naive_fc_net(paddle.nn.Layer): + def __init__( + self, input_size=10, recompute_blocks=[1, 3], recompute_kwargs={} + ): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + self.total_func = [ + self.runfunc0, + self.runfunc1, + self.runfunc2, + self.runfunc3, + self.runfunc4, + ] + + def forward(self, inputs): + nums = len(self.total_func) + for i in range(nums): + if i in self.recompute_blocks: + inputs = recompute( + self.total_func[i], inputs, **{"preserve_rng_state": True} + ) + else: + inputs = self.total_func[i](inputs) + return inputs + + +def run_model(xpu_state, recompute_block=[], recompute_kwargs={}): + gen = paddle.seed(10) + random.seed(10) + batch_size, input_size = 1, 10 + model = Naive_fc_net( + input_size, + recompute_blocks=recompute_block, + recompute_kwargs=recompute_kwargs, + ) + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=model.parameters() + ) + loss_ = [] + param_ = [] + grad_ = [] + for _ in range(5): + x = paddle.rand(shape=[batch_size, input_size], dtype="float32") + y_pred = model(x) + loss = y_pred.mean() + loss_.append(loss.item()) + loss.backward() + optimizer.step() + param_.append(model.parameters()[9]) + grad_.append(model.parameters()[3]._grad_ivar()) + optimizer.clear_grad() + return loss_, param_, grad_ + + +xpu_state = paddle.get_rng_state() +# without recompute +loss_ref, param_ref, grad_ref = run_model(xpu_state, recompute_block=[]) +loss, param, grad = run_model(xpu_state, recompute_block=[1, 3]) +# The result of the recompute_loss should be the same as the normal_loss. +np.testing.assert_allclose(loss_ref, loss, rtol=1e-05, atol=1e-05) diff --git a/python/paddle/framework/random.py b/python/paddle/framework/random.py index 6d7d704808d..33e7535934a 100644 --- a/python/paddle/framework/random.py +++ b/python/paddle/framework/random.py @@ -45,10 +45,51 @@ def seed(seed): if core.is_compiled_with_cuda(): for i in range(core.get_cuda_device_count()): core.default_cuda_generator(i).manual_seed(seed) + elif core.is_compiled_with_xpu(): + for i in range(core.get_xpu_device_count()): + core.default_xpu_generator(i).manual_seed(seed) return core.default_cpu_generator().manual_seed(seed) +def get_rng_state(device=None): + """ + Get all random states of random generators of specified device. + Args: + device(str): This parameter determines the specific running device. + It can be ``cpu``, ``gpu``, ``xpu``, Default is None. + If None, return the generators of current device (specified by ``set_device``). + Returns: + GeneratorState: object. + Examples: + .. code-block:: python + import paddle + sts = paddle.get_rng_state() + """ + state_list = [] + if device is None: + place = fluid.framework._current_expected_place() + else: + place = device._convert_to_place(device) + + if isinstance(place, core.CPUPlace): + state_list.append(core.default_cpu_generator().get_state()) + elif isinstance(place, core.CUDAPlace): + for i in range(core.get_cuda_device_count()): + state_list.append(core.default_cuda_generator(i).get_state()) + elif isinstance(place, core.XPUPlace): + for i in range(core.get_xpu_device_count()): + state_list.append(core.default_xpu_generator(i).get_state()) + else: + raise ValueError( + "get_rng_state is not implemented for current device: {}".format( + place + ) + ) + + return state_list + + def get_cuda_rng_state(): """ @@ -75,6 +116,59 @@ def get_cuda_rng_state(): return state_list +def set_rng_state(state_list, device=None): + """ + + Sets generator state for all device generators. + + Args: + state_list(list|tuple): The device states to set back to device generators. state_list is obtained from get_rng_state(). + device(str): This parameter determines the specific running device. + It can be ``cpu``, ``gpu``, ``xpu``, Default is None. + If None, return the generators of current device (specified by ``set_device``). + + Returns: + None. + + Examples: + .. code-block:: python + + import paddle + sts = paddle.get_rng_state() + paddle.set_rng_state(sts) + + """ + if device is None: + place = fluid.framework._current_expected_place() + else: + place = device._convert_to_place(device) + + if isinstance(place, core.CUDAPlace): + if not len(state_list) == core.get_cuda_device_count(): + raise ValueError( + "Length of gpu state list shoule be equal to the gpu device count" + ) + for i in range(core.get_cuda_device_count()): + core.default_cuda_generator(i).set_state(state_list[i]) + elif isinstance(place, core.XPUPlace): + if not len(state_list) == core.get_xpu_device_count(): + raise ValueError( + "Length of xpu state list shoule be equal to the xpu device count" + ) + for i in range(core.get_xpu_device_count()): + core.default_xpu_generator(i).set_state(state_list[i]) + elif isinstance(place, core.CPUPlace): + if not len(state_list) == 1: + raise ValueError("Length of cpu state list shoule be equal to 1") + core.default_cpu_generator().set_state(state_list[0]) + else: + raise ValueError( + "set_rng_state is not implemented for current device: {}".format( + place + ) + ) + + def set_cuda_rng_state(state_list): """ -- GitLab