diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 43ab227a9478707445892c14723801992d0041aa..3314e41cc51d74f87be0e2cd5eba9bb260c16be7 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -76,6 +76,20 @@ class OpRegistry { template struct OpKernelRegistrarFunctor; +template +inline void RegisterKernelClass(const char* op_type, const char* library_type, + Func func) { + std::string library(library_type); + std::string data_layout = "ANYLAYOUT"; + if (library == "MKLDNN") { + data_layout = "MKLDNNLAYOUT"; + } + OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), + StringToDataLayout(data_layout), + StringToLibraryType(library_type)); + OperatorWithKernel::AllOpKernels()[op_type][key] = func; +} + template struct OpKernelRegistrarFunctor { using KERNEL_TYPE = @@ -83,16 +97,10 @@ struct OpKernelRegistrarFunctor { void operator()(const char* op_type, const char* library_type) const { using T = typename KERNEL_TYPE::ELEMENT_TYPE; - std::string library(library_type); - std::string data_layout = "ANYLAYOUT"; - if (library == "MKLDNN") { - data_layout = "MKLDNNLAYOUT"; - } - OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), - StringToDataLayout(data_layout), - StringToLibraryType(library_type)); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE); - + RegisterKernelClass( + op_type, library_type, [](const framework::ExecutionContext& ctx) { + KERNEL_TYPE().Compute(ctx); + }); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor func; @@ -116,6 +124,47 @@ class OpKernelRegistrar : public Registrar { } }; +template +struct OpKernelRegistrarFunctorEx; + +template +class OpKernelRegistrarEx : public Registrar { + public: + explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) { + OpKernelRegistrarFunctorEx + func; + func(op_type, library_type); + } +}; + +template +struct OpKernelRegistrarFunctorEx { + void operator()(const char* op_type, const char* library_type) const {} +}; + +template +struct OpKernelRegistrarFunctorEx { + using Functor = + typename std::tuple_element>::type; + using T = + typename std::tuple_element>::type; + + void operator()(const char* op_type, const char* library_type) const { + RegisterKernelClass(op_type, library_type, Functor()); + + constexpr auto size = + std::tuple_size>::value; + OpKernelRegistrarFunctorEx= size, I + 2, + DataTypeAndKernelType...> + func; + func(op_type, library_type); + } +}; + /** * check if MACRO is used in GLOBAL NAMESPACE. */ @@ -174,6 +223,25 @@ class OpKernelRegistrar : public Registrar { #define REGISTER_OP_CPU_KERNEL(op_type, ...) \ REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) +#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op_kernel_##op_type##_##library_type##__, \ + "REGISTER_OP_KERNEL_EX must be called in global namespace"); \ + static ::paddle::framework::OpKernelRegistrarEx \ + __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ + #library_type); \ + int TouchOpKernelRegistrar_##op_type##_##library_type() { \ + __op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ + return 0; \ + } + +#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \ + REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \ + __VA_ARGS__) + +#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \ + REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) + /** * Macro to mark what Operator and Kernel * we will use and tell the compiler to diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 71cd5a39083471af52598cc2a1d4c591d3780624..3cf8e8696d739e3f2894e490161b9fb5b459bc41 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } - kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx)); + kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 1550d5df172f0599e1b42e7f1ccf51ac4dd1e0c3..01d750efbb8aaa35701f6caa7ec103ec21dd529e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase { class OperatorWithKernel : public OperatorBase { public: + using OpKernelFunc = std::function; using OpKernelMap = - std::unordered_map, - OpKernelType::Hash>; + std::unordered_map; OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 847b7b0c12e1679501dbe83d578b23ca2aef3e9e..99fa659a351249a4a93f71700e1c646465861aba 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -115,6 +115,7 @@ class MKLDNNMemory { template class FCMKLDNNOpKernel : public paddle::framework::OpKernel { + public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 7f743f577fbcdaf6f62e01031e25ef09a842c2e9..918f3be533d51367eade5f5108ad2eab954a9303 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -12,14 +12,108 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/reshape_op.h" - #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { +class ReshapeOp : public framework::OperatorWithKernel { + public: + ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReshapeOp should not be null."); + + const std::vector &shape = ctx->Attrs().Get>("shape"); + PADDLE_ENFORCE(!shape.empty(), + "The shape information must be set by Attr(shape)."); + + if (ctx->HasInput("Shape") && ctx->IsRuntime()) { + // If true, set the shape of Output(Out) according to Input(Shape) in + // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel. + ctx->ShareLoD("X", /*->*/ "Out"); + return; + } + + auto x_dims = ctx->GetInputDim("X"); + auto out_dims = ValidateShape(shape, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + + static framework::DDim ValidateShape(const std::vector shape, + const framework::DDim &in_dims) { + const int64_t in_size = framework::product(in_dims); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PADDLE_ENFORCE( + unk_dim_idx == -1, + "Only one input dimension of Attr(shape) can be unknown."); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PADDLE_ENFORCE( + static_cast(i) < in_dims.size(), + "The index of dimension to copy from input shape must be less " + "than the size of input shape."); + } else { + PADDLE_ENFORCE( + shape[i] > 0, + "Each input dimension of Attr(shape) must not be negtive except " + "one unknown dimension."); + } + + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = + (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } + + if (unk_dim_idx != -1) { + if (in_size > 0) { + // in_size < 0 and is un-determinate in compile time, skip the check, + // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, in_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -in_size / capacity; + PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, + "Invalid shape is given."); + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); + } + return framework::make_ddim(output_shape); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -107,19 +201,93 @@ class ReshapeGradOp : public framework::OperatorWithKernel { } }; +class ReshapeKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *out = ctx.Output("Out"); + auto *in = ctx.Input("X"); + + auto *shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; + + framework::DDim out_dims = out->dims(); + + if (shape_tensor) { + auto *shape_data = shape_tensor->data(); + framework::Tensor cpu_shape_tensor; + if (platform::is_gpu_place(ctx.GetPlace())) { + TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); + shape_data = cpu_shape_tensor.data(); + } + auto shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + out_dims = ReshapeOp::ValidateShape(shape, in->dims()); + } + if (!in->lod().empty()) { + PADDLE_ENFORCE_EQ( + out_dims[0], in->dims()[0], + "Reshape operator cannot reshape an input sequence batch " + "into an output sequence batch that has a different " + "number of time steps. Please consider using " + "sequence_reshape op."); + } + + bool inplace = ctx.Attr("inplace"); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace(), in->type()); + framework::TensorCopySync(*in, ctx.GetPlace(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } + } +}; + +class ReshapeGradKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *d_out = ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + bool inplace = ctx.Attr("inplace"); + + auto in_dims = d_x->dims(); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + ctx.device_context().Wait(); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); -REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel, - ops::ReshapeKernel, - ops::ReshapeKernel, - ops::ReshapeKernel); -REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); +#endif diff --git a/paddle/fluid/operators/reshape_op.cu b/paddle/fluid/operators/reshape_op.cu deleted file mode 100644 index c628c634e2bc9ae260948a6e7ccf786cbd6c5c3c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reshape_op.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/reshape_op.h" -using CUDA = paddle::platform::CUDADeviceContext; - -REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel); -REGISTER_OP_CUDA_KERNEL(reshape_grad, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h deleted file mode 100644 index 3dd8c7c11eca241e747bfa129962032d882ce44c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reshape_op.h +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class ReshapeOp : public framework::OperatorWithKernel { - public: - ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ReshapeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ReshapeOp should not be null."); - - const std::vector &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE(!shape.empty(), - "The shape information must be set by Attr(shape)."); - - if (ctx->HasInput("Shape") && ctx->IsRuntime()) { - // If true, set the shape of Output(Out) according to Input(Shape) in - // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel. - ctx->ShareLoD("X", /*->*/ "Out"); - return; - } - - auto x_dims = ctx->GetInputDim("X"); - auto out_dims = ValidateShape(shape, x_dims); - ctx->SetOutputDim("Out", out_dims); - if (x_dims[0] == out_dims[0]) { - // Only pass LoD when the first dimension of output and Input(X) - // are the same. - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - - static framework::DDim ValidateShape(const std::vector shape, - const framework::DDim &in_dims) { - const int64_t in_size = framework::product(in_dims); - // only one dimension can be set to -1, whose size will be automatically - // infered. - const int64_t unk_dim_val = -1; - const int64_t copy_dim_val = 0; - - std::vector output_shape(shape.size(), 0); - int64_t capacity = 1; - int unk_dim_idx = -1; - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == unk_dim_val) { - PADDLE_ENFORCE( - unk_dim_idx == -1, - "Only one input dimension of Attr(shape) can be unknown."); - unk_dim_idx = i; - } else if (shape[i] == copy_dim_val) { - PADDLE_ENFORCE( - static_cast(i) < in_dims.size(), - "The index of dimension to copy from input shape must be less " - "than the size of input shape."); - } else { - PADDLE_ENFORCE( - shape[i] > 0, - "Each input dimension of Attr(shape) must not be negtive except " - "one unknown dimension."); - } - - capacity *= (shape[i] ? shape[i] : in_dims[i]); - output_shape[i] = - (shape[i] ? static_cast(shape[i]) : in_dims[i]); - } - - if (unk_dim_idx != -1) { - if (in_size > 0) { - // in_size < 0 and is un-determinate in compile time, skip the check, - // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], - // capacity = -24, in_size = -8, output_shape[0] = 0 - // the following check will fail. - output_shape[unk_dim_idx] = -in_size / capacity; - PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, - "Invalid shape is given."); - } else { - output_shape[unk_dim_idx] = -1; - } - } else { - PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); - } - return framework::make_ddim(output_shape); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); - } -}; - -template -class ReshapeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const { - auto *out = ctx.Output("Out"); - auto *in = ctx.Input("X"); - - auto *shape_tensor = ctx.HasInput("Shape") - ? ctx.Input("Shape") - : nullptr; - - framework::DDim out_dims = out->dims(); - - if (shape_tensor) { - auto *shape_data = shape_tensor->data(); - framework::Tensor cpu_shape_tensor; - if (platform::is_gpu_place(ctx.GetPlace())) { - TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); - shape_data = cpu_shape_tensor.data(); - } - auto shape = - std::vector(shape_data, shape_data + shape_tensor->numel()); - out_dims = ReshapeOp::ValidateShape(shape, in->dims()); - } - if (!in->lod().empty()) { - PADDLE_ENFORCE_EQ( - out_dims[0], in->dims()[0], - "Reshape operator cannot reshape an input sequence batch " - "into an output sequence batch that has a different " - "number of time steps. Please consider using " - "sequence_reshape op."); - } - - bool inplace = ctx.Attr("inplace"); - out->Resize(out_dims); - if (!inplace) { - out->mutable_data(ctx.GetPlace()); - framework::TensorCopySync(*in, ctx.GetPlace(), out); - out->Resize(out_dims); - } else { - out->ShareDataWith(*in); - out->Resize(out_dims); - } - } -}; - -template -class ReshapeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const { - auto *d_out = ctx.Input(framework::GradVarName("Out")); - auto *d_x = ctx.Output(framework::GradVarName("X")); - - d_x->mutable_data(ctx.GetPlace()); - bool inplace = ctx.Attr("inplace"); - - auto in_dims = d_x->dims(); - if (!inplace) { - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - ctx.device_context().Wait(); - d_x->Resize(in_dims); - } else { - d_x->ShareDataWith(*d_out); - d_x->Resize(in_dims); - } - } -}; -} // namespace operators -} // namespace paddle