未验证 提交 1b5eba8a 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Unify fluid kernel (Part4) (#52626)

* unify kernel

* fix ci bugs

* fix py3 bugs

* fix py3 bugs

* perfect code
上级 5e29f30c
...@@ -96,7 +96,7 @@ register_operators(EXCLUDES py_func_op dgc_op generated_op1 generated_op2 genera ...@@ -96,7 +96,7 @@ register_operators(EXCLUDES py_func_op dgc_op generated_op1 generated_op2 genera
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS}) op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(run_program_op DEPS executor_cache ${OP_HEADER_DEPS})
target_link_libraries(run_program_op cuda_graph_with_memory_pool) target_link_libraries(run_program_op cuda_graph_with_memory_pool)
op_library(quantize_linear_op DEPS phi) op_library(quantize_linear_op DEPS phi)
op_library(save_combine_op DEPS string_array phi) op_library(save_combine_op DEPS string_array phi)
......
...@@ -114,9 +114,12 @@ namespace plat = paddle::platform; ...@@ -114,9 +114,12 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(recv_v2, ops::RecvOpV2, ops::RecvOpV2Maker); REGISTER_OP_WITHOUT_GRADIENT(recv_v2, ops::RecvOpV2, ops::RecvOpV2Maker);
REGISTER_OP_CPU_KERNEL(recv_v2, PD_REGISTER_STRUCT_KERNEL(recv_v2,
ops::RecvOpV2CPUKernel<float>, CPU,
ops::RecvOpV2CPUKernel<double>, ALL_LAYOUT,
ops::RecvOpV2CPUKernel<int>, ops::RecvOpV2CPUKernel,
ops::RecvOpV2CPUKernel<int64_t>, float,
ops::RecvOpV2CPUKernel<plat::float16>); double,
int,
int64_t,
plat::float16) {}
...@@ -105,7 +105,7 @@ framework::DDim recv_shape_info(const platform::Place &place, ...@@ -105,7 +105,7 @@ framework::DDim recv_shape_info(const platform::Place &place,
} }
#endif #endif
template <typename T> template <typename T, typename DeviceContext>
class RecvOpV2CUDAKernel : public framework::OpKernel<T> { class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -232,13 +232,17 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -232,13 +232,17 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(recv_v2, PD_REGISTER_STRUCT_KERNEL(recv_v2,
ops::RecvOpV2CUDAKernel<float>, GPU,
ops::RecvOpV2CUDAKernel<double>, ALL_LAYOUT,
ops::RecvOpV2CUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::RecvOpV2CUDAKernel<plat::bfloat16>, plat::bfloat16,
#endif #endif
ops::RecvOpV2CUDAKernel<int>, int,
ops::RecvOpV2CUDAKernel<int64_t>, int64_t,
ops::RecvOpV2CUDAKernel<int8_t>, int8_t,
ops::RecvOpV2CUDAKernel<plat::float16>); plat::float16) {
}
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class RecvOpV2CPUKernel : public framework::OpKernel<T> { class RecvOpV2CPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -499,12 +499,12 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -499,12 +499,12 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
run_program, run_program,
device_type, device_type,
paddle::operators:: paddle::operators::
RunProgramOpKernel<paddle::platform::CustomDeviceContext, float>); RunProgramOpKernel<float, paddle::platform::CustomDeviceContext>);
REGISTER_OP_CUSTOM_DEVICE_KERNEL( REGISTER_OP_CUSTOM_DEVICE_KERNEL(
run_program_grad, run_program_grad,
device_type, device_type,
paddle::operators :: paddle::operators ::
RunProgramGradOpKernel<paddle::platform::CustomDeviceContext, float>); RunProgramGradOpKernel<float, paddle::platform::CustomDeviceContext>);
REGISTER_OP_CUSTOM_DEVICE_KERNEL( REGISTER_OP_CUSTOM_DEVICE_KERNEL(
save_combine, save_combine,
device_type, device_type,
......
...@@ -248,7 +248,7 @@ static inline T JaccardOverlap(const std::vector<T>& box1, ...@@ -248,7 +248,7 @@ static inline T JaccardOverlap(const std::vector<T>& box1,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class RetinanetDetectionOutputKernel : public framework::OpKernel<T> { class RetinanetDetectionOutputKernel : public framework::OpKernel<T> {
public: public:
void NMSFast(const std::vector<std::vector<T>>& cls_dets, void NMSFast(const std::vector<std::vector<T>>& cls_dets,
...@@ -671,6 +671,9 @@ REGISTER_OPERATOR( ...@@ -671,6 +671,9 @@ REGISTER_OPERATOR(
ops::RetinanetDetectionOutputOpMaker, ops::RetinanetDetectionOutputOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(retinanet_detection_output, PD_REGISTER_STRUCT_KERNEL(retinanet_detection_output,
ops::RetinanetDetectionOutputKernel<float>, CPU,
ops::RetinanetDetectionOutputKernel<double>); ALL_LAYOUT,
ops::RetinanetDetectionOutputKernel,
float,
double) {}
...@@ -242,7 +242,7 @@ void bilinear_interpolate(const T* in_data, ...@@ -242,7 +242,7 @@ void bilinear_interpolate(const T* in_data,
val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
} }
template <typename T> template <typename T, typename DeviceContext>
class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> { class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -390,7 +390,7 @@ T get_feature_gradient( ...@@ -390,7 +390,7 @@ T get_feature_gradient(
return weight; return weight;
} }
template <typename T> template <typename T, typename DeviceContext>
class CPUROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> { class CPUROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -690,7 +690,13 @@ REGISTER_OPERATOR( ...@@ -690,7 +690,13 @@ REGISTER_OPERATOR(
ops::ROIPerspectiveTransformGradMaker<paddle::imperative::OpBase>); ops::ROIPerspectiveTransformGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(roi_perspective_transform_grad, REGISTER_OPERATOR(roi_perspective_transform_grad,
ops::ROIPerspectiveTransformGradOp); ops::ROIPerspectiveTransformGradOp);
REGISTER_OP_CPU_KERNEL(roi_perspective_transform, PD_REGISTER_STRUCT_KERNEL(roi_perspective_transform,
ops::CPUROIPerspectiveTransformOpKernel<float>); CPU,
REGISTER_OP_CPU_KERNEL(roi_perspective_transform_grad, ALL_LAYOUT,
ops::CPUROIPerspectiveTransformGradOpKernel<float>); ops::CPUROIPerspectiveTransformOpKernel,
float) {}
PD_REGISTER_STRUCT_KERNEL(roi_perspective_transform_grad,
CPU,
ALL_LAYOUT,
ops::CPUROIPerspectiveTransformGradOpKernel,
float) {}
...@@ -363,7 +363,7 @@ __global__ void RoiTransformKernel(const float* input_data, ...@@ -363,7 +363,7 @@ __global__ void RoiTransformKernel(const float* input_data,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> { class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -507,7 +507,7 @@ __global__ void RoiTransformGradKernel(int out_size, ...@@ -507,7 +507,7 @@ __global__ void RoiTransformGradKernel(int out_size,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> { class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -539,7 +539,13 @@ class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> { ...@@ -539,7 +539,13 @@ class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(roi_perspective_transform, PD_REGISTER_STRUCT_KERNEL(roi_perspective_transform,
ops::CUDAROIPerspectiveTransformOpKernel<float>); GPU,
REGISTER_OP_CUDA_KERNEL(roi_perspective_transform_grad, ALL_LAYOUT,
ops::CUDAROIPerspectiveTransformGradOpKernel<float>); ops::CUDAROIPerspectiveTransformOpKernel,
float) {}
PD_REGISTER_STRUCT_KERNEL(roi_perspective_transform_grad,
GPU,
ALL_LAYOUT,
ops::CUDAROIPerspectiveTransformGradOpKernel,
float) {}
...@@ -392,7 +392,7 @@ std::vector<phi::DenseTensor> SampleRpnFgBgGt( ...@@ -392,7 +392,7 @@ std::vector<phi::DenseTensor> SampleRpnFgBgGt(
return loc_score_tgtlbl_gt; return loc_score_tgtlbl_gt;
} }
template <typename T> template <typename T, typename DeviceContext>
class RpnTargetAssignKernel : public framework::OpKernel<T> { class RpnTargetAssignKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -995,7 +995,7 @@ std::vector<phi::DenseTensor> GetAllFgBgGt( ...@@ -995,7 +995,7 @@ std::vector<phi::DenseTensor> GetAllFgBgGt(
return loc_score_tgtlbl_gt; return loc_score_tgtlbl_gt;
} }
template <typename T> template <typename T, typename DeviceContext>
class RetinanetTargetAssignKernel : public framework::OpKernel<T> { class RetinanetTargetAssignKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -1236,15 +1236,21 @@ REGISTER_OPERATOR( ...@@ -1236,15 +1236,21 @@ REGISTER_OPERATOR(
ops::RpnTargetAssignOpMaker, ops::RpnTargetAssignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(rpn_target_assign, PD_REGISTER_STRUCT_KERNEL(rpn_target_assign,
ops::RpnTargetAssignKernel<float>, CPU,
ops::RpnTargetAssignKernel<double>); ALL_LAYOUT,
ops::RpnTargetAssignKernel,
float,
double) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
retinanet_target_assign, retinanet_target_assign,
ops::RetinanetTargetAssignOp, ops::RetinanetTargetAssignOp,
ops::RetinanetTargetAssignOpMaker, ops::RetinanetTargetAssignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(retinanet_target_assign, PD_REGISTER_STRUCT_KERNEL(retinanet_target_assign,
ops::RetinanetTargetAssignKernel<float>, CPU,
ops::RetinanetTargetAssignKernel<double>); ALL_LAYOUT,
ops::RetinanetTargetAssignKernel,
float,
double) {}
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class ResNetUnitKernel : public framework::OpKernel<T> { class ResNetUnitKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -223,7 +223,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> { ...@@ -223,7 +223,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class ResNetUnitGradKernel : public framework::OpKernel<T> { class ResNetUnitGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -419,7 +419,11 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> { ...@@ -419,7 +419,11 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
#if CUDNN_VERSION >= 8000 #if CUDNN_VERSION >= 8000
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(resnet_unit, ops::ResNetUnitKernel<plat::float16>); PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_CUDA_KERNEL(resnet_unit_grad, resnet_unit, GPU, ALL_LAYOUT, ops::ResNetUnitKernel, plat::float16) {}
ops::ResNetUnitGradKernel<plat::float16>); PD_REGISTER_STRUCT_KERNEL(resnet_unit_grad,
GPU,
ALL_LAYOUT,
ops::ResNetUnitGradKernel,
plat::float16) {}
#endif #endif
...@@ -122,7 +122,7 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>; ...@@ -122,7 +122,7 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
PD_REGISTER_STRUCT_KERNEL(dequantize_linear, PD_REGISTER_STRUCT_KERNEL(dequantize_linear,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -242,7 +242,7 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -242,7 +242,7 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using GPUCtx = phi::GPUContext;
PD_REGISTER_STRUCT_KERNEL(rank_attention, PD_REGISTER_STRUCT_KERNEL(rank_attention,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -11,13 +11,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -11,13 +11,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/row_conv_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -136,8 +134,8 @@ https://github.com/PaddlePaddle/Paddle/issues/2228#issuecomment-303903645 . ...@@ -136,8 +134,8 @@ https://github.com/PaddlePaddle/Paddle/issues/2228#issuecomment-303903645 .
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class RowConvKernel<phi::CPUContext, T> : public framework::OpKernel<T> { class RowConvKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<phi::DenseTensor>("X"); auto *x = context.Input<phi::DenseTensor>("X");
...@@ -211,8 +209,8 @@ class RowConvKernel<phi::CPUContext, T> : public framework::OpKernel<T> { ...@@ -211,8 +209,8 @@ class RowConvKernel<phi::CPUContext, T> : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class RowConvGradKernel<phi::CPUContext, T> : public framework::OpKernel<T> { class RowConvGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<phi::DenseTensor>("X"); auto *x = context.Input<phi::DenseTensor>("X");
...@@ -351,6 +349,7 @@ REGISTER_OPERATOR(row_conv, ...@@ -351,6 +349,7 @@ REGISTER_OPERATOR(row_conv,
ops::RowConvGradOpMaker<paddle::framework::OpDesc>, ops::RowConvGradOpMaker<paddle::framework::OpDesc>,
ops::RowConvGradOpMaker<paddle::imperative::OpBase>); ops::RowConvGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(row_conv_grad, ops::RowConvGradOp); REGISTER_OPERATOR(row_conv_grad, ops::RowConvGradOp);
REGISTER_OP_CPU_KERNEL(row_conv, ops::RowConvKernel<phi::CPUContext, float>); PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_CPU_KERNEL(row_conv_grad, row_conv, CPU, ALL_LAYOUT, ops::RowConvKernel, float) {}
ops::RowConvGradKernel<phi::CPUContext, float>); PD_REGISTER_STRUCT_KERNEL(
row_conv_grad, CPU, ALL_LAYOUT, ops::RowConvGradKernel, float) {}
...@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/row_conv_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -319,8 +319,8 @@ __global__ void RowConvGradFilter(const T *in, ...@@ -319,8 +319,8 @@ __global__ void RowConvGradFilter(const T *in,
} // namespace } // namespace
template <typename T> template <typename T, typename DeviceContext>
class RowConvKernel<phi::GPUContext, T> : public framework::OpKernel<T> { class RowConvKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<phi::DenseTensor>("X"); auto *X = context.Input<phi::DenseTensor>("X");
...@@ -373,8 +373,8 @@ class RowConvKernel<phi::GPUContext, T> : public framework::OpKernel<T> { ...@@ -373,8 +373,8 @@ class RowConvKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class RowConvGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> { class RowConvGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<phi::DenseTensor>("X"); auto *X = context.Input<phi::DenseTensor>("X");
...@@ -491,6 +491,7 @@ class RowConvGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> { ...@@ -491,6 +491,7 @@ class RowConvGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(row_conv, ops::RowConvKernel<phi::GPUContext, float>); PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_CUDA_KERNEL(row_conv_grad, row_conv, GPU, ALL_LAYOUT, ops::RowConvKernel, float) {}
ops::RowConvGradKernel<phi::GPUContext, float>); PD_REGISTER_STRUCT_KERNEL(
row_conv_grad, GPU, ALL_LAYOUT, ops::RowConvGradKernel, float) {}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RowConvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
template <typename DeviceContext, typename T>
class RowConvGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
} // namespace operators
} // namespace paddle
...@@ -253,7 +253,7 @@ REGISTER_OPERATOR(run_program, ...@@ -253,7 +253,7 @@ REGISTER_OPERATOR(run_program,
REGISTER_OPERATOR(run_program_grad, ops::RunProgramGradOp); REGISTER_OPERATOR(run_program_grad, ops::RunProgramGradOp);
/* see [Why use single type kernel] */ /* see [Why use single type kernel] */
REGISTER_OP_CPU_KERNEL(run_program, PD_REGISTER_STRUCT_KERNEL(
ops::RunProgramOpKernel<phi::CPUContext, float>) run_program, CPU, ALL_LAYOUT, ops::RunProgramOpKernel, float) {}
REGISTER_OP_CPU_KERNEL(run_program_grad, PD_REGISTER_STRUCT_KERNEL(
ops::RunProgramGradOpKernel<phi::CPUContext, float>) run_program_grad, CPU, ALL_LAYOUT, ops::RunProgramGradOpKernel, float) {}
...@@ -20,7 +20,7 @@ namespace ops = paddle::operators; ...@@ -20,7 +20,7 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
/* see [Why use single type kernel] */ /* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(run_program, PD_REGISTER_STRUCT_KERNEL(
ops::RunProgramOpKernel<phi::GPUContext, float>); run_program, GPU, ALL_LAYOUT, ops::RunProgramOpKernel, float) {}
REGISTER_OP_CUDA_KERNEL(run_program_grad, PD_REGISTER_STRUCT_KERNEL(
ops::RunProgramGradOpKernel<phi::GPUContext, float>); run_program_grad, GPU, ALL_LAYOUT, ops::RunProgramGradOpKernel, float) {}
...@@ -197,7 +197,7 @@ static cudaStreamCaptureMode StringToCUDAGraphCaptureMode( ...@@ -197,7 +197,7 @@ static cudaStreamCaptureMode StringToCUDAGraphCaptureMode(
} // namespace details } // namespace details
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RunProgramOpKernel : public framework::OpKernel<T> { class RunProgramOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -395,7 +395,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -395,7 +395,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RunProgramGradOpKernel : public framework::OpKernel<T> { class RunProgramGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
......
...@@ -328,7 +328,6 @@ register_unity_group( ...@@ -328,7 +328,6 @@ register_unity_group(
pool_op.cu.cc pool_op.cu.cc
pool_cudnn_op.cu.cc pool_cudnn_op.cu.cc
pool_with_index_op.cu.cc pool_with_index_op.cu.cc
run_program_op.cu.cc
softmax_op.cu.cc softmax_op.cu.cc
softmax_cudnn_op.cu.cc softmax_cudnn_op.cu.cc
spp_op.cu.cc spp_op.cu.cc
...@@ -354,7 +353,6 @@ register_unity_group( ...@@ -354,7 +353,6 @@ register_unity_group(
rnn_op.cu.cc rnn_op.cu.cc
split_op.cu.cc split_op.cu.cc
assign_value_op.cu.cc assign_value_op.cu.cc
run_program_op.cu.cc
warpctc_op.cu.cc) warpctc_op.cu.cc)
register_unity_group( register_unity_group(
cu cu
...@@ -492,6 +490,7 @@ register_unity_group( ...@@ -492,6 +490,7 @@ register_unity_group(
pixel_shuffle_op.cu pixel_shuffle_op.cu
prelu_op.cu prelu_op.cu
prroi_pool_op.cu prroi_pool_op.cu
run_program_op.cu
pull_box_extended_sparse_op.cu pull_box_extended_sparse_op.cu
pull_box_sparse_op.cu) pull_box_sparse_op.cu)
register_unity_group( register_unity_group(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册