未验证 提交 073f7ced 编写于 作者: J jameszhang 提交者: GitHub

[KUNLUN] update xccl lib & use native Reduce in dygraph (#49941)

* update xccl lib & use native Reduce in dygraph

* minor
上级 5670644c
......@@ -16,7 +16,7 @@ else()
endif()
set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.6")
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.7")
if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
......
......@@ -352,41 +352,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
phi::DenseTensor output_t;
paddle::framework::TensorCopy(*output, platform::XPUPlace(), &output_t);
const auto& place = input.place();
auto* calc_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
switch (input.dtype()) {
case phi::DataType::FLOAT32:
calc_ctx->template Alloc<float>(&output_t);
break;
case phi::DataType::FLOAT16:
calc_ctx->template Alloc<float16>(&output_t);
break;
case phi::DataType::INT32:
calc_ctx->template Alloc<int>(&output_t);
break;
default:
VLOG(0) << "Error: type " << input.dtype() << " not supported for "
<< GetBackendName();
break;
}
int ret =
bkcl_all_reduce(comm,
input.data(),
output_t.data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
if (rank_ == opts.root_rank) {
*output = output_t;
}
return ret;
return bkcl_reduce(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
opts.root_rank,
stream);
},
CommType::ALLREDUCE,
CommType::REDUCE,
sync_op,
use_calc_stream);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册