未验证 提交 98c17a68 编写于 作者: Q QingshuChen 提交者: GitHub

suport recompute for kunlun (#49069)

上级 644dfc60
......@@ -254,9 +254,13 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
auto dev_ctxes = op->DeviceContext();
auto &inputs = op->Inputs();
for (auto &input : inputs) {
auto dev_ctxes = input->GeneratedOp()->DeviceContext();
for (auto &item : dev_ctxes) {
((platform::XPUDeviceContext *)(item.second))->Wait();
if (input && input->GeneratedOp() != nullptr) {
auto dev_ctxes = input->GeneratedOp()->DeviceContext();
for (auto &item : dev_ctxes) {
((platform::XPUDeviceContext *)(item.second))->Wait();
}
} else {
VLOG(3) << "No generated op:" << op->Name();
}
}
op->Run(strategy_.use_device_);
......
......@@ -20,11 +20,43 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
const std::shared_ptr<Generator>& DefaultXPUGenerator(int64_t device_id) {
#if defined(PADDLE_WITH_XPU)
static int64_t num_xpu_devices = -1;
static std::once_flag num_devices_init_flag;
static std::deque<std::once_flag> xpu_device_flags;
static std::vector<std::shared_ptr<Generator>> default_xpu_generators;
std::call_once(num_devices_init_flag, []() {
num_xpu_devices = paddle::platform::GetXPUDeviceCount();
xpu_device_flags.resize(num_xpu_devices);
default_xpu_generators.resize(num_xpu_devices);
});
if (device_id < 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"xpu device id shoule be greater than 0"));
}
std::call_once(xpu_device_flags[device_id], [device_id]() {
default_xpu_generators[device_id] =
std::make_shared<Generator>(GetRandomSeed(), device_id);
VLOG(4) << "initial seed: "
<< default_xpu_generators[device_id]->GetCurrentSeed();
});
return default_xpu_generators[device_id];
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"getDefaultXPUGenerator only support in XPU place"));
#endif
}
const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
......@@ -107,6 +107,8 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator();
const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id = -1);
const std::shared_ptr<Generator>& DefaultXPUGenerator(int64_t device_id = -1);
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& SetRandomSeedGenerator(
......
......@@ -170,6 +170,13 @@ std::unique_ptr<DeviceContext> CreateDeviceContext(
cuda_ctx->PartialInitWithAllocator();
dev_ctx->SetGenerator(
framework::DefaultCUDAGenerator(p.GetDeviceId()).get());
#endif
} else if (is_xpu_place(p)) {
#if defined(PADDLE_WITH_XPU)
dev_ctx->SetAllocator(
memory::allocation::AllocatorFacade::Instance().GetAllocator(p).get());
dev_ctx->SetGenerator(
framework::DefaultXPUGenerator(p.GetDeviceId()).get());
#endif
} else {
dev_ctx->SetAllocator(
......
......@@ -90,6 +90,7 @@ void BindGenerator(py::module* m_ptr) {
.def("random", &framework::Generator::Random64);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
m.def("default_cuda_generator", &framework::DefaultCUDAGenerator);
m.def("default_xpu_generator", &framework::DefaultXPUGenerator);
m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator);
}
......
......@@ -58,6 +58,9 @@ void DropoutRawKernel(const Context& dev_ctx,
} else {
seed_data = fix_seed ? seed : 0;
}
if (seed_data == 0) {
seed_data = dev_ctx.GetGenerator()->Random64();
}
auto* mask_data = dev_ctx.template Alloc<T>(mask);
// Special case when dropout_prob is 1.0
......
......@@ -327,6 +327,8 @@ from .tensor.einsum import einsum # noqa: F401
from .framework.random import seed # noqa: F401
from .framework.random import get_cuda_rng_state # noqa: F401
from .framework.random import set_cuda_rng_state # noqa: F401
from .framework.random import get_rng_state # noqa: F401
from .framework.random import set_rng_state # noqa: F401
from .framework import ParamAttr # noqa: F401
from .framework import CPUPlace # noqa: F401
from .framework import IPUPlace # noqa: F401
......@@ -424,6 +426,7 @@ __all__ = [ # noqa
'save',
'multinomial',
'get_cuda_rng_state',
'get_rng_state',
'rank',
'empty_like',
'eye',
......@@ -606,6 +609,7 @@ __all__ = [ # noqa
'unique',
'unique_consecutive',
'set_cuda_rng_state',
'set_rng_state',
'set_printoptions',
'std',
'flatten',
......
......@@ -56,16 +56,15 @@ def check_recompute_necessary(inputs):
@contextlib.contextmanager
def swith_rng_state_tracker(rng_state, tracker):
orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_cuda_rng_state(rng_state)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_rng_state(rng_state)
get_rng_state_tracker().set_states_tracker(tracker)
try:
yield
finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state)
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
class LegacyRecomputeFunction(LegacyPyLayer):
......@@ -95,15 +94,8 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if ctx.preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".format(
cur_device
)
)
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = (
ctx.fw_rng_state = paddle.get_rng_state()
ctx.fwd_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
......@@ -156,7 +148,7 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state_tracker(
ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
ctx.fw_rng_state, ctx.fwd_rng_state_tracker
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
......@@ -244,15 +236,8 @@ class RecomputeFunction(PyLayer):
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if ctx.preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".format(
cur_device
)
)
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = (
ctx.fw_rng_state = paddle.get_rng_state()
ctx.fwd_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
......@@ -305,7 +290,7 @@ class RecomputeFunction(PyLayer):
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state_tracker(
ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
ctx.fw_rng_state, ctx.fwd_rng_state_tracker
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
......
......@@ -312,15 +312,6 @@ class TestRecompute(unittest.TestCase):
recompute_block=[2], recompute_kwargs=kwargs
)
def test_recompute_cpu_rng(self):
paddle.set_device("cpu")
for flag in [True, False]:
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"use_reentrant": flag},
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import paddle
from paddle.distributed.fleet.utils import recompute
def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(
block_name + "_fc_0",
paddle.nn.Linear(input_size, input_size, bias_attr=False),
),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(
block_name + "_fc_1",
paddle.nn.Linear(input_size, input_size, bias_attr=False),
),
(block_name + "_relu_2", paddle.nn.ReLU()),
)
if is_last:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(input_size, 1, bias_attr=False),
)
else:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(input_size, input_size, bias_attr=False),
)
return block
class Naive_fc_net(paddle.nn.Layer):
def __init__(
self, input_size=10, recompute_blocks=[1, 3], recompute_kwargs={}
):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
self.total_func = [
self.runfunc0,
self.runfunc1,
self.runfunc2,
self.runfunc3,
self.runfunc4,
]
def forward(self, inputs):
nums = len(self.total_func)
for i in range(nums):
if i in self.recompute_blocks:
inputs = recompute(
self.total_func[i], inputs, **{"preserve_rng_state": True}
)
else:
inputs = self.total_func[i](inputs)
return inputs
def run_model(xpu_state, recompute_block=[], recompute_kwargs={}):
gen = paddle.seed(10)
random.seed(10)
batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
recompute_blocks=recompute_block,
recompute_kwargs=recompute_kwargs,
)
optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=model.parameters()
)
loss_ = []
param_ = []
grad_ = []
for _ in range(5):
x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
y_pred = model(x)
loss = y_pred.mean()
loss_.append(loss.item())
loss.backward()
optimizer.step()
param_.append(model.parameters()[9])
grad_.append(model.parameters()[3]._grad_ivar())
optimizer.clear_grad()
return loss_, param_, grad_
xpu_state = paddle.get_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(xpu_state, recompute_block=[])
loss, param, grad = run_model(xpu_state, recompute_block=[1, 3])
# The result of the recompute_loss should be the same as the normal_loss.
np.testing.assert_allclose(loss_ref, loss, rtol=1e-05, atol=1e-05)
......@@ -45,10 +45,51 @@ def seed(seed):
if core.is_compiled_with_cuda():
for i in range(core.get_cuda_device_count()):
core.default_cuda_generator(i).manual_seed(seed)
elif core.is_compiled_with_xpu():
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).manual_seed(seed)
return core.default_cpu_generator().manual_seed(seed)
def get_rng_state(device=None):
"""
Get all random states of random generators of specified device.
Args:
device(str): This parameter determines the specific running device.
It can be ``cpu``, ``gpu``, ``xpu``, Default is None.
If None, return the generators of current device (specified by ``set_device``).
Returns:
GeneratorState: object.
Examples:
.. code-block:: python
import paddle
sts = paddle.get_rng_state()
"""
state_list = []
if device is None:
place = fluid.framework._current_expected_place()
else:
place = device._convert_to_place(device)
if isinstance(place, core.CPUPlace):
state_list.append(core.default_cpu_generator().get_state())
elif isinstance(place, core.CUDAPlace):
for i in range(core.get_cuda_device_count()):
state_list.append(core.default_cuda_generator(i).get_state())
elif isinstance(place, core.XPUPlace):
for i in range(core.get_xpu_device_count()):
state_list.append(core.default_xpu_generator(i).get_state())
else:
raise ValueError(
"get_rng_state is not implemented for current device: {}".format(
place
)
)
return state_list
def get_cuda_rng_state():
"""
......@@ -75,6 +116,59 @@ def get_cuda_rng_state():
return state_list
def set_rng_state(state_list, device=None):
"""
Sets generator state for all device generators.
Args:
state_list(list|tuple): The device states to set back to device generators. state_list is obtained from get_rng_state().
device(str): This parameter determines the specific running device.
It can be ``cpu``, ``gpu``, ``xpu``, Default is None.
If None, return the generators of current device (specified by ``set_device``).
Returns:
None.
Examples:
.. code-block:: python
import paddle
sts = paddle.get_rng_state()
paddle.set_rng_state(sts)
"""
if device is None:
place = fluid.framework._current_expected_place()
else:
place = device._convert_to_place(device)
if isinstance(place, core.CUDAPlace):
if not len(state_list) == core.get_cuda_device_count():
raise ValueError(
"Length of gpu state list shoule be equal to the gpu device count"
)
for i in range(core.get_cuda_device_count()):
core.default_cuda_generator(i).set_state(state_list[i])
elif isinstance(place, core.XPUPlace):
if not len(state_list) == core.get_xpu_device_count():
raise ValueError(
"Length of xpu state list shoule be equal to the xpu device count"
)
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).set_state(state_list[i])
elif isinstance(place, core.CPUPlace):
if not len(state_list) == 1:
raise ValueError("Length of cpu state list shoule be equal to 1")
core.default_cpu_generator().set_state(state_list[0])
else:
raise ValueError(
"set_rng_state is not implemented for current device: {}".format(
place
)
)
def set_cuda_rng_state(state_list):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册