未验证 提交 058ca61d 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 3] (#52101)

* update

* fix compile bug

* fix bug

* fix bug

* revert crop_op

* fix xpu compile

* fix cinn compile

* fix bug

* fix bug

* fix bug

* fix bug

* update

* update

* update
上级 4d97b25d
......@@ -26,8 +26,14 @@
#include "paddle/fluid/framework/program_desc.h"
USE_OP_ITSELF(mul);
USE_OP(cinn_launch);
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(cinn_launch);
PD_DECLARE_KERNEL(cinn_launch, CPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_CUDA
PD_DECLARE_KERNEL(cinn_launch, GPU, ALL_LAYOUT);
#endif
namespace paddle::framework {
using Name2VarInfoMap =
......
......@@ -54,7 +54,7 @@ class ClearFloatStatusMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class ClearFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -67,7 +67,6 @@ class ClearFloatStatusKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
using CPU = phi::CPUContext;
REGISTER_OPERATOR(
clear_float_status,
......@@ -76,5 +75,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(clear_float_status,
ops::ClearFloatStatusKernel<CPU, float>);
PD_REGISTER_STRUCT_KERNEL(
clear_float_status, CPU, ALL_LAYOUT, ops::ClearFloatStatusKernel, float) {}
......@@ -145,7 +145,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(CenterLossGradNoNeedBufVarsInferer, "X");
} // namespace paddle
namespace ops = paddle::operators;
using CPUCtx = phi::CPUContext;
namespace plat = paddle::platform;
REGISTER_OPERATOR(center_loss,
ops::CenterLossOp,
......@@ -157,10 +157,11 @@ REGISTER_OPERATOR(center_loss_grad,
ops::CenterLossGradOp,
ops::CenterLossGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(center_loss,
ops::CenterLossKernel<CPUCtx, float>,
ops::CenterLossKernel<CPUCtx, double>);
REGISTER_OP_CPU_KERNEL(center_loss_grad,
ops::CenterLossGradKernel<CPUCtx, float>,
ops::CenterLossGradKernel<CPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(
center_loss, CPU, ALL_LAYOUT, ops::CenterLossKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(center_loss_grad,
CPU,
ALL_LAYOUT,
ops::CenterLossGradKernel,
float,
double) {}
......@@ -81,7 +81,7 @@ __global__ void UpdateCenters(T *centers,
}
}
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CenterLossCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -150,11 +150,12 @@ class CenterLossCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
using GPUCtx = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(center_loss,
ops::CenterLossCUDAKernel<GPUCtx, float>,
ops::CenterLossCUDAKernel<GPUCtx, double>);
REGISTER_OP_CUDA_KERNEL(center_loss_grad,
ops::CenterLossGradKernel<GPUCtx, float>,
ops::CenterLossGradKernel<GPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(
center_loss, GPU, ALL_LAYOUT, ops::CenterLossCUDAKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(center_loss_grad,
GPU,
ALL_LAYOUT,
ops::CenterLossGradKernel,
float,
double) {}
......@@ -40,7 +40,7 @@ struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CenterLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -133,7 +133,7 @@ class CenterLossKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CenterLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......
......@@ -197,5 +197,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(chunk_eval,
ops::ChunkEvalOp,
ops::ChunkEvalOpMaker);
REGISTER_OP_CPU_KERNEL(chunk_eval,
ops::ChunkEvalKernel<paddle::platform::CPUPlace, float>);
PD_REGISTER_STRUCT_KERNEL(
chunk_eval, CPU, ALL_LAYOUT, ops::ChunkEvalKernel, float) {}
......@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class ChunkEvalKernel : public framework::OpKernel<T> {
public:
struct Segment {
......
......@@ -118,5 +118,9 @@ REGISTER_OPERATOR(
ops::CinnInstructionRunOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(cinn_instruction_run,
ops::CinnInstructionRunOpKernel<phi::CPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(cinn_instruction_run,
CPU,
ALL_LAYOUT,
ops::CinnInstructionRunOpKernel,
float) {}
......@@ -18,6 +18,8 @@ limitations under the License. */
namespace ops = paddle::operators;
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(
cinn_instruction_run,
ops::CinnInstructionRunOpKernel<phi::GPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(cinn_instruction_run,
GPU,
ALL_LAYOUT,
ops::CinnInstructionRunOpKernel,
float) {}
......@@ -34,7 +34,7 @@ using CinnInstruction = ::cinn::hlir::framework::Instruction;
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject;
using CinnCompiler = framework::paddle2cinn::CinnCompiler;
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CinnInstructionRunOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -26,12 +26,16 @@ limitations under the License. */
#include "paddle/fluid/platform/init.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP(cinn_launch);
USE_OP(cinn_instruction_run);
USE_OP_ITSELF(cinn_launch);
USE_OP_ITSELF(cinn_instruction_run);
USE_OP_ITSELF(elementwise_add);
PD_DECLARE_KERNEL(cinn_launch, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cinn_instruction_run, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_CUDA
PD_DECLARE_KERNEL(cinn_launch, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cinn_instruction_run, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
#endif
......
......@@ -36,7 +36,12 @@ limitations under the License. */
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
#include "paddle/phi/core/ddim.h"
USE_OP(cinn_instruction_run);
USE_OP_ITSELF(cinn_instruction_run);
PD_DECLARE_KERNEL(cinn_instruction_run, CPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_CUDA
PD_DECLARE_KERNEL(cinn_instruction_run, GPU, ALL_LAYOUT);
#endif
namespace paddle {
namespace operators::details {
......
......@@ -195,5 +195,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* see [Why use single type kernel] */
REGISTER_OP_CPU_KERNEL(cinn_launch,
ops::CinnLaunchOpKernel<phi::CPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(
cinn_launch, CPU, ALL_LAYOUT, ops::CinnLaunchOpKernel, float) {}
......@@ -34,5 +34,8 @@ void SetCinnRandomSeed<phi::GPUContext>() {
} // namespace paddle
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(
cinn_launch, paddle::operators::CinnLaunchOpKernel<phi::GPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(cinn_launch,
GPU,
ALL_LAYOUT,
paddle::operators::CinnLaunchOpKernel,
float) {}
......@@ -60,7 +60,7 @@ void SetCinnRandomSeed();
} // namespace details
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CinnLaunchOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -31,16 +31,20 @@ limitations under the License. */
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP(cinn_launch);
USE_OP(cinn_instruction_run);
USE_OP_ITSELF(cinn_launch);
USE_OP_ITSELF(cinn_instruction_run);
USE_OP_ITSELF(elementwise_add);
DECLARE_double(eager_delete_tensor_gb);
DECLARE_bool(enable_pe_launch_cinn);
DECLARE_bool(enable_interpretercore_launch_cinn);
DECLARE_bool(enable_cinn_auto_tune);
PD_DECLARE_KERNEL(cinn_launch, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cinn_instruction_run, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_CUDA
PD_DECLARE_KERNEL(cinn_launch, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cinn_instruction_run, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
#endif
......
......@@ -65,9 +65,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_broadcast,
ops::CBroadcastOp,
ops::CBroadcastOpMaker);
REGISTER_OP_CPU_KERNEL(c_broadcast,
ops::CBroadcastOpCPUKernel<float>,
ops::CBroadcastOpCPUKernel<double>,
ops::CBroadcastOpCPUKernel<int>,
ops::CBroadcastOpCPUKernel<int64_t>,
ops::CBroadcastOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_broadcast,
CPU,
ALL_LAYOUT,
ops::CBroadcastOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -92,12 +92,16 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_broadcast,
ops::CBroadcastOpCUDAKernel<float>,
ops::CBroadcastOpCUDAKernel<double>,
PD_REGISTER_STRUCT_KERNEL(c_broadcast,
GPU,
ALL_LAYOUT,
ops::CBroadcastOpCUDAKernel,
int,
int64_t,
float,
double,
#if NCCL_VERSION_CODE >= 21000
ops::CBroadcastOpCUDAKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::CBroadcastOpCUDAKernel<int>,
ops::CBroadcastOpCUDAKernel<int64_t>,
ops::CBroadcastOpCUDAKernel<plat::float16>);
plat::float16) {
}
......@@ -33,7 +33,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CBroadcastOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -113,9 +113,12 @@ REGISTER_OPERATOR(c_concat,
ops::CConcatOpGradMaker<paddle::imperative::OpBase>,
ops::CConcatOpMaker);
REGISTER_OP_CPU_KERNEL(c_concat,
ops::CConcatOpCPUKernel<float>,
ops::CConcatOpCPUKernel<double>,
ops::CConcatOpCPUKernel<int>,
ops::CConcatOpCPUKernel<int64_t>,
ops::CConcatOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_concat,
CPU,
ALL_LAYOUT,
ops::CConcatOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CConcatOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -129,12 +129,16 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_concat,
ops::CConcatOpCUDAKernel<float>,
ops::CConcatOpCUDAKernel<double>,
ops::CConcatOpCUDAKernel<int>,
ops::CConcatOpCUDAKernel<int64_t>,
PD_REGISTER_STRUCT_KERNEL(c_concat,
GPU,
ALL_LAYOUT,
ops::CConcatOpCUDAKernel,
float,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000
ops::CConcatOpCUDAKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::CConcatOpCUDAKernel<plat::float16>);
plat::float16) {
}
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CConcatOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -184,12 +184,17 @@ REGISTER_OPERATOR(c_embedding_grad,
ops::CEmbeddingGradOpNoBufferVarsInferer,
ops::CEmbeddingOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(c_embedding,
ops::CEmbeddingOpCPUKernel<float>,
ops::CEmbeddingOpCPUKernel<double>,
ops::CEmbeddingOpCPUKernel<plat::float16>);
REGISTER_OP_CPU_KERNEL(c_embedding_grad,
ops::CEmbeddingGradOpCPUKernel<float>,
ops::CEmbeddingGradOpCPUKernel<double>,
ops::CEmbeddingGradOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_embedding,
CPU,
ALL_LAYOUT,
ops::CEmbeddingOpCPUKernel,
float,
double,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(c_embedding_grad,
CPU,
ALL_LAYOUT,
ops::CEmbeddingGradOpCPUKernel,
float,
double,
plat::float16) {}
......@@ -82,7 +82,7 @@ __global__ void CEmbeddingGrad(T *table,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......@@ -136,7 +136,7 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......@@ -195,17 +195,27 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel<float>,
ops::CEmbeddingCUDAKernel<double>,
PD_REGISTER_STRUCT_KERNEL(c_embedding,
GPU,
ALL_LAYOUT,
ops::CEmbeddingCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingCUDAKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::CEmbeddingCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel<float>,
ops::CEmbeddingGradCUDAKernel<double>,
plat::float16) {
}
PD_REGISTER_STRUCT_KERNEL(c_embedding_grad,
GPU,
ALL_LAYOUT,
ops::CEmbeddingGradCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingGradCUDAKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::CEmbeddingGradCUDAKernel<plat::float16>);
plat::float16) {
}
......@@ -51,7 +51,7 @@ void GetIdsEmbedding(const TIds* ids,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CEmbeddingOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -112,7 +112,7 @@ void UpdateEmbedding(const TIds* ids,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CEmbeddingGradOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -87,9 +87,12 @@ REGISTER_OPERATOR(c_identity,
ops::CIdentityOpGradMaker<paddle::imperative::OpBase>,
ops::CIdentityOpMaker);
REGISTER_OP_CPU_KERNEL(c_identity,
ops::CIdentityOpCPUKernel<float>,
ops::CIdentityOpCPUKernel<double>,
ops::CIdentityOpCPUKernel<int>,
ops::CIdentityOpCPUKernel<int64_t>,
ops::CIdentityOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_identity,
CPU,
ALL_LAYOUT,
ops::CIdentityOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -17,12 +17,16 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_identity,
ops::CIdentityOpKernel<float>,
ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>,
PD_REGISTER_STRUCT_KERNEL(c_identity,
GPU,
ALL_LAYOUT,
ops::CIdentityOpKernel,
float,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000
ops::CIdentityOpKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::CIdentityOpKernel<plat::float16>);
plat::float16) {
}
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CIdentityOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -34,7 +34,7 @@ class CIdentityOpCPUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CIdentityOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -15,8 +15,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(c_identity,
ops::CIdentityOpKernel<float>,
ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>,
ops::CIdentityOpKernel<plat::float16>);
ops::CIdentityOpKernel<float, plat::NPUPlace>,
ops::CIdentityOpKernel<double, plat::NPUPlace>,
ops::CIdentityOpKernel<int, plat::NPUPlace>,
ops::CIdentityOpKernel<int64_t, plat::NPUPlace>,
ops::CIdentityOpKernel<plat::float16, plat::NPUPlace>);
......@@ -15,8 +15,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_identity,
ops::CIdentityOpKernel<float>,
ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>,
ops::CIdentityOpKernel<plat::float16>);
ops::CIdentityOpKernel<float, plat::XPUPlace>,
ops::CIdentityOpKernel<double, plat::XPUPlace>,
ops::CIdentityOpKernel<int, plat::XPUPlace>,
ops::CIdentityOpKernel<int64_t, plat::XPUPlace>,
ops::CIdentityOpKernel<plat::float16, plat::XPUPlace>);
......@@ -33,6 +33,8 @@ class CReduceMaxOpMaker : public CReduceOpMaker {
std::string GetName() const override { return "Max"; }
};
DEFINE_C_REDUCE_CPU_KERNEL(CReduceMax, kRedMax)
} // namespace operators
} // namespace paddle
......@@ -43,9 +45,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_reduce_max,
ops::CReduceOp,
ops::CReduceMaxOpMaker);
REGISTER_OP_CPU_KERNEL(c_reduce_max,
ops::CReduceOpCPUKernel<ops::kRedMax, float>,
ops::CReduceOpCPUKernel<ops::kRedMax, double>,
ops::CReduceOpCPUKernel<ops::kRedMax, int>,
ops::CReduceOpCPUKernel<ops::kRedMax, int64_t>,
ops::CReduceOpCPUKernel<ops::kRedMax, plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_reduce_max,
CPU,
ALL_LAYOUT,
ops::CReduceMaxCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -14,12 +14,21 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_CUDA_KERNEL(CReduceMax, kRedMax);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reduce_max,
ops::CReduceOpCUDAKernel<ops::kRedMax, float>,
ops::CReduceOpCUDAKernel<ops::kRedMax, double>,
ops::CReduceOpCUDAKernel<ops::kRedMax, int>,
ops::CReduceOpCUDAKernel<ops::kRedMax, int64_t>,
ops::CReduceOpCUDAKernel<ops::kRedMax, plat::float16>)
PD_REGISTER_STRUCT_KERNEL(c_reduce_max,
GPU,
ALL_LAYOUT,
ops::CReduceMaxCUDAKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -33,6 +33,7 @@ class CReduceMinOpMaker : public CReduceOpMaker {
std::string GetName() const override { return "Min"; }
};
DEFINE_C_REDUCE_CPU_KERNEL(CReduceMin, kRedMin)
} // namespace operators
} // namespace paddle
......@@ -43,9 +44,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_reduce_min,
ops::CReduceOp,
ops::CReduceMinOpMaker);
REGISTER_OP_CPU_KERNEL(c_reduce_min,
ops::CReduceOpCPUKernel<ops::kRedMin, float>,
ops::CReduceOpCPUKernel<ops::kRedMin, double>,
ops::CReduceOpCPUKernel<ops::kRedMin, int>,
ops::CReduceOpCPUKernel<ops::kRedMin, int64_t>,
ops::CReduceOpCPUKernel<ops::kRedMin, plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_reduce_min,
CPU,
ALL_LAYOUT,
ops::CReduceMinCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -14,12 +14,21 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_CUDA_KERNEL(CReduceMin, kRedMin);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reduce_min,
ops::CReduceOpCUDAKernel<ops::kRedMin, float>,
ops::CReduceOpCUDAKernel<ops::kRedMin, double>,
ops::CReduceOpCUDAKernel<ops::kRedMin, int>,
ops::CReduceOpCUDAKernel<ops::kRedMin, int64_t>,
ops::CReduceOpCUDAKernel<ops::kRedMin, plat::float16>)
PD_REGISTER_STRUCT_KERNEL(c_reduce_min,
GPU,
ALL_LAYOUT,
ops::CReduceMinCUDAKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -127,6 +127,10 @@ class CReduceOpCPUKernel : public framework::OpKernel<T> {
}
};
#define DEFINE_C_REDUCE_CPU_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##CPUKernel : public CReduceOpCPUKernel<red_type, T> {};
template <ReduceType red_type, typename T>
class CReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
......@@ -278,6 +282,10 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
}
};
#define DEFINE_C_REDUCE_CUDA_KERNEL(op_name, red_type) \
template <typename T, typename DeviceContext> \
class op_name##CUDAKernel : public CReduceOpCUDAKernel<red_type, T> {};
template <ReduceType red_type, typename T>
class CReduceOpMLUKernel : public framework::OpKernel<T> {
public:
......
......@@ -33,6 +33,8 @@ class CReduceProdOpMaker : public CReduceOpMaker {
std::string GetName() const override { return "Prod"; }
};
DEFINE_C_REDUCE_CPU_KERNEL(CReduceProd, kRedProd)
} // namespace operators
} // namespace paddle
......@@ -43,9 +45,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_reduce_prod,
ops::CReduceOp,
ops::CReduceProdOpMaker);
REGISTER_OP_CPU_KERNEL(c_reduce_prod,
ops::CReduceOpCPUKernel<ops::kRedProd, float>,
ops::CReduceOpCPUKernel<ops::kRedProd, double>,
ops::CReduceOpCPUKernel<ops::kRedProd, int>,
ops::CReduceOpCPUKernel<ops::kRedProd, int64_t>,
ops::CReduceOpCPUKernel<ops::kRedProd, plat::float16>)
PD_REGISTER_STRUCT_KERNEL(c_reduce_prod,
CPU,
ALL_LAYOUT,
ops::CReduceProdCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -14,12 +14,21 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_CUDA_KERNEL(CReduceProd, kRedProd);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reduce_prod,
ops::CReduceOpCUDAKernel<ops::kRedProd, float>,
ops::CReduceOpCUDAKernel<ops::kRedProd, double>,
ops::CReduceOpCUDAKernel<ops::kRedProd, int>,
ops::CReduceOpCUDAKernel<ops::kRedProd, int64_t>,
ops::CReduceOpCUDAKernel<ops::kRedProd, plat::float16>)
PD_REGISTER_STRUCT_KERNEL(c_reduce_prod,
GPU,
ALL_LAYOUT,
ops::CReduceProdCUDAKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -33,6 +33,8 @@ class CReduceSumOpMaker : public CReduceOpMaker {
std::string GetName() const override { return "Sum"; }
};
DEFINE_C_REDUCE_CPU_KERNEL(CReduceSum, kRedSum)
} // namespace operators
} // namespace paddle
......@@ -43,9 +45,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_reduce_sum,
ops::CReduceOp,
ops::CReduceSumOpMaker);
REGISTER_OP_CPU_KERNEL(c_reduce_sum,
ops::CReduceOpCPUKernel<ops::kRedSum, float>,
ops::CReduceOpCPUKernel<ops::kRedSum, double>,
ops::CReduceOpCPUKernel<ops::kRedSum, int>,
ops::CReduceOpCPUKernel<ops::kRedSum, int64_t>,
ops::CReduceOpCPUKernel<ops::kRedSum, plat::float16>)
PD_REGISTER_STRUCT_KERNEL(c_reduce_sum,
CPU,
ALL_LAYOUT,
ops::CReduceSumCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -14,12 +14,21 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_REDUCE_CUDA_KERNEL(CReduceSum, kRedSum);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reduce_sum,
ops::CReduceOpCUDAKernel<ops::kRedSum, float>,
ops::CReduceOpCUDAKernel<ops::kRedSum, double>,
ops::CReduceOpCUDAKernel<ops::kRedSum, int>,
ops::CReduceOpCUDAKernel<ops::kRedSum, int64_t>,
ops::CReduceOpCUDAKernel<ops::kRedSum, plat::float16>)
PD_REGISTER_STRUCT_KERNEL(c_reduce_sum,
GPU,
ALL_LAYOUT,
ops::CReduceSumCUDAKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -73,9 +73,12 @@ REGISTER_OP_WITHOUT_GRADIENT(c_reducescatter,
ops::CReduceScatterOp,
ops::CReduceScatterOpMaker);
REGISTER_OP_CPU_KERNEL(c_reducescatter,
ops::CReduceScatterOpCPUKernel<float>,
ops::CReduceScatterOpCPUKernel<double>,
ops::CReduceScatterOpCPUKernel<int>,
ops::CReduceScatterOpCPUKernel<int64_t>,
ops::CReduceScatterOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_reducescatter,
CPU,
ALL_LAYOUT,
ops::CReduceScatterOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -81,12 +81,16 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reducescatter,
ops::CReduceScatterOpCUDAKernel<float>,
ops::CReduceScatterOpCUDAKernel<double>,
PD_REGISTER_STRUCT_KERNEL(c_reducescatter,
GPU,
ALL_LAYOUT,
ops::CReduceScatterOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
ops::CReduceScatterOpCUDAKernel<plat::bfloat16>,
plat::bfloat16,
#endif
ops::CReduceScatterOpCUDAKernel<int>,
ops::CReduceScatterOpCUDAKernel<int64_t>,
ops::CReduceScatterOpCUDAKernel<plat::float16>);
int,
int64_t,
plat::float16) {
}
......@@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CReduceScatterOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -88,9 +88,12 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_scatter, ops::CScatterOp, ops::CScatterOpMaker);
REGISTER_OP_CPU_KERNEL(c_scatter,
ops::CScatterOpCPUKernel<float>,
ops::CScatterOpCPUKernel<double>,
ops::CScatterOpCPUKernel<int>,
ops::CScatterOpCPUKernel<int64_t>,
ops::CScatterOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_scatter,
CPU,
ALL_LAYOUT,
ops::CScatterOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -113,9 +113,12 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_scatter,
ops::CScatterOpCUDAKernel<float>,
ops::CScatterOpCUDAKernel<double>,
ops::CScatterOpCUDAKernel<int>,
ops::CScatterOpCUDAKernel<int64_t>,
ops::CScatterOpCUDAKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_scatter,
GPU,
ALL_LAYOUT,
ops::CScatterOpCUDAKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -31,7 +31,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CScatterOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -203,7 +203,10 @@ REGISTER_OPERATOR(c_softmax_with_cross_entropy_grad,
ops::CSoftmaxWithCrossEntropyOpGrad,
ops::CSoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(c_softmax_with_cross_entropy,
ops::CSoftmaxWithCrossEntropyOpCPUKernel<float>,
ops::CSoftmaxWithCrossEntropyOpCPUKernel<double>,
ops::CSoftmaxWithCrossEntropyOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy,
CPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyOpCPUKernel,
float,
double,
plat::float16) {}
......@@ -102,7 +102,7 @@ __global__ void MaskLabelByIndexGrad(T* logits_grad,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -430,7 +430,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -494,14 +494,17 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_softmax_with_cross_entropy,
ops::CSoftmaxWithCrossEntropyOpCUDAKernel<float>,
ops::CSoftmaxWithCrossEntropyOpCUDAKernel<double>,
ops::CSoftmaxWithCrossEntropyOpCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(
c_softmax_with_cross_entropy_grad,
ops::CSoftmaxWithCrossEntropyGradCUDAKernel<float>,
ops::CSoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
ops::CSoftmaxWithCrossEntropyGradCUDAKernel<double>);
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy,
GPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyOpCUDAKernel,
float,
double,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy_grad,
GPU,
ALL_LAYOUT,
ops::CSoftmaxWithCrossEntropyGradCUDAKernel,
float,
double,
plat::float16) {}
......@@ -29,7 +29,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -121,9 +121,12 @@ REGISTER_OPERATOR(c_split,
ops::CSplitOpGradMaker<paddle::imperative::OpBase>,
ops::CSplitOpMaker);
REGISTER_OP_CPU_KERNEL(c_split,
ops::CSplitOpCPUKernel<float>,
ops::CSplitOpCPUKernel<double>,
ops::CSplitOpCPUKernel<int>,
ops::CSplitOpCPUKernel<int64_t>,
ops::CSplitOpCPUKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_split,
CPU,
ALL_LAYOUT,
ops::CSplitOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -52,7 +52,7 @@ __global__ void SplitFromRank(const T* input,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CSplitOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -115,9 +115,12 @@ class CSplitOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_split,
ops::CSplitOpCUDAKernel<float>,
ops::CSplitOpCUDAKernel<double>,
ops::CSplitOpCUDAKernel<int>,
ops::CSplitOpCUDAKernel<int64_t>,
ops::CSplitOpCUDAKernel<plat::float16>);
PD_REGISTER_STRUCT_KERNEL(c_split,
GPU,
ALL_LAYOUT,
ops::CSplitOpCUDAKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CSplitOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -32,28 +32,23 @@ Call calculation stream synchronization.
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
using MLU = plat::MLUPlace;
REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream,
ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float>,
ops::CSyncCalcStreamKernel<double>,
ops::CSyncCalcStreamKernel<int>,
ops::CSyncCalcStreamKernel<int64_t>,
ops::CSyncCalcStreamKernel<paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float>,
ops::CSyncCalcStreamKernel<double>,
ops::CSyncCalcStreamKernel<int>,
ops::CSyncCalcStreamKernel<int64_t>,
ops::CSyncCalcStreamKernel<paddle::platform::float16>);
ops::CSyncCalcStreamKernel<float, MLU>,
ops::CSyncCalcStreamKernel<double, MLU>,
ops::CSyncCalcStreamKernel<int, MLU>,
ops::CSyncCalcStreamKernel<int64_t, MLU>,
ops::CSyncCalcStreamKernel<plat::float16, MLU>);
REGISTER_OP_MLU_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float>,
ops::CSyncCalcStreamKernel<double>,
ops::CSyncCalcStreamKernel<int>,
ops::CSyncCalcStreamKernel<int64_t>,
ops::CSyncCalcStreamKernel<paddle::platform::float16>);
ops::CSyncCalcStreamKernel<float, MLU>,
ops::CSyncCalcStreamKernel<double, MLU>,
ops::CSyncCalcStreamKernel<int, MLU>,
ops::CSyncCalcStreamKernel<int64_t, MLU>,
ops::CSyncCalcStreamKernel<plat::float16, MLU>);
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_sync_calc_stream_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_sync_calc_stream,
GPU,
ALL_LAYOUT,
ops::CSyncCalcStreamKernel,
float,
double,
int,
int64_t,
plat::float16) {}
......@@ -35,7 +35,7 @@ class CSyncCalcStreamOp : public framework::OperatorWithKernel {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CSyncCalcStreamKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -36,7 +36,9 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(c_sync_calc_stream, KP, plat::XPUPlace,
ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_KERNEL(c_sync_calc_stream,
KP,
plat::XPUPlace,
ops::CSyncCalcStreamKernel<float, plat::XPUPlace>);
#endif
......@@ -17,4 +17,5 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>)
REGISTER_OP_XPU_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float, plat::XPUPlace>)
......@@ -48,13 +48,14 @@ Call communication stream synchronization.
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream,
ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamKernel<float, plat::NPUPlace>);
REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_MLU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_MLU_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamKernel<float, plat::MLUPlace>);
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_sync_comm_stream_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(
c_sync_comm_stream, GPU, ALL_LAYOUT, ops::CSyncCommStreamKernel, float) {}
......@@ -33,7 +33,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename DeviceContext>
class CSyncCommStreamKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -36,7 +36,9 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(c_sync_comm_stream, KP, plat::XPUPlace,
ops::CSyncCommStreamKernel<float>);
REGISTER_OP_KERNEL(c_sync_comm_stream,
KP,
plat::XPUPlace,
ops::CSyncCommStreamKernel<float, plat::XPUPlace>);
#endif
......@@ -17,4 +17,5 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_XPU_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamKernel<float, plat::XPUPlace>);
......@@ -152,7 +152,7 @@ However, the output only shares the LoD information with input X.
};
template <typename T>
class ConvShiftKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
class ConvShiftKernel<T, phi::CPUContext> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<phi::DenseTensor>("X");
......@@ -182,8 +182,7 @@ class ConvShiftKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
};
template <typename T>
class ConvShiftGradKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
class ConvShiftGradKernel<T, phi::CPUContext> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<phi::DenseTensor>("X");
......@@ -262,8 +261,7 @@ REGISTER_OPERATOR(conv_shift,
ops::ConvShiftGradOpMaker<paddle::framework::OpDesc>,
ops::ConvShiftGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv_shift_grad, ops::ConvShiftGradOp);
REGISTER_OP_CPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv_shift_grad,
ops::ConvShiftGradKernel<paddle::platform::CPUPlace, float>);
PD_REGISTER_STRUCT_KERNEL(
conv_shift, CPU, ALL_LAYOUT, ops::ConvShiftKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
conv_shift_grad, CPU, ALL_LAYOUT, ops::ConvShiftGradKernel, float) {}
......@@ -122,7 +122,7 @@ __global__ void ConvShiftDy(const T *x,
} // namespace
template <typename T>
class ConvShiftKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
class ConvShiftKernel<T, phi::GPUContext> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const phi::DenseTensor *X = context.Input<phi::DenseTensor>("X");
......@@ -151,7 +151,7 @@ class ConvShiftKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
};
template <typename T>
class ConvShiftGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
class ConvShiftGradKernel<T, phi::GPUContext> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const phi::DenseTensor *X = context.Input<phi::DenseTensor>("X");
......@@ -209,7 +209,8 @@ class ConvShiftGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv_shift,
ops::ConvShiftKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(conv_shift_grad,
ops::ConvShiftGradKernel<phi::GPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(
conv_shift, GPU, ALL_LAYOUT, ops::ConvShiftKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
conv_shift_grad, GPU, ALL_LAYOUT, ops::ConvShiftGradKernel, float) {}
......@@ -18,13 +18,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class ConvShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class ConvShiftGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
......
......@@ -165,7 +165,7 @@ class CorrelationOpGrad : public framework::OperatorWithKernel {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CorrelationKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -186,6 +186,6 @@ REGISTER_OPERATOR(correlation,
ops::CorrelationOpGradMaker<paddle::framework::OpDesc>,
ops::CorrelationOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(correlation_grad, ops::CorrelationOpGrad);
REGISTER_OP_CPU_KERNEL(correlation,
ops::CorrelationKernel<float>,
ops::CorrelationKernel<double>);
PD_REGISTER_STRUCT_KERNEL(
correlation, CPU, ALL_LAYOUT, ops::CorrelationKernel, float, double) {}
......@@ -175,7 +175,7 @@ __global__ void correlation_forward(T *output,
}
// class CorrelationKernel<phi::GPUContext, T>
template <typename T>
template <typename T, typename DeviceContext>
class CorrelationCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -443,7 +443,7 @@ __global__ void correlation_backward_input2(int item,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CorrelationCUDAGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -564,9 +564,12 @@ class CorrelationCUDAGradKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(correlation,
ops::CorrelationCUDAKernel<float>,
ops::CorrelationCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(correlation_grad,
ops::CorrelationCUDAGradKernel<float>,
ops::CorrelationCUDAGradKernel<double>);
PD_REGISTER_STRUCT_KERNEL(
correlation, GPU, ALL_LAYOUT, ops::CorrelationCUDAKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(correlation_grad,
GPU,
ALL_LAYOUT,
ops::CorrelationCUDAGradKernel,
float,
double) {}
......@@ -247,6 +247,6 @@ REGISTER_OPERATOR(cos_sim,
ops::CosSimGradOpMaker<paddle::framework::OpDesc>,
ops::CosSimGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cos_sim_grad, ops::CosSimOpGrad);
REGISTER_OP_CPU_KERNEL(cos_sim, ops::CosSimKernel<phi::CPUContext, float>);
REGISTER_OP_CPU_KERNEL(cos_sim_grad,
ops::CosSimGradKernel<phi::CPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(cos_sim, CPU, ALL_LAYOUT, ops::CosSimKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
cos_sim_grad, CPU, ALL_LAYOUT, ops::CosSimGradKernel, float) {}
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/cos_sim_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cos_sim, ops::CosSimKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(cos_sim_grad,
ops::CosSimGradKernel<phi::GPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(cos_sim, GPU, ALL_LAYOUT, ops::CosSimKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
cos_sim_grad, GPU, ALL_LAYOUT, ops::CosSimGradKernel, float) {}
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CosSimKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -68,7 +68,7 @@ class CosSimKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CosSimGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -216,6 +216,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(crf_decoding,
ops::CRFDecodingOp,
ops::CRFDecodingOpMaker);
REGISTER_OP_CPU_KERNEL(crf_decoding,
ops::CRFDecodingOpKernel<phi::CPUContext, float>,
ops::CRFDecodingOpKernel<phi::CPUContext, double>);
PD_REGISTER_STRUCT_KERNEL(
crf_decoding, CPU, ALL_LAYOUT, ops::CRFDecodingOpKernel, float, double) {}
......@@ -25,7 +25,7 @@ namespace operators {
using framework::LoD;
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CRFDecodingOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -420,7 +420,6 @@ class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
using CPUCtx = phi::CPUContext;
REGISTER_OPERATOR(cross_entropy,
ops::CrossEntropyOpBase,
......@@ -429,12 +428,14 @@ REGISTER_OPERATOR(cross_entropy,
ops::CrossEntropyGradOpMaker<paddle::framework::OpDesc>,
ops::CrossEntropyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy,
ops::CrossEntropyOpKernel<CPUCtx, float>,
ops::CrossEntropyOpKernel<CPUCtx, double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<CPUCtx, float>,
ops::CrossEntropyGradientOpKernel<CPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(
cross_entropy, CPU, ALL_LAYOUT, ops::CrossEntropyOpKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(cross_entropy_grad,
CPU,
ALL_LAYOUT,
ops::CrossEntropyGradientOpKernel,
float,
double) {}
REGISTER_OPERATOR(cross_entropy2,
ops::CrossEntropyOp2,
......@@ -443,9 +444,15 @@ REGISTER_OPERATOR(cross_entropy2,
ops::CrossEntropyGradOpMaker2<paddle::framework::OpDesc>,
ops::CrossEntropyGradOpMaker2<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2);
REGISTER_OP_CPU_KERNEL(cross_entropy2,
ops::CrossEntropyOpKernel2<CPUCtx, float>,
ops::CrossEntropyOpKernel2<CPUCtx, double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad2,
ops::CrossEntropyGradientOpKernel2<CPUCtx, float>,
ops::CrossEntropyGradientOpKernel2<CPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(cross_entropy2,
CPU,
ALL_LAYOUT,
ops::CrossEntropyOpKernel2,
float,
double) {}
PD_REGISTER_STRUCT_KERNEL(cross_entropy_grad2,
CPU,
ALL_LAYOUT,
ops::CrossEntropyGradientOpKernel2,
float,
double) {}
......@@ -17,25 +17,33 @@ limitations under the License. */
namespace plat = paddle::platform;
namespace ops = paddle::operators;
using CUDACtx = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(cross_entropy,
ops::CrossEntropyOpKernel<CUDACtx, float>,
ops::CrossEntropyOpKernel<CUDACtx, double>,
ops::CrossEntropyOpKernel<CUDACtx, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>);
REGISTER_OP_CUDA_KERNEL(cross_entropy2,
ops::CrossEntropyOpKernel2<CUDACtx, float>,
ops::CrossEntropyOpKernel2<CUDACtx, double>,
ops::CrossEntropyOpKernel2<CUDACtx, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad2,
ops::CrossEntropyGradientOpKernel2<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel2<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel2<CUDACtx, plat::float16>);
PD_REGISTER_STRUCT_KERNEL(cross_entropy,
GPU,
ALL_LAYOUT,
ops::CrossEntropyOpKernel,
float,
double,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(cross_entropy_grad,
GPU,
ALL_LAYOUT,
ops::CrossEntropyGradientOpKernel,
float,
double,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(cross_entropy2,
GPU,
ALL_LAYOUT,
ops::CrossEntropyOpKernel2,
float,
double,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(cross_entropy_grad2,
GPU,
ALL_LAYOUT,
ops::CrossEntropyGradientOpKernel2,
float,
double,
plat::float16) {}
......@@ -24,7 +24,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CrossEntropyOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -121,7 +121,7 @@ class XeGradFunctor {
size_t ignore_index_;
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -239,7 +239,7 @@ struct HardLabelCrossEntropyBackwardFunctor {
int64_t feature_size_;
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -266,7 +266,7 @@ class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -128,6 +128,6 @@ REGISTER_OPERATOR(
ops::CTCAlignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(ctc_align,
ops::CTCAlignKernel<phi::CPUContext, int>,
ops::CTCAlignKernel<phi::CPUContext, int64_t>);
PD_REGISTER_STRUCT_KERNEL(
ctc_align, CPU, ALL_LAYOUT, ops::CTCAlignKernel, int, int64_t) {}
......@@ -76,7 +76,7 @@ __global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -165,6 +165,7 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_align,
paddle::operators::CTCAlignOpCUDAKernel<int>,
paddle::operators::CTCAlignOpCUDAKernel<int64_t>);
namespace ops = paddle::operators;
PD_REGISTER_STRUCT_KERNEL(
ctc_align, GPU, ALL_LAYOUT, ops::CTCAlignOpCUDAKernel, int, int64_t) {}
......@@ -24,7 +24,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class CTCAlignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -292,7 +292,7 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class NotImpleKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -312,8 +312,8 @@ REGISTER_OPERATOR(cudnn_lstm,
ops::CudnnLSTMGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>);
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm, CPU, ALL_LAYOUT, ops::NotImpleKernel, float) {}
// TODO(Shixiaowei02) Add ModifyInput support
REGISTER_OP_VERSION(cudnn_lstm)
......
......@@ -199,7 +199,7 @@ void LSTMInferece(const bool &has_seq_length,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -436,7 +436,7 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -727,13 +727,17 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm_grad, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUGradKernel, float) {}
#else
REGISTER_OP_CUDA_KERNEL(cudnn_lstm,
ops::CudnnLSTMGPUKernel<float>,
ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad,
ops::CudnnLSTMGPUGradKernel<float>,
ops::CudnnLSTMGPUGradKernel<double>);
PD_REGISTER_STRUCT_KERNEL(
cudnn_lstm, GPU, ALL_LAYOUT, ops::CudnnLSTMGPUKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(cudnn_lstm_grad,
GPU,
ALL_LAYOUT,
ops::CudnnLSTMGPUGradKernel,
float,
double) {}
#endif
......@@ -180,8 +180,7 @@ REGISTER_OPERATOR(cvm_grad,
ops::CVMGradientOp,
ops::CVMGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel<float>, ops::CVMOpKernel<double>);
REGISTER_OP_CPU_KERNEL(cvm_grad,
ops::CVMGradOpKernel<float>,
ops::CVMGradOpKernel<double>);
PD_REGISTER_STRUCT_KERNEL(
cvm, CPU, ALL_LAYOUT, ops::CVMOpKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(
cvm_grad, CPU, ALL_LAYOUT, ops::CVMGradOpKernel, float, double) {}
......@@ -81,7 +81,7 @@ __global__ void CvmGradComputeKernel(const bool use_cvm,
}
}
template <typename T>
template <typename T, typename DeviceContext>
class CVMCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -122,7 +122,7 @@ class CVMCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CVMGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -187,9 +187,8 @@ class CVMGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cvm,
ops::CVMCUDAKernel<float>,
ops::CVMCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(cvm_grad,
ops::CVMGradCUDAKernel<float>,
ops::CVMGradCUDAKernel<double>);
PD_REGISTER_STRUCT_KERNEL(
cvm, GPU, ALL_LAYOUT, ops::CVMCUDAKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(
cvm_grad, GPU, ALL_LAYOUT, ops::CVMGradCUDAKernel, float, double) {}
......@@ -54,7 +54,7 @@ void CvmGradComputeKernel(const bool use_cvm,
(*DY) += item_width - cvm_offset;
}
template <typename T>
template <typename T, typename DeviceContext>
class CVMOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -95,7 +95,7 @@ class CVMOpKernel : public framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename DeviceContext>
class CVMGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -140,9 +140,12 @@ REGISTER_OPERATOR(
ops::CollectFpnProposalsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(collect_fpn_proposals,
ops::CollectFpnProposalsOpKernel<float>,
ops::CollectFpnProposalsOpKernel<double>);
PD_REGISTER_STRUCT_KERNEL(collect_fpn_proposals,
CPU,
ALL_LAYOUT,
ops::CollectFpnProposalsOpKernel,
float,
double) {}
REGISTER_OP_VERSION(collect_fpn_proposals)
.AddCheckpoint(
R"ROC(
......
......@@ -51,7 +51,7 @@ static __global__ void GetLengthLoD(const int nthreads,
}
}
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -267,7 +267,10 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
collect_fpn_proposals,
ops::GPUCollectFpnProposalsOpKernel<phi::GPUContext, float>,
ops::GPUCollectFpnProposalsOpKernel<phi::GPUContext, double>);
PD_REGISTER_STRUCT_KERNEL(collect_fpn_proposals,
GPU,
ALL_LAYOUT,
ops::GPUCollectFpnProposalsOpKernel,
float,
double) {}
......@@ -57,7 +57,7 @@ static inline bool CompareByBatchid(ScoreWithID<T> a, ScoreWithID<T> b) {
return a.batch_id < b.batch_id;
}
template <typename T>
template <typename T, typename DeviceContext>
class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -37,7 +37,7 @@ PD_DECLARE_KERNEL(fused_transpose, OneDNN, ONEDNN);
USE_OP_ITSELF(shape);
PD_DECLARE_KERNEL(shape, OneDNN, ONEDNN);
USE_OP_ITSELF(crop);
USE_OP_DEVICE_KERNEL(crop, CPU);
PD_DECLARE_KERNEL(crop, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pool2d, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(relu, CPU, ALL_LAYOUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册