From f3d1fac24df45b409125ccc86d273d5a0be093e9 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 2 Apr 2019 19:31:05 +0800 Subject: [PATCH] fix kernel registry --- paddle/fluid/lite/CMakeLists.txt | 1 + paddle/fluid/lite/core/kernel.h | 11 +++--- paddle/fluid/lite/core/op_lite.h | 8 ++--- paddle/fluid/lite/core/op_registry.cc | 2 +- paddle/fluid/lite/core/op_registry.h | 40 +++++++++++++++------- paddle/fluid/lite/core/scope.cc | 3 +- paddle/fluid/lite/core/scope.h | 2 +- paddle/fluid/lite/core/scope_test.cc | 2 +- paddle/fluid/lite/core/tensor.cc | 2 +- paddle/fluid/lite/core/tensor.h | 1 - paddle/fluid/lite/core/variable.cc | 2 +- paddle/fluid/lite/operators/CMakeLists.txt | 5 ++- paddle/fluid/lite/operators/fc_op.cc | 2 +- paddle/fluid/lite/operators/fc_op.h | 30 +++++++++------- paddle/fluid/lite/utils/all.h | 2 +- 15 files changed, 68 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 0af924424ac..5a471a3c38b 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(cuda) add_subdirectory(operators) add_subdirectory(kernels) add_subdirectory(model_parser) +add_subdirectory(utils) diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 1808a0ddcd8..6355cb659de 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -18,10 +18,11 @@ #include #include #include -#include "paddle/fluid/framework/op_desc.h" #include "context.h" -#include "target_wrapper.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" +#include "target_wrapper.h" namespace paddle { namespace lite { @@ -39,11 +40,11 @@ class OpKernel { void SetContext(context_ptr_t&& ctx) { context_ = std::move(ctx); } - void SetParam(any param) { param_ = param; } + void SetParam(operators::param_t param) { param_ = param; } template Param& param() const { - return *any_cast(¶m_); + return param_.get(); } virtual void Run() { CHECK(false) << "Not Implemented"; } @@ -52,7 +53,7 @@ class OpKernel { protected: context_ptr_t context_; - mutable any param_; + mutable operators::param_t param_; }; } // namespace lite diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index a540fc8c83c..d660ff9c1f8 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -18,11 +18,12 @@ #include #include #include -#include "context.h" -#include "kernel.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/scope.h" namespace paddle { namespace lite { @@ -66,8 +67,7 @@ class OpLite : public Registry { // Run this operator. virtual bool Run() = 0; // Build the operator, attach it with the runtime environment. - virtual bool Build(const framework::OpDesc &opdesc, - framework::Scope *scope) = 0; + virtual bool Build(const framework::OpDesc &opdesc, lite::Scope *scope) = 0; // Human-readable information. virtual std::string DebugString() const = 0; diff --git a/paddle/fluid/lite/core/op_registry.cc b/paddle/fluid/lite/core/op_registry.cc index 23c940458ff..6556f2e6112 100644 --- a/paddle/fluid/lite/core/op_registry.cc +++ b/paddle/fluid/lite/core/op_registry.cc @@ -12,4 +12,4 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "op_registry.h" \ No newline at end of file +#include "paddle/fluid/lite/core/op_registry.h" \ No newline at end of file diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 3be1ee5a6d3..dfe54f5cbd2 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -15,9 +15,9 @@ #include #include #include -#include "kernel.h" -#include "op_lite.h" -#include "target_wrapper.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { @@ -50,16 +50,32 @@ class OpLiteRegistor : public Registor { }; template -class KernelRegistryForTarget : public Factory> {}; +using KernelRegistryForTarget = Factory>; class KernelRegistry final { public: + using any_kernel_registor_t = variant< + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // + >; + KernelRegistry() { -#define INIT_FOR(target__, precision__) \ - registries_[KernelRegistry::GetKernelOffset()] = \ - &KernelRegistryForTarget::Global(); +/* +using kernel_target_t = + KernelRegistryForTarget; +registries_[0].set( + &KernelRegistryForTarget::Global()); + */ +#define INIT_FOR(target__, precision__) \ + registries_[KernelRegistry::GetKernelOffset()] \ + .set \ + *>(&KernelRegistryForTarget::Global()); // Currently, just register 2 kernel targets. INIT_FOR(kARM, kFloat); INIT_FOR(kHost, kFloat); @@ -76,8 +92,8 @@ class KernelRegistry final { typename KernelRegistryForTarget::creator_t &&creator) { using kernel_registor_t = KernelRegistryForTarget; - any_cast( - registries_[GetKernelOffset()]) + registries_[GetKernelOffset()] + .template get() ->Register(name, std::move(creator)); } @@ -88,7 +104,7 @@ class KernelRegistry final { } private: - std::array registries_; + std::array registries_; }; template diff --git a/paddle/fluid/lite/core/scope.cc b/paddle/fluid/lite/core/scope.cc index 6e36f72e32c..4a662ebf489 100644 --- a/paddle/fluid/lite/core/scope.cc +++ b/paddle/fluid/lite/core/scope.cc @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "scope.h" -#include "scope.h" +#include "paddle/fluid/lite/core/scope.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/scope.h b/paddle/fluid/lite/core/scope.h index 56ded2b65d2..1709dfe88d0 100644 --- a/paddle/fluid/lite/core/scope.h +++ b/paddle/fluid/lite/core/scope.h @@ -19,7 +19,7 @@ #include #include #include -#include "variable.h" +#include "paddle/fluid/lite/core/variable.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/scope_test.cc b/paddle/fluid/lite/core/scope_test.cc index 63c8c2f58a3..43eeb232ee1 100644 --- a/paddle/fluid/lite/core/scope_test.cc +++ b/paddle/fluid/lite/core/scope_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "scope.h" +#include "paddle/fluid/lite/core/scope.h" #include namespace paddle { diff --git a/paddle/fluid/lite/core/tensor.cc b/paddle/fluid/lite/core/tensor.cc index 28d76f0793f..4354bb6cb44 100644 --- a/paddle/fluid/lite/core/tensor.cc +++ b/paddle/fluid/lite/core/tensor.cc @@ -12,4 +12,4 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensor.h" +#include "paddle/fluid/lite/core/tensor.h" diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index d95548b1915..68a76ed18cc 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -57,7 +57,6 @@ using LoD = std::vector>; // A light-weight tensor implementation. class Tensor { public: - void SyncEventTree(); Tensor() = default; template diff --git a/paddle/fluid/lite/core/variable.cc b/paddle/fluid/lite/core/variable.cc index 3ef5001f49e..79a311b5ce4 100644 --- a/paddle/fluid/lite/core/variable.cc +++ b/paddle/fluid/lite/core/variable.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "variable.h" +#include "paddle/fluid/lite/core/variable.h" namespace paddle { namespace lite {} // namespace lite diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index cbcc848f782..954c8d2ba33 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -1,2 +1,5 @@ -cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite) +cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite) cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) +cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) + +cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite) diff --git a/paddle/fluid/lite/operators/fc_op.cc b/paddle/fluid/lite/operators/fc_op.cc index 827e921bb8f..ef0a1404ddb 100644 --- a/paddle/fluid/lite/operators/fc_op.cc +++ b/paddle/fluid/lite/operators/fc_op.cc @@ -53,7 +53,7 @@ bool FcOpLite::InferShape() const { const auto w_dims = param_.w->dims(); // Set output dims - std::vector output_dims(param_.in_num_col_dims + 1, 0); + std::vector output_dims(param_.in_num_col_dims + 1, 0); for (int i = 0; i < param_.in_num_col_dims; ++i) { output_dims[i] = input_dims[i]; } diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index f3a1eaca92b..720e6039c50 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -15,23 +15,15 @@ #include #include #include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { namespace operators { -struct FcParam { - Tensor* input{nullptr}; - Tensor* w{nullptr}; - Tensor* bias{nullptr}; - Tensor* output{nullptr}; - // the input matrix dimentions. - lite::DDim in_mat_dims; - int in_num_col_dims{0}; -}; - class FcOpLite : public OpLite { public: FcOpLite() {} @@ -42,9 +34,21 @@ class FcOpLite : public OpLite { bool Run() override { return false; } - bool Build(const framework::OpDesc& opdesc, - framework::Scope* scope) override { - return false; + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool Build(const framework::OpDesc& op_desc, lite::Scope* scope) override { + auto input = op_desc.Input("Input").front(); + auto W = op_desc.Input("W").front(); + auto bias = op_desc.Input("bias").front(); + auto out = op_desc.Output("bias").front(); + + param_.input = scope->FindVar(input)->GetMutable(); + param_.w = scope->FindVar(W)->GetMutable(); + param_.bias = scope->FindVar(bias)->GetMutable(); + param_.output = scope->FindVar(out)->GetMutable(); + param_.in_num_col_dims = + boost::any_cast(op_desc.GetAttr("in_num_col_dims")); + + return true; } std::string DebugString() const override { return "fc"; } diff --git a/paddle/fluid/lite/utils/all.h b/paddle/fluid/lite/utils/all.h index df07541a179..7730bfb9030 100644 --- a/paddle/fluid/lite/utils/all.h +++ b/paddle/fluid/lite/utils/all.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/fluid/lite/utils/varient.h" #include "paddle/fluid/lite/utils/check.h" #include "paddle/fluid/lite/utils/factory.h" #include "paddle/fluid/lite/utils/macros.h" +#include "paddle/fluid/lite/utils/varient.h" -- GitLab