未验证 提交 8fd724a5 编写于 作者: B Baibaifan 提交者: GitHub

add_c_sync_npu_kernel (#32687)

上级 43527a2b
...@@ -46,7 +46,7 @@ Call calculation stream synchronization. ...@@ -46,7 +46,7 @@ Call calculation stream synchronization.
}; };
template <typename T> template <typename T>
class CSyncCalcStreamCudaKernel : public framework::OpKernel<T> { class CSyncCalcStreamKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
...@@ -86,5 +86,6 @@ namespace ops = paddle::operators; ...@@ -86,5 +86,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp, REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker); ops::CSyncCalcStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
ops::CSyncCalcStreamCudaKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
...@@ -35,7 +35,7 @@ namespace m = paddle::operators::math; ...@@ -35,7 +35,7 @@ namespace m = paddle::operators::math;
USE_OP(elementwise_add); USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, NPU); USE_OP_DEVICE_KERNEL(elementwise_add, NPU);
USE_NO_KERNEL_OP(c_sync_calc_stream); USE_OP_DEVICE_KERNEL(c_sync_calc_stream, NPU);
template <typename T> template <typename T>
void Compare(f::Scope* scope, const p::DeviceContext& ctx) { void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
......
...@@ -58,7 +58,7 @@ Call communication stream synchronization. ...@@ -58,7 +58,7 @@ Call communication stream synchronization.
}; };
template <typename T> template <typename T>
class CSyncCommStreamCudaKernel : public framework::OpKernel<T> { class CSyncCommStreamKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -97,5 +97,6 @@ namespace ops = paddle::operators; ...@@ -97,5 +97,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp, REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker); ops::CSyncCommStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
ops::CSyncCommStreamCudaKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
...@@ -43,7 +43,7 @@ namespace p = paddle::platform; ...@@ -43,7 +43,7 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(c_broadcast); USE_OP(c_broadcast);
USE_NO_KERNEL_OP(c_sync_comm_stream); USE_OP_DEVICE_KERNEL(c_sync_comm_stream, NPU);
USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU); USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册