未验证 提交 82866d4a 编写于 作者: Y yuyang18

Add register kernel functor and shrink reshape op

* Shrink reshape_op library size
* User can register a standard C++ functor as a op kernel
上级 75ae426a
...@@ -146,7 +146,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, true, I, ...@@ -146,7 +146,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType> template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, false, I, struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
DataTypeAndKernelType...> { DataTypeAndKernelType...> {
using KERNEL_TYPE = using Functor =
typename std::tuple_element<I + 1, typename std::tuple_element<I + 1,
std::tuple<DataTypeAndKernelType...>>::type; std::tuple<DataTypeAndKernelType...>>::type;
using T = using T =
...@@ -154,10 +154,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -154,10 +154,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
std::tuple<DataTypeAndKernelType...>>::type; std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type) const {
RegisterKernelClass<PlaceType, T>( RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor());
op_type, library_type, [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
});
constexpr auto size = constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value; std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
...@@ -238,11 +235,11 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -238,11 +235,11 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \ return 0; \
} }
#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \ REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
__VA_ARGS__) __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/** /**
......
...@@ -107,7 +107,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -107,7 +107,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
} }
}; };
void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const { void ReshapeKernel::operator()(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out"); auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X"); auto *in = ctx.Input<framework::LoDTensor>("X");
...@@ -147,7 +147,7 @@ void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const { ...@@ -147,7 +147,7 @@ void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
out->Resize(out_dims); out->Resize(out_dims);
} }
} }
void ReshapeGradKernelBase::Compute( void ReshapeGradKernel::operator()(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
...@@ -172,10 +172,10 @@ namespace ops = paddle::operators; ...@@ -172,10 +172,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel, int, ops::ReshapeKernel,
ops::ReshapeKernel); int64_t, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<float>, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel<double>, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel<int>, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel<int64_t>); ops::ReshapeGradKernel);
...@@ -14,11 +14,11 @@ limitations under the License. */ ...@@ -14,11 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h" #include "paddle/fluid/operators/reshape_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel); ops::ReshapeKernel, int, ops::ReshapeKernel,
REGISTER_OP_CUDA_KERNEL(reshape_grad, int64_t, ops::ReshapeKernel);
paddle::operators::ReshapeGradKernel<float>, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
paddle::operators::ReshapeGradKernel<double>, double, ops::ReshapeGradKernel, int,
paddle::operators::ReshapeGradKernel<int>, ops::ReshapeGradKernel, int64_t,
paddle::operators::ReshapeGradKernel<int64_t>); ops::ReshapeGradKernel);
...@@ -118,21 +118,15 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -118,21 +118,15 @@ class ReshapeOp : public framework::OperatorWithKernel {
} }
}; };
class ReshapeKernel : public framework::OpKernelBase { class ReshapeKernel {
public: public:
void Compute(const framework::ExecutionContext &ctx) const final; void operator()(const framework::ExecutionContext &ctx) const;
}; };
class ReshapeGradKernelBase : public framework::OpKernelBase { class ReshapeGradKernel {
public: public:
void Compute(const framework::ExecutionContext &ctx) const; void operator()(const framework::ExecutionContext &ctx) const;
}; };
template <typename T>
class ReshapeGradKernel : public ReshapeGradKernelBase {
public:
// Tell register element type.
using ELEMENT_TYPE = T;
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册