未验证 提交 eac973d1 编写于 作者: L LiYuRio 提交者: GitHub

forbid backward for comm (#47636)

上级 ac2a94c7
......@@ -61,31 +61,13 @@ Scatter tensors from all participators to all participators.
}
};
template <typename T>
class AllToAllOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("alltoall");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(alltoall,
ops::AllToAllOp,
ops::AllToAllOpMaker,
ops::AllToAllOpGradMaker<paddle::framework::OpDesc>,
ops::AllToAllOpGradMaker<paddle::imperative::OpBase>)
REGISTER_OP_WITHOUT_GRADIENT(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker)
REGISTER_OP_CPU_KERNEL(alltoall,
ops::AllToAllOpCPUKernel<float>,
......
......@@ -95,7 +95,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(alltoall,
ops::AllToAllOpCUDAKernel<float>,
ops::AllToAllOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::AllToAllOpCUDAKernel<plat::bfloat16>,
#endif
ops::AllToAllOpCUDAKernel<int>,
......
......@@ -63,31 +63,15 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
}
};
template <typename T>
class CAllGatherOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_reducescatter");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allgather,
ops::CAllGatherOp,
ops::CAllGatherOpGradMaker<paddle::framework::OpDesc>,
ops::CAllGatherOpGradMaker<paddle::imperative::OpBase>,
ops::CAllGatherOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(c_allgather,
ops::CAllGatherOp,
ops::CAllGatherOpMaker);
REGISTER_OP_CPU_KERNEL(c_allgather,
ops::CAllGatherOpCPUKernel<float>,
......
......@@ -96,7 +96,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_allgather,
ops::CAllGatherOpCUDAKernel<float>,
ops::CAllGatherOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::CAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::CAllGatherOpCUDAKernel<int>,
......
......@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceMaxInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
c_allreduce_max,
ops::CAllReduceOp,
ops::CAllReduceMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllreduceMaxInplaceInferer)
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_max,
ops::CAllReduceOp,
ops::CAllReduceMaxOpMaker,
ops::AllreduceMaxInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_max,
ops::CAllReduceOpCPUKernel<ops::kRedMax, float>,
......
......@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceMinInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
c_allreduce_min,
ops::CAllReduceOp,
ops::CAllReduceMinOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllreduceMinInplaceInferer)
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_min,
ops::CAllReduceOp,
ops::CAllReduceMinOpMaker,
ops::AllreduceMinInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_min,
ops::CAllReduceOpCPUKernel<ops::kRedMin, float>,
......
......@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceProdInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
c_allreduce_prod,
ops::CAllReduceOp,
ops::CAllReduceProdOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllreduceProdInplaceInferer)
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_prod,
ops::CAllReduceOp,
ops::CAllReduceProdOpMaker,
ops::AllreduceProdInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_prod,
ops::CAllReduceOpCPUKernel<ops::kRedProd, float>,
......
......@@ -20,7 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce_sum,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>,
#endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
......
......@@ -108,7 +108,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_broadcast,
ops::CBroadcastOpCUDAKernel<float>,
ops::CBroadcastOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::CBroadcastOpCUDAKernel<plat::bfloat16>,
#endif
ops::CBroadcastOpCUDAKernel<int>,
......
......@@ -66,29 +66,15 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
}
};
template <typename T>
class CReduceScatterOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_allgather");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_reducescatter,
ops::CReduceScatterOp,
ops::CReduceScatterOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(c_reducescatter,
ops::CReduceScatterOp,
ops::CReduceScatterOpMaker);
REGISTER_OP_CPU_KERNEL(c_reducescatter,
ops::CReduceScatterOpCPUKernel<float>,
......
......@@ -84,7 +84,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reducescatter,
ops::CReduceScatterOpCUDAKernel<float>,
ops::CReduceScatterOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::CReduceScatterOpCUDAKernel<plat::bfloat16>,
#endif
ops::CReduceScatterOpCUDAKernel<int>,
......
......@@ -236,7 +236,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(recv_v2,
ops::RecvOpV2CUDAKernel<float>,
ops::RecvOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::RecvOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::RecvOpV2CUDAKernel<int>,
......
......@@ -221,7 +221,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2,
ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::SendOpV2CUDAKernel<int>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册