未验证 提交 3206fa80 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add recompute support (#53044)

* [CustomDevice] add recompute support

* update
上级 7e19d16f
......@@ -109,6 +109,9 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultXPUGenerator(p.GetDeviceId()).get());
#endif
} else if (p.GetType() == phi::AllocationType::CUSTOM) {
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get());
} else {
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCPUGenerator().get());
......
......@@ -88,6 +88,7 @@ void BindGenerator(py::module* m_ptr) {
m.def("default_cpu_generator", &phi::DefaultCPUGenerator);
m.def("default_cuda_generator", &phi::DefaultCUDAGenerator);
m.def("default_xpu_generator", &phi::DefaultXPUGenerator);
m.def("default_custom_device_generator", &phi::DefaultCustomDeviceGenerator);
m.def("set_random_seed_generator", &phi::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &phi::GetRandomSeedGenerator);
}
......
......@@ -99,6 +99,17 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator;
}
const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place) {
static std::
unordered_map<phi::Place, std::shared_ptr<Generator>, phi::Place::Hash>
generators;
if (generators.find(place) == generators.end()) {
generators.insert({place, std::make_shared<Generator>(GetRandomSeed())});
}
return generators[place];
}
using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;
static RNGMap& GetRandomSeedGeneratorMap() {
......
......@@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo>
#include <utility>
#include "paddle/phi/common/place.h"
namespace phi {
class Generator {
......@@ -80,6 +82,9 @@ const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id = -1);
const std::shared_ptr<Generator>& DefaultXPUGenerator(int64_t device_id = -1);
const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place);
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& SetRandomSeedGenerator(
......
......@@ -205,7 +205,12 @@ class HybridParallelClipGrad:
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
# bf16 is not supported on XPU now
if not paddle.is_compiled_with_xpu():
if not (
paddle.is_compiled_with_xpu()
or isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
)
):
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
for p, g in params_grads:
if g is None:
......
......@@ -222,13 +222,19 @@ def _recompute_without_reentrant(
if preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
if 'gpu:' in cur_device:
fw_cuda_rng_state = paddle.get_cuda_rng_state()
elif (
cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
):
fw_cuda_rng_state = paddle.get_rng_state(cur_device)
else:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".format(
cur_device
)
)
fw_cuda_rng_state = paddle.get_cuda_rng_state()
fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
......
......@@ -154,14 +154,19 @@ def _broadcast_object_list_help(object_list, hcg):
def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device()
dev = cur_device.split(":")[0]
assert dev in [
"xpu",
"gpu",
"npu",
], f"Only support xpu, gpu and npu now, but this is {dev}"
assert (
dev
in [
"xpu",
"gpu",
]
or dev in paddle.device.get_all_custom_device_type()
), f"Only support xpu, gpu and custom_device now, but this is {dev}"
dev_idx = int(cur_device.split(':')[1])
if dev == "gpu":
place = paddle.CUDAPlace(dev_idx)
elif dev in paddle.device.get_all_custom_device_type():
place = paddle.CustomPlace(dev, dev_idx)
else:
place = eval(f"paddle.{dev.upper()}Place")(dev_idx)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# TODO: define random api
import paddle
from paddle import fluid
from paddle.fluid import core
......@@ -48,7 +49,18 @@ def seed(seed):
elif core.is_compiled_with_xpu():
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).manual_seed(seed)
place = fluid.framework._current_expected_place()
if isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).manual_seed(seed)
return core.default_cpu_generator().manual_seed(seed)
......@@ -70,7 +82,7 @@ def get_rng_state(device=None):
if device is None:
place = fluid.framework._current_expected_place()
else:
place = device._convert_to_place(device)
place = paddle.device._convert_to_place(device)
if isinstance(place, core.CPUPlace):
state_list.append(core.default_cpu_generator().get_state())
......@@ -80,6 +92,19 @@ def get_rng_state(device=None):
elif isinstance(place, core.XPUPlace):
for i in range(core.get_xpu_device_count()):
state_list.append(core.default_xpu_generator(i).get_state())
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
state_list.append(
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).get_state()
)
else:
raise ValueError(
"get_rng_state is not implemented for current device: {}".format(
......@@ -157,6 +182,21 @@ def set_rng_state(state_list, device=None):
)
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).set_state(state_list[i])
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
if not len(state_list) == dev_cnt:
raise ValueError(
f"Length of custom device state list shoule be equal to the {place.get_dtype_type()} device count"
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).set_state(state_list[i])
elif isinstance(place, core.CPUPlace):
if not len(state_list) == 1:
raise ValueError("Length of cpu state list shoule be equal to 1")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册