未验证 提交 2e1ac529 编写于 作者: H houj04 提交者: GitHub

[XPU] remove scale_loss in parallel.py (#53337)

* [XPU] remove scale_loss in parallel.py

* [XPU] throw Unimplemented when using Reducer
上级 eee9c788
...@@ -61,7 +61,9 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { ...@@ -61,7 +61,9 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
VLOG(4) << "after div 2" << *tensor; VLOG(4) << "after div 2" << *tensor;
} else if (platform::is_xpu_place(tensor->place())) { } else if (platform::is_xpu_place(tensor->place())) {
#ifdef PADDLE_WITH_XPU_BKCL #ifdef PADDLE_WITH_XPU_BKCL
// TODO(liuyuhui) support xpu about div nranks in the future PADDLE_THROW(
platform::errors::Unimplemented("DivNRanks is not supported on XPU / "
"XPU_BKCL, use EagerReducer instead."));
#endif #endif
} }
} }
......
...@@ -149,7 +149,7 @@ void *Alloc<platform::XPUPlace>(const platform::XPUPlace &place, size_t size) { ...@@ -149,7 +149,7 @@ void *Alloc<platform::XPUPlace>(const platform::XPUPlace &place, size_t size) {
VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place); VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place);
void *p = nullptr; void *p = nullptr;
platform::XPUDeviceGuard gurad(place.device); platform::XPUDeviceGuard guard(place.device);
int ret = xpu_malloc(reinterpret_cast<void **>(&p), size); int ret = xpu_malloc(reinterpret_cast<void **>(&p), size);
if (ret != XPU_SUCCESS) { if (ret != XPU_SUCCESS) {
VLOG(10) << "xpu memory malloc(" << size << ") failed, try again"; VLOG(10) << "xpu memory malloc(" << size << ") failed, try again";
...@@ -182,7 +182,7 @@ void Free<platform::XPUPlace>(const platform::XPUPlace &place, ...@@ -182,7 +182,7 @@ void Free<platform::XPUPlace>(const platform::XPUPlace &place,
VLOG(10) << "Free " << size << " bytes on " << platform::Place(place); VLOG(10) << "Free " << size << " bytes on " << platform::Place(place);
VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place);
platform::XPUDeviceGuard gurad(place.device); platform::XPUDeviceGuard guard(place.device);
xpu_free(p); xpu_free(p);
#else #else
PADDLE_THROW( PADDLE_THROW(
......
...@@ -268,7 +268,7 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -268,7 +268,7 @@ void BufferedReader::ReadAsync(size_t i) {
xpu_ptrs.emplace_back(xpu[i].mutable_data(place_, cpu[i].type())); xpu_ptrs.emplace_back(xpu[i].mutable_data(place_, cpu[i].type()));
} }
platform::XPUDeviceGuard gurad(place_.device); platform::XPUDeviceGuard guard(place_.device);
int r = xpu_event_record(events_[i].get(), compute_stream_); int r = xpu_event_record(events_[i].get(), compute_stream_);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_event_record"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_event_record");
r = xpu_stream_wait_event(stream_.get(), events_[i].get()); r = xpu_stream_wait_event(stream_.get(), events_[i].get());
......
...@@ -22,14 +22,14 @@ XpuStreamResourcePool::XpuStreamResourcePool() { ...@@ -22,14 +22,14 @@ XpuStreamResourcePool::XpuStreamResourcePool() {
pool_.reserve(dev_cnt); pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [dev_idx] { auto creator = [dev_idx] {
platform::XPUDeviceGuard gurad(dev_idx); platform::XPUDeviceGuard guard(dev_idx);
xpuStream stream; xpuStream stream;
xpu_stream_create(&stream); xpu_stream_create(&stream);
return stream; return stream;
}; };
auto deleter = [dev_idx](xpuStream stream) { auto deleter = [dev_idx](xpuStream stream) {
platform::XPUDeviceGuard gurad(dev_idx); platform::XPUDeviceGuard guard(dev_idx);
xpu_stream_destroy(stream); xpu_stream_destroy(stream);
}; };
...@@ -63,14 +63,14 @@ XpuEventResourcePool::XpuEventResourcePool() { ...@@ -63,14 +63,14 @@ XpuEventResourcePool::XpuEventResourcePool() {
pool_.reserve(dev_cnt); pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [dev_idx] { auto creator = [dev_idx] {
platform::XPUDeviceGuard gurad(dev_idx); platform::XPUDeviceGuard guard(dev_idx);
xpuEventHandle event; xpuEventHandle event;
xpu_event_create(&event); xpu_event_create(&event);
return event; return event;
}; };
auto deleter = [dev_idx](xpuEventHandle event) { auto deleter = [dev_idx](xpuEventHandle event) {
platform::XPUDeviceGuard gurad(dev_idx); platform::XPUDeviceGuard guard(dev_idx);
xpu_event_destroy(event); xpu_event_destroy(event);
}; };
......
...@@ -33,7 +33,7 @@ void MatMul(const Context& dev_ctx, ...@@ -33,7 +33,7 @@ void MatMul(const Context& dev_ctx,
MatMulXPUFunction<T, int32_t>(a, b, out, trans_a, trans_b, xpu_ctx); MatMulXPUFunction<T, int32_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) { } else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx); MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_INT_WITH_LL) { } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
MatMulXPUFunction<T, int_with_ll_t>(a, b, out, trans_a, trans_b, xpu_ctx); MatMulXPUFunction<T, int_with_ll_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else { } else {
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx); MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
......
...@@ -68,7 +68,7 @@ void BmmKernel(const Context& dev_ctx, ...@@ -68,7 +68,7 @@ void BmmKernel(const Context& dev_ctx,
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, xpu_ctx); MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) { } else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx); MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_INT_WITH_LL) { } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
MatMulXPUFunction<T, int_with_ll_t>(x, y, out, trans_x, trans_y, xpu_ctx); MatMulXPUFunction<T, int_with_ll_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else { } else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx); MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
......
...@@ -30,7 +30,7 @@ enum XPUFCCalcType { ...@@ -30,7 +30,7 @@ enum XPUFCCalcType {
FC_INT16 = 0, FC_INT16 = 0,
FC_INT32, FC_INT32,
FC_FLOAT, FC_FLOAT,
FC_INT_WITH_LL, FC_INT32_WITH_LL,
}; };
template <typename T> template <typename T>
...@@ -42,8 +42,8 @@ XPUFCCalcType FCCalcType() { ...@@ -42,8 +42,8 @@ XPUFCCalcType FCCalcType() {
return XPUFCCalcType::FC_INT32; return XPUFCCalcType::FC_INT32;
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
return XPUFCCalcType::FC_FLOAT; return XPUFCCalcType::FC_FLOAT;
} else if (std::getenv("XPU_PADDLE_FC_INT_WITH_LL") != nullptr) { } else if (std::getenv("XPU_PADDLE_FC_INT32_WITH_LL") != nullptr) {
return XPUFCCalcType::FC_INT_WITH_LL; return XPUFCCalcType::FC_INT32_WITH_LL;
} }
return XPUFCCalcType::FC_INT16; return XPUFCCalcType::FC_INT16;
} }
......
...@@ -47,7 +47,7 @@ from paddle.distributed.fleet.launch_utils import check_backend ...@@ -47,7 +47,7 @@ from paddle.distributed.fleet.launch_utils import check_backend
# (TODO: GhostScreaming) It will be removed later. # (TODO: GhostScreaming) It will be removed later.
from paddle.framework import _set_expected_place from paddle.framework import _set_expected_place
from paddle.framework import base as imperative_base from paddle.framework import base as imperative_base
from paddle.framework import core, in_dygraph_mode, to_variable from paddle.framework import core, in_dygraph_mode
from paddle.nn.layer import layers from paddle.nn.layer import layers
from paddle.utils import deprecated from paddle.utils import deprecated
...@@ -117,21 +117,6 @@ def _split_tensors(coalesced_grads_and_grad_vars): ...@@ -117,21 +117,6 @@ def _split_tensors(coalesced_grads_and_grad_vars):
assert g_var.shape == g_shape assert g_var.shape == g_shape
def scale_loss(loss):
# TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
if not paddle.distributed.ParallelEnv().world_size > 1:
return loss
loss_scale = to_variable(
np.array([paddle.distributed.ParallelEnv().world_size]).astype(
"float32"
)
)
loss_scale.stop_gradient = True
scaled_loss = loss / loss_scale
return scaled_loss
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
def build_groups(vars, group_size): def build_groups(vars, group_size):
......
...@@ -275,8 +275,6 @@ def monkey_patch_tensor(): ...@@ -275,8 +275,6 @@ def monkey_patch_tensor():
# 4: [5000.] # 4: [5000.]
""" """
from paddle.distributed.parallel import scale_loss
if framework._non_static_mode(): if framework._non_static_mode():
if in_profiler_mode(): if in_profiler_mode():
record_event = profiler.RecordEvent( record_event = profiler.RecordEvent(
...@@ -306,30 +304,15 @@ def monkey_patch_tensor(): ...@@ -306,30 +304,15 @@ def monkey_patch_tensor():
if _grad_scalar: if _grad_scalar:
# When using amp with Fleet DistributedStrategy, we do loss scaling implicitly. # When using amp with Fleet DistributedStrategy, we do loss scaling implicitly.
self = _grad_scalar.scale(self) self = _grad_scalar.scale(self)
if paddle.is_compiled_with_xpu(): if framework.global_var._in_eager_mode_:
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future. core.eager.run_backward([self], grad_tensor, retain_graph)
scaled_loss = scale_loss(self)
if framework.global_var._in_eager_mode_:
core.eager.run_backward(
[scaled_loss], grad_tensor, retain_graph
)
else:
core.dygraph_run_backward(
[scaled_loss],
[grad_tensor],
retain_graph,
framework._dygraph_tracer(),
)
else: else:
if framework.global_var._in_eager_mode_: core.dygraph_run_backward(
core.eager.run_backward([self], grad_tensor, retain_graph) [self],
else: [grad_tensor],
core.dygraph_run_backward( retain_graph,
[self], framework._dygraph_tracer(),
[grad_tensor], )
retain_graph,
framework._dygraph_tracer(),
)
if in_profiler_mode(): if in_profiler_mode():
record_event.end() record_event.end()
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册