未验证 提交 eac973d1 编写于 作者: L LiYuRio 提交者: GitHub

forbid backward for comm (#47636)

上级 ac2a94c7
...@@ -61,31 +61,13 @@ Scatter tensors from all participators to all participators. ...@@ -61,31 +61,13 @@ Scatter tensors from all participators to all participators.
} }
}; };
template <typename T>
class AllToAllOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("alltoall");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR(alltoall, REGISTER_OP_WITHOUT_GRADIENT(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker)
ops::AllToAllOp,
ops::AllToAllOpMaker,
ops::AllToAllOpGradMaker<paddle::framework::OpDesc>,
ops::AllToAllOpGradMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(alltoall, REGISTER_OP_CPU_KERNEL(alltoall,
ops::AllToAllOpCPUKernel<float>, ops::AllToAllOpCPUKernel<float>,
......
...@@ -95,7 +95,7 @@ namespace plat = paddle::platform; ...@@ -95,7 +95,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(alltoall, REGISTER_OP_CUDA_KERNEL(alltoall,
ops::AllToAllOpCUDAKernel<float>, ops::AllToAllOpCUDAKernel<float>,
ops::AllToAllOpCUDAKernel<double>, ops::AllToAllOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::AllToAllOpCUDAKernel<plat::bfloat16>, ops::AllToAllOpCUDAKernel<plat::bfloat16>,
#endif #endif
ops::AllToAllOpCUDAKernel<int>, ops::AllToAllOpCUDAKernel<int>,
......
...@@ -63,31 +63,15 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us ...@@ -63,31 +63,15 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
} }
}; };
template <typename T>
class CAllGatherOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_reducescatter");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allgather, REGISTER_OP_WITHOUT_GRADIENT(c_allgather,
ops::CAllGatherOp, ops::CAllGatherOp,
ops::CAllGatherOpGradMaker<paddle::framework::OpDesc>, ops::CAllGatherOpMaker);
ops::CAllGatherOpGradMaker<paddle::imperative::OpBase>,
ops::CAllGatherOpMaker);
REGISTER_OP_CPU_KERNEL(c_allgather, REGISTER_OP_CPU_KERNEL(c_allgather,
ops::CAllGatherOpCPUKernel<float>, ops::CAllGatherOpCPUKernel<float>,
......
...@@ -96,7 +96,7 @@ namespace plat = paddle::platform; ...@@ -96,7 +96,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_allgather, REGISTER_OP_CUDA_KERNEL(c_allgather,
ops::CAllGatherOpCUDAKernel<float>, ops::CAllGatherOpCUDAKernel<float>,
ops::CAllGatherOpCUDAKernel<double>, ops::CAllGatherOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CAllGatherOpCUDAKernel<plat::bfloat16>, ops::CAllGatherOpCUDAKernel<plat::bfloat16>,
#endif #endif
ops::CAllGatherOpCUDAKernel<int>, ops::CAllGatherOpCUDAKernel<int>,
......
...@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceMaxInplaceInferer, {"X", "Out"}); ...@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceMaxInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR( REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_max,
c_allreduce_max, ops::CAllReduceOp,
ops::CAllReduceOp, ops::CAllReduceMaxOpMaker,
ops::CAllReduceMaxOpMaker, ops::AllreduceMaxInplaceInferer)
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllreduceMaxInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_max, REGISTER_OP_CPU_KERNEL(c_allreduce_max,
ops::CAllReduceOpCPUKernel<ops::kRedMax, float>, ops::CAllReduceOpCPUKernel<ops::kRedMax, float>,
......
...@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceMinInplaceInferer, {"X", "Out"}); ...@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceMinInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR( REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_min,
c_allreduce_min, ops::CAllReduceOp,
ops::CAllReduceOp, ops::CAllReduceMinOpMaker,
ops::CAllReduceMinOpMaker, ops::AllreduceMinInplaceInferer)
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllreduceMinInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_min, REGISTER_OP_CPU_KERNEL(c_allreduce_min,
ops::CAllReduceOpCPUKernel<ops::kRedMin, float>, ops::CAllReduceOpCPUKernel<ops::kRedMin, float>,
......
...@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceProdInplaceInferer, {"X", "Out"}); ...@@ -41,13 +41,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceProdInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR( REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_prod,
c_allreduce_prod, ops::CAllReduceOp,
ops::CAllReduceOp, ops::CAllReduceProdOpMaker,
ops::CAllReduceProdOpMaker, ops::AllreduceProdInplaceInferer)
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllreduceProdInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_prod, REGISTER_OP_CPU_KERNEL(c_allreduce_prod,
ops::CAllReduceOpCPUKernel<ops::kRedProd, float>, ops::CAllReduceOpCPUKernel<ops::kRedProd, float>,
......
...@@ -20,7 +20,7 @@ namespace plat = paddle::platform; ...@@ -20,7 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
c_allreduce_sum, c_allreduce_sum,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>,
#endif #endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
......
...@@ -108,7 +108,7 @@ namespace plat = paddle::platform; ...@@ -108,7 +108,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_broadcast, REGISTER_OP_CUDA_KERNEL(c_broadcast,
ops::CBroadcastOpCUDAKernel<float>, ops::CBroadcastOpCUDAKernel<float>,
ops::CBroadcastOpCUDAKernel<double>, ops::CBroadcastOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CBroadcastOpCUDAKernel<plat::bfloat16>, ops::CBroadcastOpCUDAKernel<plat::bfloat16>,
#endif #endif
ops::CBroadcastOpCUDAKernel<int>, ops::CBroadcastOpCUDAKernel<int>,
......
...@@ -66,29 +66,15 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us ...@@ -66,29 +66,15 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
} }
}; };
template <typename T>
class CReduceScatterOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_allgather");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR(c_reducescatter, REGISTER_OP_WITHOUT_GRADIENT(c_reducescatter,
ops::CReduceScatterOp, ops::CReduceScatterOp,
ops::CReduceScatterOpMaker); ops::CReduceScatterOpMaker);
REGISTER_OP_CPU_KERNEL(c_reducescatter, REGISTER_OP_CPU_KERNEL(c_reducescatter,
ops::CReduceScatterOpCPUKernel<float>, ops::CReduceScatterOpCPUKernel<float>,
......
...@@ -84,7 +84,7 @@ namespace plat = paddle::platform; ...@@ -84,7 +84,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reducescatter, REGISTER_OP_CUDA_KERNEL(c_reducescatter,
ops::CReduceScatterOpCUDAKernel<float>, ops::CReduceScatterOpCUDAKernel<float>,
ops::CReduceScatterOpCUDAKernel<double>, ops::CReduceScatterOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CReduceScatterOpCUDAKernel<plat::bfloat16>, ops::CReduceScatterOpCUDAKernel<plat::bfloat16>,
#endif #endif
ops::CReduceScatterOpCUDAKernel<int>, ops::CReduceScatterOpCUDAKernel<int>,
......
...@@ -236,7 +236,7 @@ namespace plat = paddle::platform; ...@@ -236,7 +236,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(recv_v2, REGISTER_OP_CUDA_KERNEL(recv_v2,
ops::RecvOpV2CUDAKernel<float>, ops::RecvOpV2CUDAKernel<float>,
ops::RecvOpV2CUDAKernel<double>, ops::RecvOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::RecvOpV2CUDAKernel<plat::bfloat16>, ops::RecvOpV2CUDAKernel<plat::bfloat16>,
#endif #endif
ops::RecvOpV2CUDAKernel<int>, ops::RecvOpV2CUDAKernel<int>,
......
...@@ -221,7 +221,7 @@ namespace plat = paddle::platform; ...@@ -221,7 +221,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2, REGISTER_OP_CUDA_KERNEL(send_v2,
ops::SendOpV2CUDAKernel<float>, ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>, ops::SendOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>, ops::SendOpV2CUDAKernel<plat::bfloat16>,
#endif #endif
ops::SendOpV2CUDAKernel<int>, ops::SendOpV2CUDAKernel<int>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册