From 9f1f1b0a9b70cebe9ff2f0a213b6f234fac674f5 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Fri, 19 Aug 2022 15:22:19 +0800 Subject: [PATCH] [XPU] c_allreduce support int. update bkcl to 1.0.5. test=kunlun (#45248) --- cmake/external/xpu.cmake | 2 +- paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc | 3 ++- paddle/fluid/platform/device/xpu/xpu2_op_list.h | 5 ++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index bb2e864f5cd..08f6be6de11 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -25,7 +25,7 @@ else() endif() set(XPU_XCCL_BASE_URL - "https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.4") + "https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.5") if(WITH_AARCH64) set(XPU_XRE_DIR_NAME "xre-kylin_aarch64") diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc index 16b3ff335b6..a4d1c62e821 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc @@ -19,4 +19,5 @@ namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL(c_allreduce_sum, ops::CAllReduceOpXPUKernel, - ops::CAllReduceOpXPUKernel) + ops::CAllReduceOpXPUKernel, + ops::CAllReduceOpXPUKernel) diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index c246924e14b..83f0c21315b 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -80,13 +80,16 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace())})}, {"c_allreduce_sum", XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, + pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"c_identity", XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, + {"c_sync_calc_stream", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"c_sync_comm_stream", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"cast", -- GitLab