From 3206fa807f412e52fb7691cea3e5f3798bc1a4cc Mon Sep 17 00:00:00 2001 From: ronnywang Date: Wed, 19 Apr 2023 19:23:48 +0800 Subject: [PATCH] [CustomDevice] add recompute support (#53044) * [CustomDevice] add recompute support * update --- paddle/fluid/platform/device_context.cc | 3 ++ paddle/fluid/pybind/generator_py.cc | 1 + paddle/phi/core/generator.cc | 11 +++++ paddle/phi/core/generator.h | 5 +++ .../hybrid_parallel_optimizer.py | 7 ++- .../distributed/fleet/recompute/recompute.py | 10 ++++- .../fleet/utils/hybrid_parallel_util.py | 15 ++++--- python/paddle/framework/random.py | 44 ++++++++++++++++++- 8 files changed, 86 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 8e15fb99490..ee12b42c805 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -109,6 +109,9 @@ inline std::unique_ptr CreateDeviceContext( dev_ctx->SetAllocator(instance.GetAllocator(p).get()); dev_ctx->SetGenerator(phi::DefaultXPUGenerator(p.GetDeviceId()).get()); #endif + } else if (p.GetType() == phi::AllocationType::CUSTOM) { + dev_ctx->SetAllocator(instance.GetAllocator(p).get()); + dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get()); } else { dev_ctx->SetAllocator(instance.GetAllocator(p).get()); dev_ctx->SetGenerator(phi::DefaultCPUGenerator().get()); diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index 32765145731..99621b1463e 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -88,6 +88,7 @@ void BindGenerator(py::module* m_ptr) { m.def("default_cpu_generator", &phi::DefaultCPUGenerator); m.def("default_cuda_generator", &phi::DefaultCUDAGenerator); m.def("default_xpu_generator", &phi::DefaultXPUGenerator); + m.def("default_custom_device_generator", &phi::DefaultCustomDeviceGenerator); m.def("set_random_seed_generator", &phi::SetRandomSeedGenerator); m.def("get_random_seed_generator", &phi::GetRandomSeedGenerator); } diff --git a/paddle/phi/core/generator.cc b/paddle/phi/core/generator.cc index 51674af6699..64bff05654d 100644 --- a/paddle/phi/core/generator.cc +++ b/paddle/phi/core/generator.cc @@ -99,6 +99,17 @@ const std::shared_ptr& DefaultCPUGenerator() { return default_cpu_generator; } +const std::shared_ptr& DefaultCustomDeviceGenerator( + const phi::CustomPlace& place) { + static std:: + unordered_map, phi::Place::Hash> + generators; + if (generators.find(place) == generators.end()) { + generators.insert({place, std::make_shared(GetRandomSeed())}); + } + return generators[place]; +} + using RNGMap = std::unordered_map>; static RNGMap& GetRandomSeedGeneratorMap() { diff --git a/paddle/phi/core/generator.h b/paddle/phi/core/generator.h index 5473be29780..a1d985dc772 100644 --- a/paddle/phi/core/generator.h +++ b/paddle/phi/core/generator.h @@ -25,6 +25,8 @@ limitations under the License. */ #include #include +#include "paddle/phi/common/place.h" + namespace phi { class Generator { @@ -80,6 +82,9 @@ const std::shared_ptr& DefaultCUDAGenerator(int64_t device_id = -1); const std::shared_ptr& DefaultXPUGenerator(int64_t device_id = -1); +const std::shared_ptr& DefaultCustomDeviceGenerator( + const phi::CustomPlace& place); + std::shared_ptr GetCPURandomEngine(uint64_t); const std::shared_ptr& SetRandomSeedGenerator( diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 405ef5492af..ab1b270e2fd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -205,7 +205,12 @@ class HybridParallelClipGrad: clip_var_fp16 = paddle.cast(clip_var, paddle.float16) # bf16 is not supported on XPU now - if not paddle.is_compiled_with_xpu(): + if not ( + paddle.is_compiled_with_xpu() + or isinstance( + paddle.framework._current_expected_place(), paddle.CustomPlace + ) + ): clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16) for p, g in params_grads: if g is None: diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index b0b9885c33e..8bbbe8e4e7e 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -222,13 +222,19 @@ def _recompute_without_reentrant( if preserve_rng_state: cur_device = paddle.get_device() - if 'gpu:' not in cur_device: + if 'gpu:' in cur_device: + fw_cuda_rng_state = paddle.get_cuda_rng_state() + elif ( + cur_device.split(':')[0] + in paddle.device.get_all_custom_device_type() + ): + fw_cuda_rng_state = paddle.get_rng_state(cur_device) + else: raise RuntimeError( "Recompute with RNG perserve is not support current device: {}.".format( cur_device ) ) - fw_cuda_rng_state = paddle.get_cuda_rng_state() fwd_cuda_rng_state_tracker = ( get_rng_state_tracker().get_states_tracker() ) diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index fc7b463bd81..04093ebcb35 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -154,14 +154,19 @@ def _broadcast_object_list_help(object_list, hcg): def broadcast_input_data(hcg, *inputs, **kwargs): cur_device = paddle.get_device() dev = cur_device.split(":")[0] - assert dev in [ - "xpu", - "gpu", - "npu", - ], f"Only support xpu, gpu and npu now, but this is {dev}" + assert ( + dev + in [ + "xpu", + "gpu", + ] + or dev in paddle.device.get_all_custom_device_type() + ), f"Only support xpu, gpu and custom_device now, but this is {dev}" dev_idx = int(cur_device.split(':')[1]) if dev == "gpu": place = paddle.CUDAPlace(dev_idx) + elif dev in paddle.device.get_all_custom_device_type(): + place = paddle.CustomPlace(dev, dev_idx) else: place = eval(f"paddle.{dev.upper()}Place")(dev_idx) diff --git a/python/paddle/framework/random.py b/python/paddle/framework/random.py index b6f4dd6c817..fff7f5eecd9 100644 --- a/python/paddle/framework/random.py +++ b/python/paddle/framework/random.py @@ -13,6 +13,7 @@ # limitations under the License. # TODO: define random api +import paddle from paddle import fluid from paddle.fluid import core @@ -48,7 +49,18 @@ def seed(seed): elif core.is_compiled_with_xpu(): for i in range(core.get_xpu_device_count()): core.default_xpu_generator(i).manual_seed(seed) - + place = fluid.framework._current_expected_place() + if isinstance(place, core.CustomPlace): + dev_cnt = sum( + [ + place.get_device_type() == s.split(':')[0] + for s in core.get_available_custom_device() + ] + ) + for i in range(dev_cnt): + core.default_custom_device_generator( + core.CustomPlace(place.get_device_type(), i) + ).manual_seed(seed) return core.default_cpu_generator().manual_seed(seed) @@ -70,7 +82,7 @@ def get_rng_state(device=None): if device is None: place = fluid.framework._current_expected_place() else: - place = device._convert_to_place(device) + place = paddle.device._convert_to_place(device) if isinstance(place, core.CPUPlace): state_list.append(core.default_cpu_generator().get_state()) @@ -80,6 +92,19 @@ def get_rng_state(device=None): elif isinstance(place, core.XPUPlace): for i in range(core.get_xpu_device_count()): state_list.append(core.default_xpu_generator(i).get_state()) + elif isinstance(place, core.CustomPlace): + dev_cnt = sum( + [ + place.get_device_type() == s.split(':')[0] + for s in core.get_available_custom_device() + ] + ) + for i in range(dev_cnt): + state_list.append( + core.default_custom_device_generator( + core.CustomPlace(place.get_device_type(), i) + ).get_state() + ) else: raise ValueError( "get_rng_state is not implemented for current device: {}".format( @@ -157,6 +182,21 @@ def set_rng_state(state_list, device=None): ) for i in range(core.get_xpu_device_count()): core.default_xpu_generator(i).set_state(state_list[i]) + elif isinstance(place, core.CustomPlace): + dev_cnt = sum( + [ + place.get_device_type() == s.split(':')[0] + for s in core.get_available_custom_device() + ] + ) + if not len(state_list) == dev_cnt: + raise ValueError( + f"Length of custom device state list shoule be equal to the {place.get_dtype_type()} device count" + ) + for i in range(dev_cnt): + core.default_custom_device_generator( + core.CustomPlace(place.get_device_type(), i) + ).set_state(state_list[i]) elif isinstance(place, core.CPUPlace): if not len(state_list) == 1: raise ValueError("Length of cpu state list shoule be equal to 1") -- GitLab