diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index c96727cfb8852ba8f024195051ca0be3e74b681c..b656c494ddfb49f709f21cf12cfccf71a8a42246 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -8,3 +8,4 @@ cc_library(scope_lite SRCS scope.cc) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) +cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index e9eeb71a302aad90f29e5686fbb18acb64d8910d..af1adb9338cb7ede4391bd27172a82580418d91d 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -26,6 +26,8 @@ namespace paddle { namespace lite { +// An base with virtual functions to unify all the kernel implementation on +// different targets. class KernelBase { public: virtual void Run() = 0; @@ -45,8 +47,12 @@ class KernelBase { return param_.get(); } - protected: + virtual TargetType target() const = 0; + virtual PrecisionType precision() const = 0; + virtual ~KernelBase() = default; + + protected: core::any_context_t context_; mutable operators::param_t param_; }; diff --git a/paddle/fluid/lite/core/kernel_test.cc b/paddle/fluid/lite/core/kernel_test.cc index 42ebe2ed29d042ad3f7ba2d01fca501c6c1f9864..325744aeee25d635db6079fdcd0e557b2295ffe7 100644 --- a/paddle/fluid/lite/core/kernel_test.cc +++ b/paddle/fluid/lite/core/kernel_test.cc @@ -28,6 +28,10 @@ class SomeKernel : public OpKernel { LOG(INFO) << param().in_num_col_dims; test_code = param().in_num_col_dims; } + + TargetType target() const override { return TARGET(kHost); } + PrecisionType precision() const override { return PRECISION(kFloat); } + }; TEST(Kernel, test) { diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h index 36f247c20937a85b57ce998b4d07cf6474418783..0267dd43ed1fb636148c018262edc2781baaaa83 100644 --- a/paddle/fluid/lite/core/memory.h +++ b/paddle/fluid/lite/core/memory.h @@ -28,9 +28,6 @@ static void* TargetMalloc(TargetType target, size_t size) { case static_cast(TargetType::kCUDA): data = TargetWrapper::Malloc(size); break; - case static_cast(TargetType::kARM): - data = TargetWrapper::Malloc(size); - break; case static_cast(TargetType::kHost): data = TargetWrapper::Malloc(size); break; @@ -48,9 +45,6 @@ static void TargetFree(TargetType target, void* data) { case static_cast(TargetType::kCUDA): TargetWrapper::Free(data); break; - case static_cast(TargetType::kARM): - TargetWrapper::Free(data); - break; default: LOG(FATAL) << "Unknown type"; } diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index f277ecb580015448f60c939c400be4c10b83547b..53098a2a957dc7333395305101fb602447417f96 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -12,4 +12,36 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/lite/core/op_lite.h" #include "op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +std::vector> OpLite::CreateKernels( + const std::vector &places) { + std::vector> kernels; + CHECK(!op_type_.empty()) << "op_type_ should be set first"; + + for (auto place : places) { + kernels.emplace_back(KernelRegistry::Global().Create(op_type_, place.target, + place.precision)); + } + + return kernels; +} + +void OpLite::PickKernel(const std::vector &valid_places, + OpLite::KernelStrategy kernel_strategy) { + switch (kernel_strategy) { + case KernelStrategy::kStatic: + StaticPickKernel(valid_places); + break; + default: + LOG(FATAL) << "unsupported kernel strategy"; + } +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 76f5851e43f4d25192a46fdb61b432f0a9f2e94f..fa72d95051a64daa9c8a92f13b7291ac767620fe 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -70,7 +70,15 @@ class OpLite : public Registry { // Inference the outputs' shape. virtual bool InferShape() const { return true; } // Run this operator. - virtual bool Run() = 0; + virtual bool Run() { + CHECK(kernel_); + SyncInputEvents(); + + kernel_->Run(); + + RecordOutputEvents(); + return true; + } // Build the operator, attach it with the runtime environment. virtual bool Build(const framework::OpDesc &opdesc, lite::Scope *scope) = 0; // Human-readable information. @@ -79,21 +87,31 @@ class OpLite : public Registry { const Place &kernel_place() const { return kernel_place_; } protected: + void PickKernel(const std::vector &valid_places, + KernelStrategy kernel_strategy = KernelStrategy::kStatic); + // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. virtual void StaticPickKernel(const std::vector &valid_targets) = 0; - void PickKernel(const std::vector &valid_places, - KernelStrategy kernel_strategy = KernelStrategy::kStatic); + // Wait until all the inputs' events are ready. + void SyncInputEvents() {} + + // Record the output events, and that will tell all the dependent operators + // some inputs are ready. + void RecordOutputEvents() {} // Create all the kernels for the valid targets. - void CreateKernels(); + std::vector> CreateKernels( + const std::vector &places); virtual ~OpLite() = default; protected: std::unique_ptr op_context_; Place kernel_place_; + std::unique_ptr kernel_; + std::string op_type_; }; } // namespace lite diff --git a/paddle/fluid/lite/core/op_lite_test.cc b/paddle/fluid/lite/core/op_lite_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0892da86606d5f9fe1be636f8287fa8b867a4116 --- /dev/null +++ b/paddle/fluid/lite/core/op_lite_test.cc @@ -0,0 +1,13 @@ +#include +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { + +TEST(OpLite, test) { + + +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index deb1654bc35a409020d92ab69df52b2d29927868..bf3bacb346c5301b4c5ae7cd900b40034cf0f914 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -59,7 +59,6 @@ class KernelRegistry final { KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // - KernelRegistryForTarget *, // KernelRegistryForTarget * // >; @@ -77,7 +76,6 @@ registries_[0].set( *>(&KernelRegistryForTarget::Global()); // Currently, just register 2 kernel targets. - INIT_FOR(kARM, kFloat); INIT_FOR(kHost, kFloat); #undef INIT_FOR } @@ -97,6 +95,42 @@ registries_[0].set( ->Register(name, std::move(creator)); } + template + std::unique_ptr Create(const std::string &op_type) { + using kernel_registor_t = KernelRegistryForTarget; + return registries_[GetKernelOffset()] + .template get() + ->Create(op_type); + } + + std::unique_ptr Create(const std::string &op_type, + TargetType target, + PrecisionType precision) { +#define CREATE_KERNEL(target__) \ + switch (precision) { \ + case PRECISION(kFloat): \ + return Create(op_type); \ + default: \ + CHECK(false) << "not supported kernel place yet"; \ + } + + switch (target) { + case TARGET(kHost): { + CREATE_KERNEL(kHost); + } break; + case TARGET(kX86): { + CREATE_KERNEL(kX86); + } break; + case TARGET(kCUDA): { + CREATE_KERNEL(kCUDA); + } break; + default: + CHECK(false) << "not supported kernel place"; + } + +#undef CREATE_KERNEL + } + // Get a kernel registry offset in all the registries. template static constexpr int GetKernelOffset() { diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index 60de8c224db223727f62fff8f8220493af301ec7..c7bd2ce2172eb0e1ce5063128b6e933094687a1a 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -18,7 +18,7 @@ namespace paddle { namespace lite { -enum class TargetType { kHost = 0, kX86, kCUDA, kARM, kLastAsPlaceHolder }; +enum class TargetType { kHost = 0, kX86, kCUDA, kLastAsPlaceHolder }; // Some helper macro to get a specific TargetType. #define TARGET(item__) paddle::lite::TargetType::item__ #define TARGET_VAL(item__) static_cast(TARGET(item__)) diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h index 566d407f8508f59103dfc05850cd72bd6d73517f..0e95882c56056e93d02e65fb34152e2818da0972 100644 --- a/paddle/fluid/lite/core/types.h +++ b/paddle/fluid/lite/core/types.h @@ -21,9 +21,8 @@ namespace paddle { namespace lite { namespace core { -using any_context_t = variant, // - Context, // - Context // +using any_context_t = variant, // + Context // >; } // namespace core