diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index ddfbc83f6cbe71df7c64e226cfccd1fd73e416fa..c96727cfb8852ba8f024195051ca0be3e74b681c 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -7,3 +7,4 @@ cc_library(op_registry_lite SRCS op_registry.cc) cc_library(scope_lite SRCS scope.cc) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) +cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 6355cb659deb96e357bb60031e0fd9faf8dab740..e9eeb71a302aad90f29e5686fbb18acb64d8910d 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -14,46 +14,54 @@ #pragma once -#include -#include #include #include -#include "context.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/core/target_wrapper.h" +#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" -#include "target_wrapper.h" namespace paddle { namespace lite { -// Light-weight kernel implementation. -// The OpKernel is designed to implement the specific algorithm on a target -// device. -template -class OpKernel { +class KernelBase { public: - using context_t = Context; - using context_ptr_t = std::unique_ptr; - - OpKernel() = default; + virtual void Run() = 0; - void SetContext(context_ptr_t&& ctx) { context_ = std::move(ctx); } + template + void SetContext(std::unique_ptr>&& ctx) { + context_.set>>(std::move(ctx)); + } - void SetParam(operators::param_t param) { param_ = param; } + template + void SetParam(T param) { + param_.set(param); + } template Param& param() const { return param_.get(); } + protected: + virtual ~KernelBase() = default; + core::any_context_t context_; + mutable operators::param_t param_; +}; + +// Light-weight kernel implementation. +// The OpKernel is designed to implement the specific algorithm on a target +// device. +template +class OpKernel : public KernelBase { + public: virtual void Run() { CHECK(false) << "Not Implemented"; } - virtual ~OpKernel() = default; + OpKernel() = default; - protected: - context_ptr_t context_; - mutable operators::param_t param_; + virtual ~OpKernel() = default; }; } // namespace lite diff --git a/paddle/fluid/lite/core/kernel_test.cc b/paddle/fluid/lite/core/kernel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..42ebe2ed29d042ad3f7ba2d01fca501c6c1f9864 --- /dev/null +++ b/paddle/fluid/lite/core/kernel_test.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/kernel.h" +#include +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace core { + +int test_code{-1}; +class SomeKernel : public OpKernel { + public: + void Run() override { + LOG(INFO) << "SomeKernel executed"; + LOG(INFO) << param().in_num_col_dims; + test_code = param().in_num_col_dims; + } +}; + +TEST(Kernel, test) { + SomeKernel kernel; + operators::FcParam param; + param.in_num_col_dims = 100; + kernel.SetParam(param); + kernel.Run(); + ASSERT_EQ(test_code, 100); +} + +} // namespace core +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index d660ff9c1f830460076b1b92fc95d32ae6cf2f1a..76f5851e43f4d25192a46fdb61b432f0a9f2e94f 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -57,7 +57,12 @@ class OpLite : public Registry { kRuntime, }; - OpLite() {} + struct Place { + TargetType target{TARGET(kHost)}; + PrecisionType precision{PRECISION(kFloat)}; + }; + + OpLite() = default; OpLite(std::unique_ptr &&x) : op_context_(std::move(x)) {} // Check the shape. @@ -71,12 +76,14 @@ class OpLite : public Registry { // Human-readable information. virtual std::string DebugString() const = 0; + const Place &kernel_place() const { return kernel_place_; } + protected: - // Specify the kernel to run by default. - virtual void StaticPickKernel( - const std::vector &valid_targets) = 0; + // Specify the kernel to run by default. This will specify the value of + // `kernel_place_`. + virtual void StaticPickKernel(const std::vector &valid_targets) = 0; - void PickKernel(const std::vector &valid_places, + void PickKernel(const std::vector &valid_places, KernelStrategy kernel_strategy = KernelStrategy::kStatic); // Create all the kernels for the valid targets. @@ -86,6 +93,7 @@ class OpLite : public Registry { protected: std::unique_ptr op_context_; + Place kernel_place_; }; } // namespace lite diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index dfe54f5cbd275ff1ab728eeb3db8135c33cf96af..deb1654bc35a409020d92ab69df52b2d29927868 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -54,14 +54,14 @@ using KernelRegistryForTarget = Factory>; class KernelRegistry final { public: - using any_kernel_registor_t = variant< - KernelRegistryForTarget *, // - KernelRegistryForTarget *, // - KernelRegistryForTarget *, // - KernelRegistryForTarget *, // - KernelRegistryForTarget *, // - KernelRegistryForTarget * // - >; + using any_kernel_registor_t = + variant *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // + >; KernelRegistry() { /* diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h new file mode 100644 index 0000000000000000000000000000000000000000..566d407f8508f59103dfc05850cd72bd6d73517f --- /dev/null +++ b/paddle/fluid/lite/core/types.h @@ -0,0 +1,31 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace core { + +using any_context_t = variant, // + Context, // + Context // + >; + +} // namespace core +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index 720e6039c505f6557e1e55784bc769edbb72ce67..c2a7e9a8685d111f9d17aca9aa5630bb2ea89408 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -14,6 +14,7 @@ #include #include +#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/tensor.h" @@ -53,8 +54,7 @@ class FcOpLite : public OpLite { std::string DebugString() const override { return "fc"; } - void StaticPickKernel(const std::vector& valid_targets) override { - } + void StaticPickKernel(const std::vector& valid_targets) override {} private: mutable FcParam param_; diff --git a/paddle/fluid/lite/utils/factory.h b/paddle/fluid/lite/utils/factory.h index 0d74d8b3a392f44230649deec9bc80fda2820cf6..680e9635c0276f1da856fc2a259cd0b4e341dd0e 100644 --- a/paddle/fluid/lite/utils/factory.h +++ b/paddle/fluid/lite/utils/factory.h @@ -19,6 +19,18 @@ namespace paddle { namespace lite { +/* + * Factor for any Type creator. + * + * Usage: + * + * struct SomeType; + * // Register a creator. + * Factory::Global().Register("some_key", [] -> + * std::unique_ptr { ... }); + * // Retrive a creator. + * auto some_type_instance = Factory::Global().Create("some_key"); + */ template class Factory { public: @@ -55,6 +67,7 @@ class Registor { public: Registor(std::function&& functor) { functor(); } + // Touch will do nothing. int Touch() { return 0; } }; diff --git a/paddle/fluid/lite/x86/target_wrapper.cc b/paddle/fluid/lite/x86/target_wrapper.cc index 533565787c24d3602cfaf59d800953611715678c..83250bcb498b8020eb3d0f417b93080c5aaee61e 100644 --- a/paddle/fluid/lite/x86/target_wrapper.cc +++ b/paddle/fluid/lite/x86/target_wrapper.cc @@ -12,22 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "target_wrapper.h" +#include "paddle/fluid/lite/core/target_wrapper.h" #include namespace paddle { -namespace framework { namespace lite { template <> -void TargetWrapper::MemcpySync(void* dst, void* src, size_t size, - IoDirection dir) { - std::copy_n(reinterpret_cast(src), size, - reinterpret_cast(dst)); +void TargetWrapper::MemcpySync(void *dst, void *src, size_t size, + IoDirection dir) { + std::copy_n(reinterpret_cast(src), size, + reinterpret_cast(dst)); } -template class TargetWrapper; +template class TargetWrapper; } // namespace lite -} // namespace framework } // namespace paddle diff --git a/paddle/fluid/lite/x86/target_wrapper.h b/paddle/fluid/lite/x86/target_wrapper.h index 222eb80f7f39ce42d9ff2a5cf5e4a9b7c1b4e364..4364f8cbca8671fbc022e5d724f333abdc541534 100644 --- a/paddle/fluid/lite/x86/target_wrapper.h +++ b/paddle/fluid/lite/x86/target_wrapper.h @@ -16,9 +16,7 @@ #include "paddle/fluid/lite/core/target_wrapper.h" namespace paddle { -namespace framework { namespace lite { namespace x86 {} // namespace x86 } // namespace lite -} // namespace framework } // namespace paddle