未验证 提交 d05b73e4 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 2] (#52044)

* update bipartite_match

* update

* fix bug

* fix test

* fix bug

* fix Kunlun-KP-Build

* Revert "fix Kunlun-KP-Build"

This reverts commit ceab63cc23079fd6839c826bb52db893fb056355.

* update
上级 ffff133b
...@@ -174,7 +174,6 @@ class BprLossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -174,7 +174,6 @@ class BprLossGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPUCtx = phi::CPUContext;
REGISTER_OPERATOR(bpr_loss, REGISTER_OPERATOR(bpr_loss,
ops::BprLossOp, ops::BprLossOp,
...@@ -182,9 +181,12 @@ REGISTER_OPERATOR(bpr_loss, ...@@ -182,9 +181,12 @@ REGISTER_OPERATOR(bpr_loss,
ops::BprLossGradMaker<paddle::framework::OpDesc>, ops::BprLossGradMaker<paddle::framework::OpDesc>,
ops::BprLossGradMaker<paddle::imperative::OpBase>); ops::BprLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp); REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp);
REGISTER_OP_CPU_KERNEL(bpr_loss,
ops::BprLossOpKernel<CPUCtx, float>, PD_REGISTER_STRUCT_KERNEL(
ops::BprLossOpKernel<CPUCtx, double>); bpr_loss, CPU, ALL_LAYOUT, ops::BprLossOpKernel, float, double) {}
REGISTER_OP_CPU_KERNEL(bpr_loss_grad, PD_REGISTER_STRUCT_KERNEL(bpr_loss_grad,
ops::BprLossGradientOpKernel<CPUCtx, float>, CPU,
ops::BprLossGradientOpKernel<CPUCtx, double>); ALL_LAYOUT,
ops::BprLossGradientOpKernel,
float,
double) {}
...@@ -35,7 +35,7 @@ struct TolerableValue { ...@@ -35,7 +35,7 @@ struct TolerableValue {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BprLossOpKernel : public framework::OpKernel<T> { class BprLossOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -83,7 +83,7 @@ class BprLossOpKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,7 @@ class BprLossOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BprLossGradientOpKernel : public framework::OpKernel<T> { class BprLossGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -73,12 +73,15 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allgather, ...@@ -73,12 +73,15 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allgather,
ops::CAllGatherOp, ops::CAllGatherOp,
ops::CAllGatherOpMaker); ops::CAllGatherOpMaker);
REGISTER_OP_CPU_KERNEL(c_allgather, PD_REGISTER_STRUCT_KERNEL(c_allgather,
ops::CAllGatherOpCPUKernel<float>, CPU,
ops::CAllGatherOpCPUKernel<double>, ALL_LAYOUT,
ops::CAllGatherOpCPUKernel<int>, ops::CAllGatherOpCPUKernel,
ops::CAllGatherOpCPUKernel<int64_t>, float,
ops::CAllGatherOpCPUKernel<uint8_t>, double,
ops::CAllGatherOpCPUKernel<int8_t>, int,
ops::CAllGatherOpCPUKernel<bool>, int8_t,
ops::CAllGatherOpCPUKernel<plat::float16>); int64_t,
uint8_t,
bool,
plat::float16) {}
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CAllGatherOpCUDAKernel : public framework::OpKernel<T> { class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -93,15 +93,19 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -93,15 +93,19 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_allgather, PD_REGISTER_STRUCT_KERNEL(c_allgather,
ops::CAllGatherOpCUDAKernel<float>, GPU,
ops::CAllGatherOpCUDAKernel<double>, ALL_LAYOUT,
ops::CAllGatherOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CAllGatherOpCUDAKernel<plat::bfloat16>, plat::bfloat16,
#endif #endif
ops::CAllGatherOpCUDAKernel<int>, int,
ops::CAllGatherOpCUDAKernel<uint8_t>, uint8_t,
ops::CAllGatherOpCUDAKernel<int8_t>, int8_t,
ops::CAllGatherOpCUDAKernel<int64_t>, int64_t,
ops::CAllGatherOpCUDAKernel<bool>, bool,
ops::CAllGatherOpCUDAKernel<plat::float16>); plat::float16) {
}
...@@ -32,7 +32,7 @@ limitations under the License. */ ...@@ -32,7 +32,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class CAllGatherOpCPUKernel : public framework::OpKernel<T> { class CAllGatherOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -35,6 +35,8 @@ class CAllReduceMaxOpMaker : public CAllReduceOpMaker { ...@@ -35,6 +35,8 @@ class CAllReduceMaxOpMaker : public CAllReduceOpMaker {
DECLARE_INPLACE_OP_INFERER(AllreduceMaxInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(AllreduceMaxInplaceInferer, {"X", "Out"});
DEFINE_C_ALLREDUCE_CPU_KERNEL(CAllReduceMax, kRedMax)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -45,10 +47,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_max, ...@@ -45,10 +47,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_max,
ops::CAllReduceOp, ops::CAllReduceOp,
ops::CAllReduceMaxOpMaker, ops::CAllReduceMaxOpMaker,
ops::AllreduceMaxInplaceInferer) ops::AllreduceMaxInplaceInferer)
PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
REGISTER_OP_CPU_KERNEL(c_allreduce_max, CPU,
ops::CAllReduceOpCPUKernel<ops::kRedMax, float>, ALL_LAYOUT,
ops::CAllReduceOpCPUKernel<ops::kRedMax, double>, ops::CAllReduceMaxCPUKernel,
ops::CAllReduceOpCPUKernel<ops::kRedMax, int>, float,
ops::CAllReduceOpCPUKernel<ops::kRedMax, int64_t>, double,
ops::CAllReduceOpCPUKernel<ops::kRedMax, plat::float16>); int,
int64_t,
plat::float16) {}
...@@ -14,13 +14,21 @@ limitations under the License. */ ...@@ -14,13 +14,21 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceMax, kRedMax)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
c_allreduce_max, GPU,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, float>, ALL_LAYOUT,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, double>, ops::CAllReduceMaxCUDAKernel,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int>, float,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int64_t>, double,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, plat::float16>) int,
int64_t,
plat::float16) {}
...@@ -35,6 +35,8 @@ class CAllReduceMinOpMaker : public CAllReduceOpMaker { ...@@ -35,6 +35,8 @@ class CAllReduceMinOpMaker : public CAllReduceOpMaker {
DECLARE_INPLACE_OP_INFERER(AllreduceMinInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(AllreduceMinInplaceInferer, {"X", "Out"});
DEFINE_C_ALLREDUCE_CPU_KERNEL(CAllReduceMin, kRedMin)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -46,9 +48,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_min, ...@@ -46,9 +48,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_min,
ops::CAllReduceMinOpMaker, ops::CAllReduceMinOpMaker,
ops::AllreduceMinInplaceInferer) ops::AllreduceMinInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_min, PD_REGISTER_STRUCT_KERNEL(c_allreduce_min,
ops::CAllReduceOpCPUKernel<ops::kRedMin, float>, CPU,
ops::CAllReduceOpCPUKernel<ops::kRedMin, double>, ALL_LAYOUT,
ops::CAllReduceOpCPUKernel<ops::kRedMin, int>, ops::CAllReduceMinCPUKernel,
ops::CAllReduceOpCPUKernel<ops::kRedMin, int64_t>, float,
ops::CAllReduceOpCPUKernel<ops::kRedMin, plat::float16>); double,
int,
int64_t,
plat::float16) {}
...@@ -14,13 +14,21 @@ limitations under the License. */ ...@@ -14,13 +14,21 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceMin, kRedMin)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(c_allreduce_min,
c_allreduce_min, GPU,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, float>, ALL_LAYOUT,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, double>, ops::CAllReduceMinCUDAKernel,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, int>, float,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, int64_t>, double,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, plat::float16>) int,
int64_t,
plat::float16) {}
...@@ -148,6 +148,10 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> { ...@@ -148,6 +148,10 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
} }
}; };
#define DEFINE_C_ALLREDUCE_CPU_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##CPUKernel : public CAllReduceOpCPUKernel<red_type, T> {};
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
// return true if found_nan or return false; // return true if found_nan or return false;
inline bool ContainsNan(const paddle::platform::NPUDeviceContext& dev_ctx, inline bool ContainsNan(const paddle::platform::NPUDeviceContext& dev_ctx,
...@@ -527,6 +531,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -527,6 +531,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
} }
}; };
#define DEFINE_C_ALLREDUCE_CUDA_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##CUDAKernel : public CAllReduceOpCUDAKernel<red_type, T> {};
template <ReduceType red_type, typename T> template <ReduceType red_type, typename T>
class CAllReduceOpMLUKernel : public framework::OpKernel<T> { class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -35,6 +35,8 @@ class CAllReduceProdOpMaker : public CAllReduceOpMaker { ...@@ -35,6 +35,8 @@ class CAllReduceProdOpMaker : public CAllReduceOpMaker {
DECLARE_INPLACE_OP_INFERER(AllreduceProdInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(AllreduceProdInplaceInferer, {"X", "Out"});
DEFINE_C_ALLREDUCE_CPU_KERNEL(CAllReduceProd, kRedProd)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -46,9 +48,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_prod, ...@@ -46,9 +48,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_prod,
ops::CAllReduceProdOpMaker, ops::CAllReduceProdOpMaker,
ops::AllreduceProdInplaceInferer) ops::AllreduceProdInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_prod, PD_REGISTER_STRUCT_KERNEL(c_allreduce_prod,
ops::CAllReduceOpCPUKernel<ops::kRedProd, float>, CPU,
ops::CAllReduceOpCPUKernel<ops::kRedProd, double>, ALL_LAYOUT,
ops::CAllReduceOpCPUKernel<ops::kRedProd, int>, ops::CAllReduceProdCPUKernel,
ops::CAllReduceOpCPUKernel<ops::kRedProd, int64_t>, float,
ops::CAllReduceOpCPUKernel<ops::kRedProd, plat::float16>) double,
int,
int64_t,
plat::float16) {}
...@@ -14,13 +14,21 @@ limitations under the License. */ ...@@ -14,13 +14,21 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceProd, kRedProd)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(c_allreduce_prod,
c_allreduce_prod, GPU,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, float>, ALL_LAYOUT,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, double>, ops::CAllReduceProdCUDAKernel,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, int>, float,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, int64_t>, double,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, plat::float16>) int,
int64_t,
plat::float16) {}
...@@ -56,6 +56,8 @@ class CAllReduceSumOpMaker : public CAllReduceOpMaker { ...@@ -56,6 +56,8 @@ class CAllReduceSumOpMaker : public CAllReduceOpMaker {
DECLARE_INPLACE_OP_INFERER(AllreduceSumInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(AllreduceSumInplaceInferer, {"X", "Out"});
DEFINE_C_ALLREDUCE_CPU_KERNEL(CAllReduceSum, kRedSum)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -67,9 +69,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_sum, ...@@ -67,9 +69,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_sum,
ops::CAllReduceSumOpMaker, ops::CAllReduceSumOpMaker,
ops::AllreduceSumInplaceInferer) ops::AllreduceSumInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_sum, PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum,
ops::CAllReduceOpCPUKernel<ops::kRedSum, float>, CPU,
ops::CAllReduceOpCPUKernel<ops::kRedSum, double>, ALL_LAYOUT,
ops::CAllReduceOpCPUKernel<ops::kRedSum, int>, ops::CAllReduceSumCPUKernel,
ops::CAllReduceOpCPUKernel<ops::kRedSum, int64_t>, float,
ops::CAllReduceOpCPUKernel<ops::kRedSum, plat::float16>) double,
int,
int64_t,
plat::float16) {}
...@@ -14,16 +14,25 @@ limitations under the License. */ ...@@ -14,16 +14,25 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceSum, kRedSum)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum,
c_allreduce_sum, GPU,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>, ALL_LAYOUT,
ops::CAllReduceSumCUDAKernel,
float,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>, plat::bfloat16,
#endif #endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>, double,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int>, int,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int64_t>, int64_t,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::float16>) plat::float16) {
}
...@@ -64,7 +64,7 @@ bool DistPairDescend(std::tuple<int, int, T> pair1, ...@@ -64,7 +64,7 @@ bool DistPairDescend(std::tuple<int, int, T> pair1,
return std::get<2>(pair1) > std::get<2>(pair2); return std::get<2>(pair1) > std::get<2>(pair2);
} }
template <typename T> template <typename T, typename DeviceContext>
class BipartiteMatchKernel : public framework::OpKernel<T> { class BipartiteMatchKernel : public framework::OpKernel<T> {
public: public:
// The match_indices must be initialized to -1 at first. // The match_indices must be initialized to -1 at first.
...@@ -318,6 +318,10 @@ REGISTER_OPERATOR( ...@@ -318,6 +318,10 @@ REGISTER_OPERATOR(
ops::BipartiteMatchOpMaker, ops::BipartiteMatchOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(bipartite_match,
ops::BipartiteMatchKernel<float>, PD_REGISTER_STRUCT_KERNEL(bipartite_match,
ops::BipartiteMatchKernel<double>); CPU,
ALL_LAYOUT,
ops::BipartiteMatchKernel,
float,
double) {}
...@@ -104,6 +104,6 @@ REGISTER_OPERATOR( ...@@ -104,6 +104,6 @@ REGISTER_OPERATOR(
ops::BoxClipOpMaker, ops::BoxClipOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(box_clip,
ops::BoxClipKernel<phi::CPUContext, float>, PD_REGISTER_STRUCT_KERNEL(
ops::BoxClipKernel<phi::CPUContext, double>); box_clip, CPU, ALL_LAYOUT, ops::BoxClipKernel, float, double) {}
...@@ -44,7 +44,7 @@ static __global__ void GPUBoxClip(const T *input, ...@@ -44,7 +44,7 @@ static __global__ void GPUBoxClip(const T *input,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class GPUBoxClipKernel : public framework::OpKernel<T> { class GPUBoxClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -74,6 +74,6 @@ class GPUBoxClipKernel : public framework::OpKernel<T> { ...@@ -74,6 +74,6 @@ class GPUBoxClipKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(box_clip,
ops::GPUBoxClipKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(
ops::GPUBoxClipKernel<phi::GPUContext, double>); box_clip, GPU, ALL_LAYOUT, ops::GPUBoxClipKernel, float, double) {}
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BoxClipKernel : public framework::OpKernel<T> { class BoxClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -225,6 +225,9 @@ REGISTER_OPERATOR( ...@@ -225,6 +225,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(box_decoder_and_assign, PD_REGISTER_STRUCT_KERNEL(box_decoder_and_assign,
ops::BoxDecoderAndAssignKernel<phi::CPUContext, float>, CPU,
ops::BoxDecoderAndAssignKernel<phi::CPUContext, double>); ALL_LAYOUT,
ops::BoxDecoderAndAssignKernel,
float,
double) {}
...@@ -95,7 +95,7 @@ __global__ void AssignBoxKernel(const T* prior_box_data, ...@@ -95,7 +95,7 @@ __global__ void AssignBoxKernel(const T* prior_box_data,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> { class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -150,7 +150,10 @@ class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> { ...@@ -150,7 +150,10 @@ class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
box_decoder_and_assign, PD_REGISTER_STRUCT_KERNEL(box_decoder_and_assign,
ops::BoxDecoderAndAssignCUDAKernel<phi::GPUContext, float>, GPU,
ops::BoxDecoderAndAssignCUDAKernel<phi::GPUContext, double>); ALL_LAYOUT,
ops::BoxDecoderAndAssignCUDAKernel,
float,
double) {}
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BoxDecoderAndAssignKernel : public framework::OpKernel<T> { class BoxDecoderAndAssignKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -105,7 +105,7 @@ class TestBipartiteMatchOpWithLoD(OpTest): ...@@ -105,7 +105,7 @@ class TestBipartiteMatchOpWithLoD(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_dygraph=False)
class TestBipartiteMatchOpWithoutLoD(OpTest): class TestBipartiteMatchOpWithoutLoD(OpTest):
...@@ -122,7 +122,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): ...@@ -122,7 +122,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_dygraph=False)
class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest): class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest):
...@@ -139,7 +139,7 @@ class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest): ...@@ -139,7 +139,7 @@ class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_dygraph=False)
class TestBipartiteMatchOpWithPerPredictionType(OpTest): class TestBipartiteMatchOpWithPerPredictionType(OpTest):
...@@ -162,7 +162,7 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest): ...@@ -162,7 +162,7 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_dygraph=False)
class TestBipartiteMatchOpWithEmptyLoD(OpTest): class TestBipartiteMatchOpWithEmptyLoD(OpTest):
...@@ -179,7 +179,7 @@ class TestBipartiteMatchOpWithEmptyLoD(OpTest): ...@@ -179,7 +179,7 @@ class TestBipartiteMatchOpWithEmptyLoD(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_dygraph=False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -52,7 +52,7 @@ def batch_box_clip(input_boxes, im_info, lod): ...@@ -52,7 +52,7 @@ def batch_box_clip(input_boxes, im_info, lod):
class TestBoxClipOp(OpTest): class TestBoxClipOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_dygraph=False)
def setUp(self): def setUp(self):
self.op_type = "box_clip" self.op_type = "box_clip"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册