未验证 提交 3206fa80 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add recompute support (#53044)

* [CustomDevice] add recompute support

* update
上级 7e19d16f
...@@ -109,6 +109,9 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext( ...@@ -109,6 +109,9 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
dev_ctx->SetAllocator(instance.GetAllocator(p).get()); dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultXPUGenerator(p.GetDeviceId()).get()); dev_ctx->SetGenerator(phi::DefaultXPUGenerator(p.GetDeviceId()).get());
#endif #endif
} else if (p.GetType() == phi::AllocationType::CUSTOM) {
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get());
} else { } else {
dev_ctx->SetAllocator(instance.GetAllocator(p).get()); dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCPUGenerator().get()); dev_ctx->SetGenerator(phi::DefaultCPUGenerator().get());
......
...@@ -88,6 +88,7 @@ void BindGenerator(py::module* m_ptr) { ...@@ -88,6 +88,7 @@ void BindGenerator(py::module* m_ptr) {
m.def("default_cpu_generator", &phi::DefaultCPUGenerator); m.def("default_cpu_generator", &phi::DefaultCPUGenerator);
m.def("default_cuda_generator", &phi::DefaultCUDAGenerator); m.def("default_cuda_generator", &phi::DefaultCUDAGenerator);
m.def("default_xpu_generator", &phi::DefaultXPUGenerator); m.def("default_xpu_generator", &phi::DefaultXPUGenerator);
m.def("default_custom_device_generator", &phi::DefaultCustomDeviceGenerator);
m.def("set_random_seed_generator", &phi::SetRandomSeedGenerator); m.def("set_random_seed_generator", &phi::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &phi::GetRandomSeedGenerator); m.def("get_random_seed_generator", &phi::GetRandomSeedGenerator);
} }
......
...@@ -99,6 +99,17 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() { ...@@ -99,6 +99,17 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator; return default_cpu_generator;
} }
const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place) {
static std::
unordered_map<phi::Place, std::shared_ptr<Generator>, phi::Place::Hash>
generators;
if (generators.find(place) == generators.end()) {
generators.insert({place, std::make_shared<Generator>(GetRandomSeed())});
}
return generators[place];
}
using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>; using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;
static RNGMap& GetRandomSeedGeneratorMap() { static RNGMap& GetRandomSeedGeneratorMap() {
......
...@@ -25,6 +25,8 @@ limitations under the License. */ ...@@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo> #include <typeinfo>
#include <utility> #include <utility>
#include "paddle/phi/common/place.h"
namespace phi { namespace phi {
class Generator { class Generator {
...@@ -80,6 +82,9 @@ const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id = -1); ...@@ -80,6 +82,9 @@ const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id = -1);
const std::shared_ptr<Generator>& DefaultXPUGenerator(int64_t device_id = -1); const std::shared_ptr<Generator>& DefaultXPUGenerator(int64_t device_id = -1);
const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place);
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t); std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& SetRandomSeedGenerator( const std::shared_ptr<Generator>& SetRandomSeedGenerator(
......
...@@ -205,7 +205,12 @@ class HybridParallelClipGrad: ...@@ -205,7 +205,12 @@ class HybridParallelClipGrad:
clip_var_fp16 = paddle.cast(clip_var, paddle.float16) clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
# bf16 is not supported on XPU now # bf16 is not supported on XPU now
if not paddle.is_compiled_with_xpu(): if not (
paddle.is_compiled_with_xpu()
or isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
)
):
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16) clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
......
...@@ -222,13 +222,19 @@ def _recompute_without_reentrant( ...@@ -222,13 +222,19 @@ def _recompute_without_reentrant(
if preserve_rng_state: if preserve_rng_state:
cur_device = paddle.get_device() cur_device = paddle.get_device()
if 'gpu:' not in cur_device: if 'gpu:' in cur_device:
fw_cuda_rng_state = paddle.get_cuda_rng_state()
elif (
cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
):
fw_cuda_rng_state = paddle.get_rng_state(cur_device)
else:
raise RuntimeError( raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".format( "Recompute with RNG perserve is not support current device: {}.".format(
cur_device cur_device
) )
) )
fw_cuda_rng_state = paddle.get_cuda_rng_state()
fwd_cuda_rng_state_tracker = ( fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker() get_rng_state_tracker().get_states_tracker()
) )
......
...@@ -154,14 +154,19 @@ def _broadcast_object_list_help(object_list, hcg): ...@@ -154,14 +154,19 @@ def _broadcast_object_list_help(object_list, hcg):
def broadcast_input_data(hcg, *inputs, **kwargs): def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device() cur_device = paddle.get_device()
dev = cur_device.split(":")[0] dev = cur_device.split(":")[0]
assert dev in [ assert (
dev
in [
"xpu", "xpu",
"gpu", "gpu",
"npu", ]
], f"Only support xpu, gpu and npu now, but this is {dev}" or dev in paddle.device.get_all_custom_device_type()
), f"Only support xpu, gpu and custom_device now, but this is {dev}"
dev_idx = int(cur_device.split(':')[1]) dev_idx = int(cur_device.split(':')[1])
if dev == "gpu": if dev == "gpu":
place = paddle.CUDAPlace(dev_idx) place = paddle.CUDAPlace(dev_idx)
elif dev in paddle.device.get_all_custom_device_type():
place = paddle.CustomPlace(dev, dev_idx)
else: else:
place = eval(f"paddle.{dev.upper()}Place")(dev_idx) place = eval(f"paddle.{dev.upper()}Place")(dev_idx)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# TODO: define random api # TODO: define random api
import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
...@@ -48,7 +49,18 @@ def seed(seed): ...@@ -48,7 +49,18 @@ def seed(seed):
elif core.is_compiled_with_xpu(): elif core.is_compiled_with_xpu():
for i in range(core.get_xpu_device_count()): for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).manual_seed(seed) core.default_xpu_generator(i).manual_seed(seed)
place = fluid.framework._current_expected_place()
if isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).manual_seed(seed)
return core.default_cpu_generator().manual_seed(seed) return core.default_cpu_generator().manual_seed(seed)
...@@ -70,7 +82,7 @@ def get_rng_state(device=None): ...@@ -70,7 +82,7 @@ def get_rng_state(device=None):
if device is None: if device is None:
place = fluid.framework._current_expected_place() place = fluid.framework._current_expected_place()
else: else:
place = device._convert_to_place(device) place = paddle.device._convert_to_place(device)
if isinstance(place, core.CPUPlace): if isinstance(place, core.CPUPlace):
state_list.append(core.default_cpu_generator().get_state()) state_list.append(core.default_cpu_generator().get_state())
...@@ -80,6 +92,19 @@ def get_rng_state(device=None): ...@@ -80,6 +92,19 @@ def get_rng_state(device=None):
elif isinstance(place, core.XPUPlace): elif isinstance(place, core.XPUPlace):
for i in range(core.get_xpu_device_count()): for i in range(core.get_xpu_device_count()):
state_list.append(core.default_xpu_generator(i).get_state()) state_list.append(core.default_xpu_generator(i).get_state())
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
state_list.append(
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).get_state()
)
else: else:
raise ValueError( raise ValueError(
"get_rng_state is not implemented for current device: {}".format( "get_rng_state is not implemented for current device: {}".format(
...@@ -157,6 +182,21 @@ def set_rng_state(state_list, device=None): ...@@ -157,6 +182,21 @@ def set_rng_state(state_list, device=None):
) )
for i in range(core.get_xpu_device_count()): for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).set_state(state_list[i]) core.default_xpu_generator(i).set_state(state_list[i])
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
if not len(state_list) == dev_cnt:
raise ValueError(
f"Length of custom device state list shoule be equal to the {place.get_dtype_type()} device count"
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).set_state(state_list[i])
elif isinstance(place, core.CPUPlace): elif isinstance(place, core.CPUPlace):
if not len(state_list) == 1: if not len(state_list) == 1:
raise ValueError("Length of cpu state list shoule be equal to 1") raise ValueError("Length of cpu state list shoule be equal to 1")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册