From 1ce478f100efe6e35e164e294bf8e6682b360fdd Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Mon, 2 Jul 2018 16:13:52 +0800 Subject: [PATCH] Polish reshape op --- paddle/fluid/framework/op_registry.h | 84 ++++++++++++++++--- paddle/fluid/operators/reshape_op.cc | 74 ++++++++++++++-- .../{reshape_op.cu => reshape_op.cu.cc} | 18 ++-- paddle/fluid/operators/reshape_op.h | 71 +++------------- 4 files changed, 157 insertions(+), 90 deletions(-) rename paddle/fluid/operators/{reshape_op.cu => reshape_op.cu.cc} (51%) diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 43ab227a94..f0278cc49c 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -76,6 +76,19 @@ class OpRegistry { template struct OpKernelRegistrarFunctor; +template +inline void RegisterKernelClass(const char* op_type, const char* library_type) { + std::string library(library_type); + std::string data_layout = "ANYLAYOUT"; + if (library == "MKLDNN") { + data_layout = "MKLDNNLAYOUT"; + } + OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), + StringToDataLayout(data_layout), + StringToLibraryType(library_type)); + OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType()); +} + template struct OpKernelRegistrarFunctor { using KERNEL_TYPE = @@ -83,16 +96,7 @@ struct OpKernelRegistrarFunctor { void operator()(const char* op_type, const char* library_type) const { using T = typename KERNEL_TYPE::ELEMENT_TYPE; - std::string library(library_type); - std::string data_layout = "ANYLAYOUT"; - if (library == "MKLDNN") { - data_layout = "MKLDNNLAYOUT"; - } - OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), - StringToDataLayout(data_layout), - StringToLibraryType(library_type)); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE); - + RegisterKernelClass(op_type, library_type); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor func; @@ -116,6 +120,47 @@ class OpKernelRegistrar : public Registrar { } }; +template +struct OpKernelRegistrarFunctorEx; + +template +class OpKernelRegistrarEx : public Registrar { + public: + explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) { + OpKernelRegistrarFunctorEx + func; + func(op_type, library_type); + } +}; + +template +struct OpKernelRegistrarFunctorEx { + void operator()(const char* op_type, const char* library_type) const {} +}; + +template +struct OpKernelRegistrarFunctorEx { + using KERNEL_TYPE = + typename std::tuple_element>::type; + using T = + typename std::tuple_element>::type; + + void operator()(const char* op_type, const char* library_type) const { + RegisterKernelClass(op_type, library_type); + + constexpr auto size = + std::tuple_size>::value; + OpKernelRegistrarFunctorEx= size, I + 2, + DataTypeAndKernelType...> + func; + func(op_type, library_type); + } +}; + /** * check if MACRO is used in GLOBAL NAMESPACE. */ @@ -174,6 +219,25 @@ class OpKernelRegistrar : public Registrar { #define REGISTER_OP_CPU_KERNEL(op_type, ...) \ REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) +#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op_kernel_##op_type##_##library_type##__, \ + "REGISTER_OP_KERNEL_EX must be called in global namespace"); \ + static ::paddle::framework::OpKernelRegistrarEx \ + __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ + #library_type); \ + int TouchOpKernelRegistrar_##op_type##_##library_type() { \ + __op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ + return 0; \ + } + +#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \ + REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \ + __VA_ARGS__) + +#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \ + REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) + /** * Macro to mark what Operator and Kernel * we will use and tell the compiler to diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 7f743f577f..ed07e6c2f7 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -107,19 +107,75 @@ class ReshapeGradOp : public framework::OperatorWithKernel { } }; +void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const { + auto *out = ctx.Output("Out"); + auto *in = ctx.Input("X"); + + auto *shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; + + framework::DDim out_dims = out->dims(); + + if (shape_tensor) { + auto *shape_data = shape_tensor->data(); + framework::Tensor cpu_shape_tensor; + if (platform::is_gpu_place(ctx.GetPlace())) { + TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); + shape_data = cpu_shape_tensor.data(); + } + auto shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + out_dims = ReshapeOp::ValidateShape(shape, in->dims()); + } + if (!in->lod().empty()) { + PADDLE_ENFORCE_EQ(out_dims[0], in->dims()[0], + "Reshape operator cannot reshape an input sequence batch " + "into an output sequence batch that has a different " + "number of time steps. Please consider using " + "sequence_reshape op."); + } + + bool inplace = ctx.Attr("inplace"); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace(), in->type()); + framework::TensorCopySync(*in, ctx.GetPlace(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } +} +void ReshapeGradKernelBase::Compute( + const framework::ExecutionContext &ctx) const { + auto *d_out = ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + bool inplace = ctx.Attr("inplace"); + + auto in_dims = d_x->dims(); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + ctx.device_context().Wait(); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } +} } // namespace operators } // namespace paddle namespace ops = paddle::operators; -using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); -REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel, - ops::ReshapeKernel, - ops::ReshapeKernel, - ops::ReshapeKernel); -REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, + ops::ReshapeKernel); +REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel, + ops::ReshapeGradKernel, + ops::ReshapeGradKernel, + ops::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.cu b/paddle/fluid/operators/reshape_op.cu.cc similarity index 51% rename from paddle/fluid/operators/reshape_op.cu rename to paddle/fluid/operators/reshape_op.cu.cc index c628c634e2..8a09321eef 100644 --- a/paddle/fluid/operators/reshape_op.cu +++ b/paddle/fluid/operators/reshape_op.cu.cc @@ -13,14 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/reshape_op.h" -using CUDA = paddle::platform::CUDADeviceContext; - -REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel); +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, + ops::ReshapeKernel); REGISTER_OP_CUDA_KERNEL(reshape_grad, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel); + paddle::operators::ReshapeGradKernel, + paddle::operators::ReshapeGradKernel, + paddle::operators::ReshapeGradKernel, + paddle::operators::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index 3dd8c7c11e..c0b57d11d3 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -118,72 +118,21 @@ class ReshapeOp : public framework::OperatorWithKernel { } }; -template -class ReshapeKernel : public framework::OpKernel { +class ReshapeKernel : public framework::OpKernelBase { public: - void Compute(const framework::ExecutionContext &ctx) const { - auto *out = ctx.Output("Out"); - auto *in = ctx.Input("X"); - - auto *shape_tensor = ctx.HasInput("Shape") - ? ctx.Input("Shape") - : nullptr; - - framework::DDim out_dims = out->dims(); - - if (shape_tensor) { - auto *shape_data = shape_tensor->data(); - framework::Tensor cpu_shape_tensor; - if (platform::is_gpu_place(ctx.GetPlace())) { - TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); - shape_data = cpu_shape_tensor.data(); - } - auto shape = - std::vector(shape_data, shape_data + shape_tensor->numel()); - out_dims = ReshapeOp::ValidateShape(shape, in->dims()); - } - if (!in->lod().empty()) { - PADDLE_ENFORCE_EQ( - out_dims[0], in->dims()[0], - "Reshape operator cannot reshape an input sequence batch " - "into an output sequence batch that has a different " - "number of time steps. Please consider using " - "sequence_reshape op."); - } + void Compute(const framework::ExecutionContext &ctx) const final; +}; - bool inplace = ctx.Attr("inplace"); - out->Resize(out_dims); - if (!inplace) { - out->mutable_data(ctx.GetPlace()); - framework::TensorCopySync(*in, ctx.GetPlace(), out); - out->Resize(out_dims); - } else { - out->ShareDataWith(*in); - out->Resize(out_dims); - } - } +class ReshapeGradKernelBase : public framework::OpKernelBase { + public: + void Compute(const framework::ExecutionContext &ctx) const; }; -template -class ReshapeGradKernel : public framework::OpKernel { +template +class ReshapeGradKernel : public ReshapeGradKernelBase { public: - void Compute(const framework::ExecutionContext &ctx) const { - auto *d_out = ctx.Input(framework::GradVarName("Out")); - auto *d_x = ctx.Output(framework::GradVarName("X")); - - d_x->mutable_data(ctx.GetPlace()); - bool inplace = ctx.Attr("inplace"); - - auto in_dims = d_x->dims(); - if (!inplace) { - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - ctx.device_context().Wait(); - d_x->Resize(in_dims); - } else { - d_x->ShareDataWith(*d_out); - d_x->Resize(in_dims); - } - } + // Tell register element type. + using ELEMENT_TYPE = T; }; } // namespace operators } // namespace paddle -- GitLab