未验证 提交 eb38c85f 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 5] (#52486)

* update

* fix bug

* update

* fix bug
上级 5bac67d4
...@@ -116,6 +116,7 @@ class FetchV2Op : public framework::OperatorWithKernel { ...@@ -116,6 +116,7 @@ class FetchV2Op : public framework::OperatorWithKernel {
} }
}; };
template <typename T, typename DeviceContext>
class FetchV2Kernel { class FetchV2Kernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -228,28 +229,19 @@ REGISTER_OPERATOR( ...@@ -228,28 +229,19 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, PD_REGISTER_STRUCT_KERNEL(fetch_v2,
float, CPU,
ops::FetchV2Kernel, ALL_LAYOUT,
double, ops::FetchV2Kernel,
ops::FetchV2Kernel, float,
int8_t, double,
ops::FetchV2Kernel, int,
uint8_t, int8_t,
ops::FetchV2Kernel, int16_t,
int, int64_t,
ops::FetchV2Kernel, uint8_t,
int64_t, bool,
ops::FetchV2Kernel, plat::float16,
bool, plat::bfloat16,
ops::FetchV2Kernel, plat::complex<float>,
paddle::platform::bfloat16, plat::complex<double>) {}
ops::FetchV2Kernel,
paddle::platform::complex<float>,
ops::FetchV2Kernel,
paddle::platform::complex<double>,
ops::FetchV2Kernel,
plat::float16,
ops::FetchV2Kernel,
int16_t,
ops::FetchV2Kernel);
...@@ -206,6 +206,6 @@ REGISTER_OPERATOR( ...@@ -206,6 +206,6 @@ REGISTER_OPERATOR(
ops::FCOpMaker, ops::FCOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fc,
ops::FCOpKernel<phi::CPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fc, CPU, ALL_LAYOUT, ops::FCOpKernel, float, double) {
ops::FCOpKernel<phi::CPUContext, double>); }
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/fc_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fc,
ops::FCOpKernel<phi::GPUContext, phi::dtype::float16>, PD_REGISTER_STRUCT_KERNEL(
ops::FCOpKernel<phi::GPUContext, float>, fc, GPU, ALL_LAYOUT, ops::FCOpKernel, float, double, phi::dtype::float16) {}
ops::FCOpKernel<phi::GPUContext, double>);
...@@ -51,7 +51,7 @@ inline void FCOutputSize(const framework::DDim& in_dims, ...@@ -51,7 +51,7 @@ inline void FCOutputSize(const framework::DDim& in_dims,
out_dims.push_back(w_dims1); out_dims.push_back(w_dims1);
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FCOpKernel : public framework::OpKernel<T> { class FCOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
......
...@@ -80,6 +80,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(FillZerosLikeOp2NoNeedBufferVarsInferer, ...@@ -80,6 +80,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(FillZerosLikeOp2NoNeedBufferVarsInferer,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like,
ops::FillZerosLikeOp, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker); ops::FillZerosLikeOpMaker);
...@@ -92,24 +94,26 @@ REGISTER_OPERATOR( ...@@ -92,24 +94,26 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(fill_zeros_like,
fill_zeros_like, CPU,
ops::FillZerosLikeKernel<phi::CPUContext, int>, ALL_LAYOUT,
ops::FillZerosLikeKernel<phi::CPUContext, int64_t>, ops::FillZerosLikeKernel,
ops::FillZerosLikeKernel<phi::CPUContext, float>, int,
ops::FillZerosLikeKernel<phi::CPUContext, double>, int64_t,
ops::FillZerosLikeKernel<phi::CPUContext, bool>, float,
ops::FillZerosLikeKernel<phi::CPUContext, paddle::platform::complex<float>>, double,
ops::FillZerosLikeKernel<phi::CPUContext, bool,
paddle::platform::complex<double>>); plat::complex<float>,
plat::complex<double>) {}
REGISTER_OP_CPU_KERNEL(
fill_zeros_like2, PD_REGISTER_STRUCT_KERNEL(fill_zeros_like2,
ops::FillZerosLikeKernel<phi::CPUContext, int>, CPU,
ops::FillZerosLikeKernel<phi::CPUContext, int64_t>, ALL_LAYOUT,
ops::FillZerosLikeKernel<phi::CPUContext, float>, ops::FillZerosLikeKernel2,
ops::FillZerosLikeKernel<phi::CPUContext, double>, int,
ops::FillZerosLikeKernel<phi::CPUContext, bool>, int64_t,
ops::FillZerosLikeKernel<phi::CPUContext, paddle::platform::complex<float>>, float,
ops::FillZerosLikeKernel<phi::CPUContext, double,
paddle::platform::complex<double>>); bool,
plat::complex<float>,
plat::complex<double>) {}
...@@ -19,26 +19,30 @@ limitations under the License. */ ...@@ -19,26 +19,30 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( namespace plat = paddle::platform;
fill_zeros_like,
ops::FillZerosLikeKernel<phi::GPUContext, int>, PD_REGISTER_STRUCT_KERNEL(fill_zeros_like,
ops::FillZerosLikeKernel<phi::GPUContext, int64_t>, GPU,
ops::FillZerosLikeKernel<phi::GPUContext, float>, ALL_LAYOUT,
ops::FillZerosLikeKernel<phi::GPUContext, double>, ops::FillZerosLikeKernel,
ops::FillZerosLikeKernel<phi::GPUContext, paddle::platform::float16>, int,
ops::FillZerosLikeKernel<phi::GPUContext, bool>, int64_t,
ops::FillZerosLikeKernel<phi::GPUContext, paddle::platform::complex<float>>, float,
ops::FillZerosLikeKernel<phi::GPUContext, double,
paddle::platform::complex<double>>); plat::float16,
bool,
REGISTER_OP_CUDA_KERNEL( plat::complex<float>,
fill_zeros_like2, plat::complex<double>) {}
ops::FillZerosLikeKernel<phi::GPUContext, int>,
ops::FillZerosLikeKernel<phi::GPUContext, int64_t>, PD_REGISTER_STRUCT_KERNEL(fill_zeros_like2,
ops::FillZerosLikeKernel<phi::GPUContext, float>, GPU,
ops::FillZerosLikeKernel<phi::GPUContext, double>, ALL_LAYOUT,
ops::FillZerosLikeKernel<phi::GPUContext, paddle::platform::float16>, ops::FillZerosLikeKernel2,
ops::FillZerosLikeKernel<phi::GPUContext, bool>, int,
ops::FillZerosLikeKernel<phi::GPUContext, paddle::platform::complex<float>>, int64_t,
ops::FillZerosLikeKernel<phi::GPUContext, float,
paddle::platform::complex<double>>); double,
plat::float16,
bool,
plat::complex<float>,
plat::complex<double>) {}
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FillZerosLikeKernel : public framework::OpKernel<T> { class FillZerosLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -33,5 +33,8 @@ class FillZerosLikeKernel : public framework::OpKernel<T> { ...@@ -33,5 +33,8 @@ class FillZerosLikeKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T, typename DeviceContext>
class FillZerosLikeKernel2 : public FillZerosLikeKernel<T, DeviceContext> {};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -162,14 +162,20 @@ REGISTER_OPERATOR(filter_by_instag, ...@@ -162,14 +162,20 @@ REGISTER_OPERATOR(filter_by_instag,
REGISTER_OPERATOR(filter_by_instag_grad, ops::FilterByInstagOpGrad); REGISTER_OPERATOR(filter_by_instag_grad, ops::FilterByInstagOpGrad);
REGISTER_OP_CPU_KERNEL(filter_by_instag, PD_REGISTER_STRUCT_KERNEL(filter_by_instag,
ops::FilterByInstagKernel<float>, CPU,
ops::FilterByInstagKernel<double>, ALL_LAYOUT,
ops::FilterByInstagKernel<int32_t>, ops::FilterByInstagKernel,
ops::FilterByInstagKernel<int64_t>); float,
double,
REGISTER_OP_CPU_KERNEL(filter_by_instag_grad, int32_t,
ops::FilterByInstagGradKernel<float>, int64_t) {}
ops::FilterByInstagGradKernel<double>,
ops::FilterByInstagGradKernel<int32_t>, PD_REGISTER_STRUCT_KERNEL(filter_by_instag_grad,
ops::FilterByInstagGradKernel<int64_t>); CPU,
ALL_LAYOUT,
ops::FilterByInstagGradKernel,
float,
double,
int32_t,
int64_t) {}
...@@ -325,7 +325,7 @@ __global__ void copy_grad_kernel(const size_t N, ...@@ -325,7 +325,7 @@ __global__ void copy_grad_kernel(const size_t N,
#endif #endif
template <typename T> template <typename T, typename DeviceContext>
class FilterByInstagGPUKernel : public framework::OpKernel<T> { class FilterByInstagGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -553,7 +553,7 @@ class FilterByInstagGPUKernel : public framework::OpKernel<T> { ...@@ -553,7 +553,7 @@ class FilterByInstagGPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FilterByInstagGradGPUKernel : public framework::OpKernel<T> { class FilterByInstagGradGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -620,14 +620,20 @@ class FilterByInstagGradGPUKernel : public framework::OpKernel<T> { ...@@ -620,14 +620,20 @@ class FilterByInstagGradGPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(filter_by_instag, PD_REGISTER_STRUCT_KERNEL(filter_by_instag,
ops::FilterByInstagGPUKernel<float>, GPU,
ops::FilterByInstagGPUKernel<double>, ALL_LAYOUT,
ops::FilterByInstagGPUKernel<int32_t>, ops::FilterByInstagGPUKernel,
ops::FilterByInstagGPUKernel<int64_t>); float,
double,
REGISTER_OP_CUDA_KERNEL(filter_by_instag_grad, int32_t,
ops::FilterByInstagGradGPUKernel<float>, int64_t) {}
ops::FilterByInstagGradGPUKernel<double>,
ops::FilterByInstagGradGPUKernel<int32_t>, PD_REGISTER_STRUCT_KERNEL(filter_by_instag_grad,
ops::FilterByInstagGradGPUKernel<int64_t>); GPU,
ALL_LAYOUT,
ops::FilterByInstagGradGPUKernel,
float,
double,
int32_t,
int64_t) {}
...@@ -34,7 +34,7 @@ using SelectedRows = phi::SelectedRows; ...@@ -34,7 +34,7 @@ using SelectedRows = phi::SelectedRows;
template <typename T> template <typename T>
using Vector = phi::Vector<T>; using Vector = phi::Vector<T>;
template <typename T> template <typename T, typename DeviceContext>
class FilterByInstagKernel : public framework::OpKernel<T> { class FilterByInstagKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -191,7 +191,7 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -191,7 +191,7 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FilterByInstagGradKernel : public framework::OpKernel<T> { class FilterByInstagGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -164,9 +164,8 @@ REGISTER_OPERATOR(fsp, ...@@ -164,9 +164,8 @@ REGISTER_OPERATOR(fsp,
ops::FSPGradOpMaker<paddle::framework::OpDesc>, ops::FSPGradOpMaker<paddle::framework::OpDesc>,
ops::FSPGradOpMaker<paddle::imperative::OpBase>); ops::FSPGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad); REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad);
REGISTER_OP_CPU_KERNEL(fsp,
ops::FSPOpKernel<phi::CPUContext, float>, PD_REGISTER_STRUCT_KERNEL(
ops::FSPOpKernel<phi::CPUContext, double>); fsp, CPU, ALL_LAYOUT, ops::FSPOpKernel, float, double) {}
REGISTER_OP_CPU_KERNEL(fsp_grad, PD_REGISTER_STRUCT_KERNEL(
ops::FSPGradOpKernel<phi::CPUContext, float>, fsp_grad, CPU, ALL_LAYOUT, ops::FSPGradOpKernel, float, double) {}
ops::FSPGradOpKernel<phi::CPUContext, double>);
...@@ -16,10 +16,7 @@ limitations under the License. */ ...@@ -16,10 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_CUDA_KERNEL(fsp, fsp, GPU, ALL_LAYOUT, ops::FSPOpKernel, float, double) {}
ops::FSPOpKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(
ops::FSPOpKernel<phi::GPUContext, double>); fsp_grad, GPU, ALL_LAYOUT, ops::FSPGradOpKernel, float, double) {}
REGISTER_OP_CUDA_KERNEL(fsp_grad,
ops::FSPGradOpKernel<phi::GPUContext, float>,
ops::FSPGradOpKernel<phi::GPUContext, double>);
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FSPOpKernel : public framework::OpKernel<T> { class FSPOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -64,7 +64,7 @@ class FSPOpKernel : public framework::OpKernel<T> { ...@@ -64,7 +64,7 @@ class FSPOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FSPGradOpKernel : public framework::OpKernel<T> { class FSPGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -33,9 +33,11 @@ namespace platform = paddle::platform; ...@@ -33,9 +33,11 @@ namespace platform = paddle::platform;
namespace op = paddle::operators; namespace op = paddle::operators;
USE_OP_ITSELF(batch_norm); USE_OP_ITSELF(batch_norm);
USE_OP_ITSELF(fused_bn_add_activation);
USE_OP_ITSELF(fused_bn_add_activation_grad);
PD_DECLARE_KERNEL(batch_norm, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(batch_norm, GPU, ALL_LAYOUT);
USE_CUDA_ONLY_OP(fused_bn_add_activation); PD_DECLARE_KERNEL(fused_bn_add_activation, GPU, ALL_LAYOUT);
USE_CUDA_ONLY_OP(fused_bn_add_activation_grad); PD_DECLARE_KERNEL(fused_bn_add_activation_grad, GPU, ALL_LAYOUT);
template <typename T> template <typename T>
void InitRandomTensor(const std::vector<int64_t> &dims, void InitRandomTensor(const std::vector<int64_t> &dims,
......
...@@ -75,7 +75,7 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT ...@@ -75,7 +75,7 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
#endif #endif
} }
template <typename T> template <typename T, typename DeviceContext>
class FusedAttentionOpKernel : public framework::OpKernel<T> { class FusedAttentionOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -402,7 +402,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -402,7 +402,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FusedAttentionGradKernel : public framework::OpKernel<T> { class FusedAttentionGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -826,11 +826,18 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -826,11 +826,18 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention,
ops::FusedAttentionOpKernel<float>, PD_REGISTER_STRUCT_KERNEL(fused_attention,
ops::FusedAttentionOpKernel<double>, GPU,
ops::FusedAttentionOpKernel<plat::float16>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL(fused_attention_grad, ops::FusedAttentionOpKernel,
ops::FusedAttentionGradKernel<float>, float,
ops::FusedAttentionGradKernel<double>, double,
ops::FusedAttentionGradKernel<plat::float16>); plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(fused_attention_grad,
GPU,
ALL_LAYOUT,
ops::FusedAttentionGradKernel,
float,
double,
plat::float16) {}
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> { class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -91,7 +91,7 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> { ...@@ -91,7 +91,7 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> { class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -176,12 +176,18 @@ class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> { ...@@ -176,12 +176,18 @@ class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_bias_dropout_residual_layer_norm,
ops::FusedBiasDropoutResidualLnOpKernel<float>, PD_REGISTER_STRUCT_KERNEL(fused_bias_dropout_residual_layer_norm,
ops::FusedBiasDropoutResidualLnOpKernel<double>, GPU,
ops::FusedBiasDropoutResidualLnOpKernel<plat::float16>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL( ops::FusedBiasDropoutResidualLnOpKernel,
fused_bias_dropout_residual_layer_norm_grad, float,
ops::FusedBiasDropoutResidualLnGradKernel<float>, double,
ops::FusedBiasDropoutResidualLnGradKernel<double>, plat::float16) {}
ops::FusedBiasDropoutResidualLnGradKernel<plat::float16>); PD_REGISTER_STRUCT_KERNEL(fused_bias_dropout_residual_layer_norm_grad,
GPU,
ALL_LAYOUT,
ops::FusedBiasDropoutResidualLnGradKernel,
float,
double,
plat::float16) {}
...@@ -36,10 +36,15 @@ template <typename T> ...@@ -36,10 +36,15 @@ template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType; using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T> template <typename T>
class FusedBatchNormActKernel<phi::GPUContext, T> class FusedBatchNormActKernel<T, phi::GPUContext>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
#if CUDNN_VERSION < 7401
PADDLE_THROW(phi::errors::Unimplemented(
"The fused_batch_norm_act operator is not supported on GPU "
"when CUDNN version < 7.4.1"));
#endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), platform::is_gpu_place(ctx.GetPlace()),
true, true,
...@@ -231,10 +236,15 @@ class FusedBatchNormActKernel<phi::GPUContext, T> ...@@ -231,10 +236,15 @@ class FusedBatchNormActKernel<phi::GPUContext, T>
}; };
template <typename T> template <typename T>
class FusedBatchNormActGradKernel<phi::GPUContext, T> class FusedBatchNormActGradKernel<T, phi::GPUContext>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
#if CUDNN_VERSION < 7401
PADDLE_THROW(phi::errors::Unimplemented(
"The fused_batch_norm_act operator is not supported on GPU "
"when CUDNN version < 7.4.1"));
#endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), platform::is_gpu_place(ctx.GetPlace()),
true, true,
...@@ -415,17 +425,19 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T> ...@@ -415,17 +425,19 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T>
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#if CUDNN_VERSION >= 7401
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(fused_batch_norm_act,
fused_batch_norm_act, GPU,
ops::FusedBatchNormActKernel<phi::GPUContext, float>, ALL_LAYOUT,
ops::FusedBatchNormActKernel<phi::GPUContext, double>, ops::FusedBatchNormActKernel,
ops::FusedBatchNormActKernel<phi::GPUContext, plat::float16>); float,
REGISTER_OP_CUDA_KERNEL( double,
fused_batch_norm_act_grad, plat::float16) {}
ops::FusedBatchNormActGradKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fused_batch_norm_act_grad,
ops::FusedBatchNormActGradKernel<phi::GPUContext, double>, GPU,
ops::FusedBatchNormActGradKernel<phi::GPUContext, plat::float16>); ALL_LAYOUT,
#endif ops::FusedBatchNormActGradKernel,
float,
double,
plat::float16) {}
...@@ -88,13 +88,13 @@ class FusedBatchNormActOpInferVarType ...@@ -88,13 +88,13 @@ class FusedBatchNormActOpInferVarType
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedBatchNormActKernel : public framework::OpKernel<T> { class FusedBatchNormActKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedBatchNormActGradKernel : public framework::OpKernel<T> { class FusedBatchNormActGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
......
...@@ -36,10 +36,15 @@ template <typename T> ...@@ -36,10 +36,15 @@ template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType; using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T> template <typename T>
class FusedBatchNormAddActKernel<phi::GPUContext, T> class FusedBatchNormAddActKernel<T, phi::GPUContext>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
#if CUDNN_VERSION < 7401
PADDLE_THROW(phi::errors::Unimplemented(
"The fused_bn_add_activation operator is not supported on GPU "
"when CUDNN version < 7.4.1"));
#endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), platform::is_gpu_place(ctx.GetPlace()),
true, true,
...@@ -208,10 +213,15 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T> ...@@ -208,10 +213,15 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T>
}; };
template <typename T> template <typename T>
class FusedBatchNormAddActGradKernel<phi::GPUContext, T> class FusedBatchNormAddActGradKernel<T, phi::GPUContext>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
#if CUDNN_VERSION < 7401
PADDLE_THROW(phi::errors::Unimplemented(
"The fused_bn_add_activation operator is not supported on GPU "
"when CUDNN version < 7.4.1"));
#endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), platform::is_gpu_place(ctx.GetPlace()),
true, true,
...@@ -362,13 +372,15 @@ class FusedBatchNormAddActGradKernel<phi::GPUContext, T> ...@@ -362,13 +372,15 @@ class FusedBatchNormAddActGradKernel<phi::GPUContext, T>
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#if CUDNN_VERSION >= 7401
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(fused_bn_add_activation,
fused_bn_add_activation, GPU,
ops::FusedBatchNormAddActKernel<phi::GPUContext, plat::float16>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL( ops::FusedBatchNormAddActKernel,
fused_bn_add_activation_grad, plat::float16) {}
ops::FusedBatchNormAddActGradKernel<phi::GPUContext, plat::float16>); PD_REGISTER_STRUCT_KERNEL(fused_bn_add_activation_grad,
#endif GPU,
ALL_LAYOUT,
ops::FusedBatchNormAddActGradKernel,
plat::float16) {}
...@@ -89,13 +89,13 @@ class FusedBatchNormAddActOpInferVarType ...@@ -89,13 +89,13 @@ class FusedBatchNormAddActOpInferVarType
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedBatchNormAddActKernel : public framework::OpKernel<T> { class FusedBatchNormAddActKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedBatchNormAddActGradKernel : public framework::OpKernel<T> { class FusedBatchNormAddActGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
......
...@@ -461,15 +461,19 @@ REGISTER_OPERATOR( ...@@ -461,15 +461,19 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(fused_elemwise_activation_grad, REGISTER_OPERATOR(fused_elemwise_activation_grad,
ops::FusedElemwiseActivationOpGrad); ops::FusedElemwiseActivationOpGrad);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(fused_elemwise_activation,
fused_elemwise_activation, CPU,
ops::FusedElemwiseActivationKernel<phi::CPUContext, float>, ALL_LAYOUT,
ops::FusedElemwiseActivationKernel<phi::CPUContext, double>); ops::FusedElemwiseActivationKernel,
float,
REGISTER_OP_CPU_KERNEL( double) {}
fused_elemwise_activation_grad,
ops::FusedElemwiseActivationGradKernel<phi::CPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fused_elemwise_activation_grad,
ops::FusedElemwiseActivationGradKernel<phi::CPUContext, double>); CPU,
ALL_LAYOUT,
ops::FusedElemwiseActivationGradKernel,
float,
double) {}
// for memory optimization, we register the fused_elemwise_add_activation OP // for memory optimization, we register the fused_elemwise_add_activation OP
REGISTER_OPERATOR( REGISTER_OPERATOR(
...@@ -482,12 +486,16 @@ REGISTER_OPERATOR(fused_elemwise_add_activation_grad, ...@@ -482,12 +486,16 @@ REGISTER_OPERATOR(fused_elemwise_add_activation_grad,
ops::FusedElemwiseAddActivationNoNeddBufVarInferer, ops::FusedElemwiseAddActivationNoNeddBufVarInferer,
ops::FusedElemwiseAddActivationOpGrad); ops::FusedElemwiseAddActivationOpGrad);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(fused_elemwise_add_activation,
fused_elemwise_add_activation, CPU,
ops::FusedElemwiseActivationKernel<phi::CPUContext, float>, ALL_LAYOUT,
ops::FusedElemwiseActivationKernel<phi::CPUContext, double>); ops::FusedElemwiseAddActivationKernel,
float,
REGISTER_OP_CPU_KERNEL( double) {}
fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<phi::CPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<phi::CPUContext, double>); CPU,
ALL_LAYOUT,
ops::FusedElemwiseAddActivationGradKernel,
float,
double) {}
...@@ -15,30 +15,34 @@ limitations under the License. */ ...@@ -15,30 +15,34 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_elemwise_activation_op.h" #include "paddle/fluid/operators/fused/fused_elemwise_activation_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( namespace plat = paddle::platform;
fused_elemwise_activation,
ops::FusedElemwiseActivationKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fused_elemwise_activation,
ops::FusedElemwiseActivationKernel<phi::GPUContext, double>, GPU,
ops::FusedElemwiseActivationKernel<phi::GPUContext, ALL_LAYOUT,
paddle::platform::float16>); ops::FusedElemwiseActivationKernel,
float,
REGISTER_OP_CUDA_KERNEL( double,
fused_elemwise_activation_grad, plat::float16) {}
ops::FusedElemwiseActivationGradKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fused_elemwise_activation_grad,
ops::FusedElemwiseActivationGradKernel<phi::GPUContext, double>, GPU,
ops::FusedElemwiseActivationGradKernel<phi::GPUContext, ALL_LAYOUT,
paddle::platform::float16>); ops::FusedElemwiseActivationGradKernel,
float,
REGISTER_OP_CUDA_KERNEL( double,
fused_elemwise_add_activation, plat::float16) {}
ops::FusedElemwiseActivationKernel<phi::GPUContext, float>,
ops::FusedElemwiseActivationKernel<phi::GPUContext, double>, PD_REGISTER_STRUCT_KERNEL(fused_elemwise_add_activation,
ops::FusedElemwiseActivationKernel<phi::GPUContext, GPU,
paddle::platform::float16>); ALL_LAYOUT,
ops::FusedElemwiseAddActivationKernel,
REGISTER_OP_CUDA_KERNEL( float,
fused_elemwise_add_activation_grad, double,
ops::FusedElemwiseActivationGradKernel<phi::GPUContext, float>, plat::float16) {}
ops::FusedElemwiseActivationGradKernel<phi::GPUContext, double>, PD_REGISTER_STRUCT_KERNEL(fused_elemwise_add_activation_grad,
ops::FusedElemwiseActivationGradKernel<phi::GPUContext, GPU,
paddle::platform::float16>); ALL_LAYOUT,
ops::FusedElemwiseAddActivationGradKernel,
float,
double,
plat::float16) {}
...@@ -616,7 +616,7 @@ static void RunGradFunctors(const framework::ExecutionContext &ctx, ...@@ -616,7 +616,7 @@ static void RunGradFunctors(const framework::ExecutionContext &ctx,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> { class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -655,7 +655,7 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> { ...@@ -655,7 +655,7 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> { class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -765,5 +765,14 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> { ...@@ -765,5 +765,14 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
} }
} }
}; };
template <typename T, typename DeviceContext>
class FusedElemwiseAddActivationKernel
: public FusedElemwiseActivationKernel<T, DeviceContext> {};
template <typename T, typename DeviceContext>
class FusedElemwiseAddActivationGradKernel
: public FusedElemwiseActivationGradKernel<T, DeviceContext> {};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> { class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -145,14 +145,18 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> { ...@@ -145,14 +145,18 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(fused_embedding_eltwise_layernorm,
fused_embedding_eltwise_layernorm, GPU,
ops::EmbeddingEltWiseLayerNormKernel<phi::GPUContext, float>, ALL_LAYOUT,
ops::EmbeddingEltWiseLayerNormKernel<phi::GPUContext, ops::EmbeddingEltWiseLayerNormKernel,
paddle::platform::float16>); float,
plat::float16) {}
#else #else
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(fused_embedding_eltwise_layernorm,
fused_embedding_eltwise_layernorm, GPU,
ops::EmbeddingEltWiseLayerNormKernel<phi::GPUContext, float>); ALL_LAYOUT,
ops::EmbeddingEltWiseLayerNormKernel,
float) {}
#endif #endif
...@@ -270,7 +270,7 @@ This operator fuse the X into LSTM, more details can refer to LSTM op. ...@@ -270,7 +270,7 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
public: public:
#define INIT_VEC_FUNC \ #define INIT_VEC_FUNC \
...@@ -396,7 +396,6 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -396,7 +396,6 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
GET_Ht(ct, gates, ht) GET_Ht(ct, gates, ht)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = phi::CPUContext;
INIT_BASE_INPUT_OUTPUT INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES INIT_BASE_SIZES
INIT_VEC_FUNC INIT_VEC_FUNC
...@@ -502,7 +501,6 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -502,7 +501,6 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
} }
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = phi::CPUContext;
INIT_BASE_INPUT_OUTPUT INIT_BASE_INPUT_OUTPUT
if (ids->lod()[0].size() == 2) { if (ids->lod()[0].size() == 2) {
SeqCompute(ctx); SeqCompute(ctx);
...@@ -682,6 +680,9 @@ REGISTER_OPERATOR(fused_embedding_fc_lstm, ...@@ -682,6 +680,9 @@ REGISTER_OPERATOR(fused_embedding_fc_lstm,
ops::FusedEmbeddingFCLSTMOp, ops::FusedEmbeddingFCLSTMOp,
ops::FusedEmbeddingFCLSTMOpMaker); ops::FusedEmbeddingFCLSTMOpMaker);
REGISTER_OP_CPU_KERNEL(fused_embedding_fc_lstm, PD_REGISTER_STRUCT_KERNEL(fused_embedding_fc_lstm,
ops::FusedEmbeddingFCLSTMKernel<float>, CPU,
ops::FusedEmbeddingFCLSTMKernel<double>); ALL_LAYOUT,
ops::FusedEmbeddingFCLSTMKernel,
float,
double) {}
...@@ -201,9 +201,15 @@ REGISTER_OPERATOR(fused_embedding_seq_pool_grad, ...@@ -201,9 +201,15 @@ REGISTER_OPERATOR(fused_embedding_seq_pool_grad,
ops::FusedEmbeddingSeqPoolOpGrad, ops::FusedEmbeddingSeqPoolOpGrad,
ops::FusedEmbeddingSeqPoolOpGradVarTypeInference); ops::FusedEmbeddingSeqPoolOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool, PD_REGISTER_STRUCT_KERNEL(fused_embedding_seq_pool,
ops::FusedEmbeddingSeqPoolKernel<float>, CPU,
ops::FusedEmbeddingSeqPoolKernel<double>); ALL_LAYOUT,
REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool_grad, ops::FusedEmbeddingSeqPoolKernel,
ops::FusedEmbeddingSeqPoolGradKernel<float>, float,
ops::FusedEmbeddingSeqPoolGradKernel<double>); double) {}
PD_REGISTER_STRUCT_KERNEL(fused_embedding_seq_pool_grad,
CPU,
ALL_LAYOUT,
ops::FusedEmbeddingSeqPoolGradKernel,
float,
double) {}
...@@ -135,7 +135,7 @@ inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims, ...@@ -135,7 +135,7 @@ inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims,
return last_dim; return last_dim;
} }
template <typename T> template <typename T, typename DeviceContext>
class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -224,7 +224,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -224,7 +224,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
......
...@@ -374,7 +374,7 @@ void AddReluAddLayerNorm(gpuStream_t stream, ...@@ -374,7 +374,7 @@ void AddReluAddLayerNorm(gpuStream_t stream,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> { class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -449,8 +449,12 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> { ...@@ -449,8 +449,12 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( namespace plat = paddle::platform;
fused_fc_elementwise_layernorm,
ops::FusedFCElementwiseLayerNormOpKernel<phi::dtype::float16>, PD_REGISTER_STRUCT_KERNEL(fused_fc_elementwise_layernorm,
ops::FusedFCElementwiseLayerNormOpKernel<float>, GPU,
ops::FusedFCElementwiseLayerNormOpKernel<double>); ALL_LAYOUT,
ops::FusedFCElementwiseLayerNormOpKernel,
float,
double,
plat::float16) {}
...@@ -65,7 +65,7 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT ...@@ -65,7 +65,7 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
#endif #endif
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedFeedForwardKernel : public framework::OpKernel<T> { class FusedFeedForwardKernel : public framework::OpKernel<T> {
public: public:
void MatMul(const phi::GPUContext& ctx, void MatMul(const phi::GPUContext& ctx,
...@@ -301,7 +301,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -301,7 +301,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedFeedForwardGradKernel : public framework::OpKernel<T> { class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
public: public:
void MatMulGrad(const phi::GPUContext& ctx, void MatMulGrad(const phi::GPUContext& ctx,
...@@ -628,14 +628,19 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -628,14 +628,19 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( namespace plat = paddle::platform;
fused_feedforward,
ops::FusedFeedForwardKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fused_feedforward,
ops::FusedFeedForwardKernel<phi::GPUContext, double>, GPU,
ops::FusedFeedForwardKernel<phi::GPUContext, paddle::platform::float16>); ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL( ops::FusedFeedForwardKernel,
fused_feedforward_grad, float,
ops::FusedFeedForwardGradKernel<phi::GPUContext, float>, double,
ops::FusedFeedForwardGradKernel<phi::GPUContext, double>, plat::float16) {}
ops::FusedFeedForwardGradKernel<phi::GPUContext, PD_REGISTER_STRUCT_KERNEL(fused_feedforward_grad,
paddle::platform::float16>); GPU,
ALL_LAYOUT,
ops::FusedFeedForwardGradKernel,
float,
double,
plat::float16) {}
...@@ -354,7 +354,7 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, ...@@ -354,7 +354,7 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
use_fused_matmul_bias); use_fused_matmul_bias);
} }
template <typename T> template <typename T, typename DeviceContext>
class FusedGateAttentionOpKernel : public framework::OpKernel<T> { class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -446,7 +446,7 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -446,7 +446,7 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FusedGateAttentionGradKernel : public framework::OpKernel<T> { class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -565,23 +565,35 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -565,23 +565,35 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(fused_gate_attention, PD_REGISTER_STRUCT_KERNEL(fused_gate_attention,
ops::FusedGateAttentionOpKernel<float>, GPU,
ops::FusedGateAttentionOpKernel<plat::float16>, ALL_LAYOUT,
ops::FusedGateAttentionOpKernel<plat::bfloat16>); ops::FusedGateAttentionOpKernel,
REGISTER_OP_CUDA_KERNEL(fused_gate_attention_grad, float,
ops::FusedGateAttentionGradKernel<float>, plat::float16,
ops::FusedGateAttentionGradKernel<plat::float16>, plat::bfloat16) {}
ops::FusedGateAttentionGradKernel<plat::bfloat16>); PD_REGISTER_STRUCT_KERNEL(fused_gate_attention_grad,
GPU,
ALL_LAYOUT,
ops::FusedGateAttentionGradKernel,
float,
plat::float16,
plat::bfloat16) {}
#else #else
REGISTER_OP_CUDA_KERNEL(fused_gate_attention, PD_REGISTER_STRUCT_KERNEL(fused_gate_attention,
ops::FusedGateAttentionOpKernel<float>, GPU,
ops::FusedGateAttentionOpKernel<double>, ALL_LAYOUT,
ops::FusedGateAttentionOpKernel<plat::float16>, ops::FusedGateAttentionOpKernel,
ops::FusedGateAttentionOpKernel<plat::bfloat16>); float,
REGISTER_OP_CUDA_KERNEL(fused_gate_attention_grad, double,
ops::FusedGateAttentionGradKernel<float>, plat::float16,
ops::FusedGateAttentionGradKernel<double>, plat::bfloat16) {}
ops::FusedGateAttentionGradKernel<plat::float16>, PD_REGISTER_STRUCT_KERNEL(fused_gate_attention_grad,
ops::FusedGateAttentionGradKernel<plat::bfloat16>); GPU,
ALL_LAYOUT,
ops::FusedGateAttentionGradKernel,
float,
double,
plat::float16,
plat::bfloat16) {}
#endif #endif
...@@ -61,10 +61,15 @@ phi::funcs::MatmulFusedType GetFwdFusedEpilogueType( ...@@ -61,10 +61,15 @@ phi::funcs::MatmulFusedType GetFwdFusedEpilogueType(
return fused_type; return fused_type;
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedGemmEpilogueKernel : public framework::OpKernel<T> { class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if CUDA_VERSION < 11060
PADDLE_THROW(phi::errors::Unimplemented(
"The fused_gemm_epilogue operator only support CUDA 11.6 "
"or higher version."));
#endif
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
const phi::DenseTensor* x = ctx.Input<phi::DenseTensor>("X"); const phi::DenseTensor* x = ctx.Input<phi::DenseTensor>("X");
...@@ -119,10 +124,15 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -119,10 +124,15 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if CUDA_VERSION < 11060
PADDLE_THROW(phi::errors::Unimplemented(
"The fused_gemm_epilogue operator only support CUDA 11.6 "
"or higher version."));
#endif
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
const phi::DenseTensor* dout = ctx.Input<phi::DenseTensor>("DOut"); const phi::DenseTensor* dout = ctx.Input<phi::DenseTensor>("DOut");
...@@ -172,21 +182,21 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -172,21 +182,21 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#if CUDA_VERSION >= 11060
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( namespace plat = paddle::platform;
fused_gemm_epilogue, PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue,
ops::FusedGemmEpilogueKernel<phi::GPUContext, float>, GPU,
ops::FusedGemmEpilogueKernel<phi::GPUContext, double>, ALL_LAYOUT,
ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::float16>, ops::FusedGemmEpilogueKernel,
ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::bfloat16>); float,
double,
REGISTER_OP_CUDA_KERNEL( plat::float16,
fused_gemm_epilogue_grad, plat::bfloat16) {}
ops::FusedGemmEpilogueGradKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue_grad,
ops::FusedGemmEpilogueGradKernel<phi::GPUContext, double>, GPU,
ops::FusedGemmEpilogueGradKernel<phi::GPUContext, ALL_LAYOUT,
paddle::platform::float16>, ops::FusedGemmEpilogueGradKernel,
ops::FusedGemmEpilogueGradKernel<phi::GPUContext, float,
paddle::platform::bfloat16>); double,
#endif plat::float16,
plat::bfloat16) {}
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -662,6 +662,9 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -662,6 +662,9 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_int8, PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer_int8,
ops::FusedMultiTransformerINT8OpKernel<plat::float16>, GPU,
ops::FusedMultiTransformerINT8OpKernel<float>); ALL_LAYOUT,
ops::FusedMultiTransformerINT8OpKernel,
float,
plat::float16) {}
...@@ -19,7 +19,7 @@ namespace operators { ...@@ -19,7 +19,7 @@ namespace operators {
#if CUDA_VERSION >= 11060 // Use cublasLt to fuse FFN operation. #if CUDA_VERSION >= 11060 // Use cublasLt to fuse FFN operation.
template <typename T> template <typename T, typename DeviceContext>
class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -685,7 +685,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -685,7 +685,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
#else #else
template <typename T> template <typename T, typename DeviceContext>
class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -1370,6 +1370,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1370,6 +1370,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_multi_transformer, PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer,
ops::FusedMultiTransformerOpKernel<plat::float16>, GPU,
ops::FusedMultiTransformerOpKernel<float>); ALL_LAYOUT,
ops::FusedMultiTransformerOpKernel,
float,
plat::float16) {}
...@@ -290,7 +290,13 @@ REGISTER_OPERATOR(fused_seqpool_cvm, ...@@ -290,7 +290,13 @@ REGISTER_OPERATOR(fused_seqpool_cvm,
ops::FusedSeqpoolCVMGradOpMaker<paddle::imperative::OpBase>); ops::FusedSeqpoolCVMGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_seqpool_cvm_grad, ops::FusedSeqpoolCVMGradOp) REGISTER_OPERATOR(fused_seqpool_cvm_grad, ops::FusedSeqpoolCVMGradOp)
REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm, PD_REGISTER_STRUCT_KERNEL(fused_seqpool_cvm,
ops::FusedSeqpoolCVMOpCPUKernel<float>) CPU,
REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm_grad, ALL_LAYOUT,
ops::FusedSeqpoolCVMGradOpCPUKernel<float>) ops::FusedSeqpoolCVMOpCPUKernel,
float) {}
PD_REGISTER_STRUCT_KERNEL(fused_seqpool_cvm_grad,
CPU,
ALL_LAYOUT,
ops::FusedSeqpoolCVMGradOpCPUKernel,
float) {}
...@@ -420,7 +420,7 @@ void FusedSeqpoolCVMGrad(const framework::ExecutionContext &ctx, ...@@ -420,7 +420,7 @@ void FusedSeqpoolCVMGrad(const framework::ExecutionContext &ctx,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> { class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -505,7 +505,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> { ...@@ -505,7 +505,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> { class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -588,8 +588,11 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> { ...@@ -588,8 +588,11 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm, PD_REGISTER_STRUCT_KERNEL(
ops::FusedSeqpoolCVMCUDAKernel<float>); fused_seqpool_cvm, GPU, ALL_LAYOUT, ops::FusedSeqpoolCVMCUDAKernel, float) {
}
REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_grad, PD_REGISTER_STRUCT_KERNEL(fused_seqpool_cvm_grad,
ops::FusedSeqpoolCVMGradCUDAKernel<float>); GPU,
ALL_LAYOUT,
ops::FusedSeqpoolCVMGradCUDAKernel,
float) {}
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class FusedSeqpoolCVMOpCPUKernel : public framework::OpKernel<T> { class FusedSeqpoolCVMOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -33,7 +33,7 @@ class FusedSeqpoolCVMOpCPUKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,7 @@ class FusedSeqpoolCVMOpCPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class FusedSeqpoolCVMGradOpCPUKernel : public framework::OpKernel<T> { class FusedSeqpoolCVMGradOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -34,10 +34,15 @@ using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; ...@@ -34,10 +34,15 @@ using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T, typename DeviceContext>
class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> { class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if CUDNN_VERSION < 7100
PADDLE_THROW(phi::errors::Unimplemented(
"The conv2d_inception_fusion operator is not supported on GPU "
"when CUDNN version < 7.1.0"));
#endif
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
auto* input = ctx.Input<phi::DenseTensor>("Input"); auto* input = ctx.Input<phi::DenseTensor>("Input");
auto filters = ctx.MultiInput<phi::DenseTensor>("Filter"); auto filters = ctx.MultiInput<phi::DenseTensor>("Filter");
...@@ -336,9 +341,10 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> { ...@@ -336,9 +341,10 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#if CUDNN_VERSION >= 7100
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_inception_fusion, PD_REGISTER_STRUCT_KERNEL(conv2d_inception_fusion,
ops::CUDNNConvInceptionFusionOpKernel<float>, GPU,
ops::CUDNNConvInceptionFusionOpKernel<double>); ALL_LAYOUT,
#endif ops::CUDNNConvInceptionFusionOpKernel,
float,
double) {}
...@@ -18,7 +18,10 @@ limitations under the License. */ ...@@ -18,7 +18,10 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fusion_group, PD_REGISTER_STRUCT_KERNEL(fusion_group,
ops::FusionGroupKernel<phi::GPUContext, float>, GPU,
ops::FusionGroupKernel<phi::GPUContext, double>, ALL_LAYOUT,
ops::FusionGroupKernel<phi::GPUContext, plat::float16>); ops::FusionGroupKernel,
float,
double,
plat::float16) {}
...@@ -42,7 +42,7 @@ static void MutableMultiTypeData(std::vector<phi::DenseTensor*>* var, ...@@ -42,7 +42,7 @@ static void MutableMultiTypeData(std::vector<phi::DenseTensor*>* var,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FusionGroupKernel : public framework::OpKernel<T> { class FusionGroupKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -234,4 +234,5 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) { ...@@ -234,4 +234,5 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
USE_CUDA_ONLY_OP(fusion_group); USE_OP_ITSELF(fusion_group);
PD_DECLARE_KERNEL(fusion_group, GPU, ALL_LAYOUT);
...@@ -249,7 +249,7 @@ more details can refer to GRU op. ...@@ -249,7 +249,7 @@ more details can refer to GRU op.
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FusionGRUKernel : public framework::OpKernel<T> { class FusionGRUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -303,7 +303,6 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -303,7 +303,6 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = phi::CPUContext;
INIT_BASE_DEFINES; INIT_BASE_DEFINES;
INIT_OTHER_DEFINES; INIT_OTHER_DEFINES;
const int N = x_lod[0].size() - 1; const int N = x_lod[0].size() - 1;
...@@ -394,7 +393,6 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -394,7 +393,6 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = phi::CPUContext;
INIT_BASE_DEFINES; INIT_BASE_DEFINES;
if (x_lod[0].size() == 2) { if (x_lod[0].size() == 2) {
xx->Resize({total_T, D3}); xx->Resize({total_T, D3});
...@@ -551,9 +549,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -551,9 +549,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker); REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_gru, PD_REGISTER_STRUCT_KERNEL(
ops::FusionGRUKernel<float>, fusion_gru, CPU, ALL_LAYOUT, ops::FusionGRUKernel, float, double) {}
ops::FusionGRUKernel<double>);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(fusion_gru) REGISTER_OP_VERSION(fusion_gru)
......
...@@ -298,11 +298,10 @@ This operator fuse the X into LSTM, more details can refer to LSTM op. ...@@ -298,11 +298,10 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FuisonLSTMKernel : public framework::OpKernel<T> { class FuisonLSTMKernel : public framework::OpKernel<T> {
public: public:
#define INIT_BASE_DEFINES \ #define INIT_BASE_DEFINES \
using DeviceContext = phi::CPUContext; \
auto* x = ctx.Input<phi::DenseTensor>("X"); \ auto* x = ctx.Input<phi::DenseTensor>("X"); \
auto* h0 = ctx.Input<phi::DenseTensor>("H0"); \ auto* h0 = ctx.Input<phi::DenseTensor>("H0"); \
auto* c0 = ctx.Input<phi::DenseTensor>("C0"); \ auto* c0 = ctx.Input<phi::DenseTensor>("C0"); \
...@@ -580,6 +579,5 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -580,6 +579,5 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker); REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_lstm, PD_REGISTER_STRUCT_KERNEL(
ops::FuisonLSTMKernel<float>, fusion_lstm, CPU, ALL_LAYOUT, ops::FuisonLSTMKernel, float, double) {}
ops::FuisonLSTMKernel<double>);
...@@ -141,7 +141,7 @@ static void fc_relu(const T* x, ...@@ -141,7 +141,7 @@ static void fc_relu(const T* x,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class FusionRepeatedFCReluKernel : public framework::OpKernel<T> { class FusionRepeatedFCReluKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -201,6 +201,9 @@ REGISTER_OPERATOR(fusion_repeated_fc_relu, ...@@ -201,6 +201,9 @@ REGISTER_OPERATOR(fusion_repeated_fc_relu,
ops::FusionRepeatedFCReluOp, ops::FusionRepeatedFCReluOp,
ops::FusionRepeatedFCReluOpMaker); ops::FusionRepeatedFCReluOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_repeated_fc_relu, PD_REGISTER_STRUCT_KERNEL(fusion_repeated_fc_relu,
ops::FusionRepeatedFCReluKernel<float>, CPU,
ops::FusionRepeatedFCReluKernel<double>); ALL_LAYOUT,
ops::FusionRepeatedFCReluKernel,
float,
double) {}
...@@ -148,11 +148,10 @@ Fusion Sequence Conv and ElementwiseAdd Operator. ...@@ -148,11 +148,10 @@ Fusion Sequence Conv and ElementwiseAdd Operator.
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> { class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = phi::CPUContext;
auto* x = ctx.Input<phi::DenseTensor>("X"); auto* x = ctx.Input<phi::DenseTensor>("X");
auto* w = ctx.Input<phi::DenseTensor>("Filter"); auto* w = ctx.Input<phi::DenseTensor>("Filter");
auto* b = ctx.Input<phi::DenseTensor>("Bias"); auto* b = ctx.Input<phi::DenseTensor>("Bias");
...@@ -283,6 +282,9 @@ REGISTER_OPERATOR(fusion_seqconv_eltadd_relu, ...@@ -283,6 +282,9 @@ REGISTER_OPERATOR(fusion_seqconv_eltadd_relu,
ops::FusionSeqConvEltAddReluOp, ops::FusionSeqConvEltAddReluOp,
ops::FusionSeqConvEltAddReluOpMaker); ops::FusionSeqConvEltAddReluOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_seqconv_eltadd_relu, PD_REGISTER_STRUCT_KERNEL(fusion_seqconv_eltadd_relu,
ops::FusionSeqConvEltAddReluKernel<float>, CPU,
ops::FusionSeqConvEltAddReluKernel<double>); ALL_LAYOUT,
ops::FusionSeqConvEltAddReluKernel,
float,
double) {}
...@@ -147,11 +147,10 @@ The concat axis should be 1. ...@@ -147,11 +147,10 @@ The concat axis should be 1.
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = phi::CPUContext;
auto ins = ctx.MultiInput<phi::DenseTensor>("X"); auto ins = ctx.MultiInput<phi::DenseTensor>("X");
auto* w = ctx.Input<phi::DenseTensor>("FCWeight"); auto* w = ctx.Input<phi::DenseTensor>("FCWeight");
auto* b = ctx.Input<phi::DenseTensor>("FCBias"); auto* b = ctx.Input<phi::DenseTensor>("FCBias");
...@@ -295,6 +294,9 @@ REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ...@@ -295,6 +294,9 @@ REGISTER_OPERATOR(fusion_seqexpand_concat_fc,
ops::FusionSeqExpandConcatFCOp, ops::FusionSeqExpandConcatFCOp,
ops::FusionSeqExpandConcatFCOpMaker); ops::FusionSeqExpandConcatFCOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc, PD_REGISTER_STRUCT_KERNEL(fusion_seqexpand_concat_fc,
ops::FusionSeqExpandConcatFCOpKernel<float>, CPU,
ops::FusionSeqExpandConcatFCOpKernel<double>); ALL_LAYOUT,
ops::FusionSeqExpandConcatFCOpKernel,
float,
double) {}
...@@ -92,7 +92,7 @@ Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator. ...@@ -92,7 +92,7 @@ Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -173,6 +173,9 @@ REGISTER_OPERATOR(fusion_seqpool_concat, ...@@ -173,6 +173,9 @@ REGISTER_OPERATOR(fusion_seqpool_concat,
ops::FusionSeqPoolConcatOp, ops::FusionSeqPoolConcatOp,
ops::FusionSeqPoolConcatOpMaker); ops::FusionSeqPoolConcatOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_seqpool_concat, PD_REGISTER_STRUCT_KERNEL(fusion_seqpool_concat,
ops::FusionSeqPoolConcatKernel<float>, CPU,
ops::FusionSeqPoolConcatKernel<double>); ALL_LAYOUT,
ops::FusionSeqPoolConcatKernel,
float,
double) {}
...@@ -96,7 +96,7 @@ Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator. ...@@ -96,7 +96,7 @@ Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FusionSeqPoolCVMConcatKernel : public framework::OpKernel<T> { class FusionSeqPoolCVMConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -172,6 +172,9 @@ REGISTER_OPERATOR( ...@@ -172,6 +172,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fusion_seqpool_cvm_concat, PD_REGISTER_STRUCT_KERNEL(fusion_seqpool_cvm_concat,
ops::FusionSeqPoolCVMConcatKernel<float>, CPU,
ops::FusionSeqPoolCVMConcatKernel<double>); ALL_LAYOUT,
ops::FusionSeqPoolCVMConcatKernel,
float,
double) {}
...@@ -84,7 +84,7 @@ void FusionSquaredMatSubOpMaker::Make() { ...@@ -84,7 +84,7 @@ void FusionSquaredMatSubOpMaker::Make() {
)DOC"); )DOC");
} }
template <typename T> template <typename T, typename DeviceContext>
class FusionSquaredMatSubKernel : public framework::OpKernel<T> { class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -151,6 +151,9 @@ REGISTER_OPERATOR(fusion_squared_mat_sub, ...@@ -151,6 +151,9 @@ REGISTER_OPERATOR(fusion_squared_mat_sub,
ops::FusionSquaredMatSubOp, ops::FusionSquaredMatSubOp,
ops::FusionSquaredMatSubOpMaker); ops::FusionSquaredMatSubOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_squared_mat_sub, PD_REGISTER_STRUCT_KERNEL(fusion_squared_mat_sub,
ops::FusionSquaredMatSubKernel<float>, CPU,
ops::FusionSquaredMatSubKernel<double>); ALL_LAYOUT,
ops::FusionSquaredMatSubKernel,
float,
double) {}
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T, typename DeviceContext>
class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> { class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -119,6 +119,10 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> { ...@@ -119,6 +119,10 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fusion_transpose_flatten_concat,
ops::TransposeFlattenConcatFusionKernel<float>, PD_REGISTER_STRUCT_KERNEL(fusion_transpose_flatten_concat,
ops::TransposeFlattenConcatFusionKernel<double>); GPU,
ALL_LAYOUT,
ops::TransposeFlattenConcatFusionKernel,
float,
double) {}
...@@ -102,7 +102,10 @@ REGISTER_OPERATOR( ...@@ -102,7 +102,10 @@ REGISTER_OPERATOR(
ops::SoftmaxMaskFuseUpperTriangleGradOpMaker<paddle::imperative::OpBase>); ops::SoftmaxMaskFuseUpperTriangleGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_softmax_mask_upper_triangle_grad, REGISTER_OPERATOR(fused_softmax_mask_upper_triangle_grad,
ops::SoftmaxMaskFuseUpperTriangleOpGrad); ops::SoftmaxMaskFuseUpperTriangleOpGrad);
REGISTER_OP_CPU_KERNEL(
fused_softmax_mask_upper_triangle, PD_REGISTER_STRUCT_KERNEL(fused_softmax_mask_upper_triangle,
ops::SoftmaxMaskFuseUpperTriangleCPUKernel<phi::CPUContext, float>, CPU,
ops::SoftmaxMaskFuseUpperTriangleCPUKernel<phi::CPUContext, double>); ALL_LAYOUT,
ops::SoftmaxMaskFuseUpperTriangleCPUKernel,
float,
double) {}
...@@ -354,7 +354,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, ...@@ -354,7 +354,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
} }
} }
template <typename Place, typename T> template <typename T, typename DeviceContext>
class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> { class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -386,7 +386,8 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> { ...@@ -386,7 +386,8 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
"received the last dimension of x is %d", "received the last dimension of x is %d",
key_seq_len)); key_seq_len));
auto& place = *context.template device_context<Place>().eigen_device(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
int pow2_index = get_pow2_index_value(key_seq_len); int pow2_index = get_pow2_index_value(key_seq_len);
...@@ -470,7 +471,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> { ...@@ -470,7 +471,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T> template <typename T, typename DeviceContext>
class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> { class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -491,7 +492,8 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> { ...@@ -491,7 +492,8 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
auto query_seq_len = y_dim[2]; auto query_seq_len = y_dim[2];
auto key_seq_len = y_dim[3]; auto key_seq_len = y_dim[3];
auto& place = *context.template device_context<Place>().eigen_device(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
int pow2_index = get_pow2_index_value(key_seq_len); int pow2_index = get_pow2_index_value(key_seq_len);
...@@ -602,14 +604,18 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> { ...@@ -602,14 +604,18 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle, PD_REGISTER_STRUCT_KERNEL(fused_softmax_mask_upper_triangle,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>, GPU,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::bfloat16>, ALL_LAYOUT,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>); ops::SoftmaxMaskFuseUpperTriangleKernel,
REGISTER_OP_CUDA_KERNEL( float,
fused_softmax_mask_upper_triangle_grad, plat::float16,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>, plat::bfloat16) {}
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, PD_REGISTER_STRUCT_KERNEL(fused_softmax_mask_upper_triangle_grad,
plat::bfloat16>, GPU,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>); ALL_LAYOUT,
ops::SoftmaxMaskFuseUpperTriangleGradKernel,
float,
plat::float16,
plat::bfloat16) {}
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class SoftmaxMaskFuseUpperTriangleCPUKernel : public framework::OpKernel<T> { class SoftmaxMaskFuseUpperTriangleCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -79,7 +79,7 @@ __global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) { ...@@ -79,7 +79,7 @@ __global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) {
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> { class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -283,6 +283,9 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> { ...@@ -283,6 +283,9 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fused_token_prune, PD_REGISTER_STRUCT_KERNEL(fused_token_prune,
ops::FusedTokenPruneOpCUDAKernel<float>, GPU,
ops::FusedTokenPruneOpCUDAKernel<double>); ALL_LAYOUT,
ops::FusedTokenPruneOpCUDAKernel,
float,
double) {}
...@@ -156,4 +156,5 @@ The paper that proposed Follow The Regularized Leader (FTRL): ...@@ -156,4 +156,5 @@ The paper that proposed Follow The Regularized Leader (FTRL):
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker); REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker);
REGISTER_OP_CPU_KERNEL(ftrl, ops::FTRLOpKernel<phi::CPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(ftrl, CPU, ALL_LAYOUT, ops::FTRLOpKernel, float) {}
...@@ -13,4 +13,4 @@ specific language governing permissions and limitations under the License. */ ...@@ -13,4 +13,4 @@ specific language governing permissions and limitations under the License. */
#include "paddle/fluid/operators/optimizers/ftrl_op.h" #include "paddle/fluid/operators/optimizers/ftrl_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(ftrl, ops::FTRLOpKernel<phi::GPUContext, float>); PD_REGISTER_STRUCT_KERNEL(ftrl, GPU, ALL_LAYOUT, ops::FTRLOpKernel, float) {}
...@@ -113,7 +113,7 @@ class SparseFTRLFunctor { ...@@ -113,7 +113,7 @@ class SparseFTRLFunctor {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class FTRLOpKernel : public framework::OpKernel<T> { class FTRLOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -541,4 +541,5 @@ REGISTER_OPERATOR(faster_tokenizer, ...@@ -541,4 +541,5 @@ REGISTER_OPERATOR(faster_tokenizer,
ops::FasterTokenizerOp, ops::FasterTokenizerOp,
ops::FasterTokenizerOpMaker); ops::FasterTokenizerOpMaker);
REGISTER_OP_CPU_KERNEL(faster_tokenizer, ops::FasterTokenizerKernel<int64_t>); PD_REGISTER_STRUCT_KERNEL(
faster_tokenizer, CPU, ALL_LAYOUT, ops::FasterTokenizerKernel, int64_t) {}
...@@ -122,7 +122,7 @@ class BertTokenizer { ...@@ -122,7 +122,7 @@ class BertTokenizer {
InvVocab inv_vocab_; InvVocab inv_vocab_;
}; };
template <typename T> template <typename T, typename DeviceContext>
class FasterTokenizerKernel : public framework::OpKernel<T> { class FasterTokenizerKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册