未验证 提交 82866d4a 编写于 作者: Y yuyang18

Add register kernel functor and shrink reshape op

* Shrink reshape_op library size
* User can register a standard C++ functor as a op kernel
上级 75ae426a
......@@ -146,7 +146,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
DataTypeAndKernelType...> {
using KERNEL_TYPE =
using Functor =
typename std::tuple_element<I + 1,
std::tuple<DataTypeAndKernelType...>>::type;
using T =
......@@ -154,10 +154,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const {
RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
});
RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor());
constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
......@@ -238,11 +235,11 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \
}
#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \
REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
......
......@@ -107,7 +107,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
}
};
void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
void ReshapeKernel::operator()(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");
......@@ -147,7 +147,7 @@ void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
out->Resize(out_dims);
}
}
void ReshapeGradKernelBase::Compute(
void ReshapeGradKernel::operator()(
const framework::ExecutionContext &ctx) const {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
......@@ -172,10 +172,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<float>,
ops::ReshapeGradKernel<double>,
ops::ReshapeGradKernel<int>,
ops::ReshapeGradKernel<int64_t>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
......@@ -14,11 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL(reshape_grad,
paddle::operators::ReshapeGradKernel<float>,
paddle::operators::ReshapeGradKernel<double>,
paddle::operators::ReshapeGradKernel<int>,
paddle::operators::ReshapeGradKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
......@@ -118,21 +118,15 @@ class ReshapeOp : public framework::OperatorWithKernel {
}
};
class ReshapeKernel : public framework::OpKernelBase {
class ReshapeKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const final;
void operator()(const framework::ExecutionContext &ctx) const;
};
class ReshapeGradKernelBase : public framework::OpKernelBase {
class ReshapeGradKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const;
void operator()(const framework::ExecutionContext &ctx) const;
};
template <typename T>
class ReshapeGradKernel : public ReshapeGradKernelBase {
public:
// Tell register element type.
using ELEMENT_TYPE = T;
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册