未验证 提交 4593597d 编写于 作者: B Baibaifan 提交者: GitHub

add_c_sync_npu_kernel (#32687) (#32723)

上级 6a1957e7
......@@ -46,7 +46,7 @@ Call calculation stream synchronization.
};
template <typename T>
class CSyncCalcStreamCudaKernel : public framework::OpKernel<T> {
class CSyncCalcStreamKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
......@@ -86,5 +86,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamCudaKernel<float>);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
......@@ -35,7 +35,7 @@ namespace m = paddle::operators::math;
USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, NPU);
USE_NO_KERNEL_OP(c_sync_calc_stream);
USE_OP_DEVICE_KERNEL(c_sync_calc_stream, NPU);
template <typename T>
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
......
......@@ -58,7 +58,7 @@ Call communication stream synchronization.
};
template <typename T>
class CSyncCommStreamCudaKernel : public framework::OpKernel<T> {
class CSyncCommStreamKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
......@@ -97,5 +97,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamCudaKernel<float>);
REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
......@@ -43,7 +43,7 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(c_broadcast);
USE_NO_KERNEL_OP(c_sync_comm_stream);
USE_OP_DEVICE_KERNEL(c_sync_comm_stream, NPU);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册