From 18959c0994038b367d4772b4be98ed92c4de1d75 Mon Sep 17 00:00:00 2001 From: superjomn Date: Fri, 19 Apr 2019 12:29:28 +0800 Subject: [PATCH] make op pointer shared_ptr to support kernel infershape --- paddle/fluid/lite/api/cxx_api.h | 2 +- paddle/fluid/lite/core/mir/node.h | 4 ++++ paddle/fluid/lite/core/mir/ssa_graph.h | 12 +++--------- .../lite/core/mir/variable_place_inference_pass.h | 4 ++-- paddle/fluid/lite/core/op_registry.h | 6 ++++-- paddle/fluid/lite/kernels/CMakeLists.txt | 1 + paddle/fluid/lite/kernels/host/CMakeLists.txt | 12 ++++++------ paddle/fluid/lite/kernels/host/mul_compute.cc | 6 ++++++ paddle/fluid/lite/utils/factory.h | 6 +++--- 9 files changed, 30 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index 9b561e3be25..d0d4e3d6c4d 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -25,7 +25,7 @@ struct Config {}; class Predictor { public: void Build(const std::string& model_path, - const std::vector& valid_places) { + const std::vector& valid_places) { CHECK(!executor_.get()) << "duplicate build found"; framework::proto::ProgramDesc prog; LoadModel(model_path, &scope_, &prog); diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index f7d5dc699b0..b5af9d3d04c 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -46,6 +46,8 @@ class Node { // The kernel instances this Instruct contains. std::vector> valid_kernels; std::shared_ptr op_info; + // TODO(Superjomn) make this a shared_ptr for resource safety. + std::shared_ptr op; // we hold op to run InferShape }; struct Argument { @@ -64,9 +66,11 @@ class Node { Instruct& AsInstruct(const std::string& op_type, std::vector>&& kernels, + const std::shared_ptr& op, const std::shared_ptr& op_info) { auto& x = AsInstruct(); x.op_type = op_type; + x.op = op; x.valid_kernels = std::move(kernels); x.op_info = op_info; return x; diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 63b0cdb7f69..f9f49e3e9e4 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -34,13 +34,7 @@ namespace mir { struct Program { std::list tmp_vars; std::list weights; - std::list> ops; - lite::Scope *scope{}; -}; - -// Program of kernel. -struct KernelProgram { - std::list> instructions; + std::list> ops; lite::Scope *scope{}; }; @@ -67,7 +61,7 @@ class SSAGraph : GraphBase { op->SetValidPlaces(valid_places); auto &new_node = node_storage_.back(); node_storage_.back().AsInstruct( - op->op_type_, op->CreateKernels(valid_places), op->op_info()); + op->op_type_, op->CreateKernels(valid_places), op, op->op_info()); CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found"; @@ -122,7 +116,7 @@ class SSAGraph : GraphBase { const std::list &nodes() const { return node_storage_; } std::list &mutable_nodes() { return node_storage_; } - mir::Node *RetriveArgument(const std::string &arg) { + mir::Node *RetrieveArgument(const std::string &arg) { auto it = arguments_.find(arg); if (it != arguments_.end()) { return it->second; diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 0eeb7c9cce1..608894504d7 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -63,7 +63,7 @@ class VariablePlaceInferencePass : public DebugPass { // kernel's declaration. for (auto& arg_name : arg_names) { - auto* node = graph->RetriveArgument(arg_name); + auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); if (arg_node.place.is_valid()) continue; @@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass { // kernel's declaration. for (auto& arg_name : arg_names) { - auto* node = graph->RetriveArgument(arg_name); + auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); if (arg_node.place.is_valid()) continue; diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 95d62f2fee4..04c19fbf873 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -27,7 +27,7 @@ namespace lite { using KernelFunc = std::function; using KernelFuncCreator = std::function()>; -class LiteOpRegistry final : public Factory { +class LiteOpRegistry final : public Factory> { public: static LiteOpRegistry &Global() { static auto *x = new LiteOpRegistry; @@ -51,7 +51,9 @@ class OpLiteRegistor : public Registor { }; template -using KernelRegistryForTarget = Factory>; +using KernelRegistryForTarget = + Factory, + std::unique_ptr>>; class KernelRegistry final { public: diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index a7a894de125..ebbfb2139e5 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,3 +1,4 @@ +set(lite_kernel_deps type_system kernel_lite op_registry_lite) add_subdirectory(host) add_subdirectory(arm) add_subdirectory(cuda) diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index 60e500630d5..03501dce5a7 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -1,8 +1,8 @@ -cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite) -cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite) -cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite) -cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite) -cc_library(feed_compute_host SRCS feed_compute.cc DEPS tensor_lite) +cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps}) +cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps}) +cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps}) +cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps}) +cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) cc_library(host_kernels DEPS fc_compute_host @@ -10,7 +10,7 @@ cc_library(host_kernels DEPS mul_compute_host scale_compute_host feed_compute_host - DEPS kernel_lite + DEPS ${lite_kernel_deps} ) cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite) diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index 755f5683dee..a9667dd8312 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -68,4 +68,10 @@ class MulCompute : public OpKernel { REGISTER_LITE_KERNEL(mul, kHost, kFloat, paddle::lite::kernels::host::MulCompute) + .BindInput("X", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindInput("Y", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/utils/factory.h b/paddle/fluid/lite/utils/factory.h index 37f2c2293cb..395390b3b5b 100644 --- a/paddle/fluid/lite/utils/factory.h +++ b/paddle/fluid/lite/utils/factory.h @@ -33,12 +33,12 @@ namespace lite { * // Retrive a creator. * auto some_type_instance = Factory::Global().Create("some_key"); */ -template +template class Factory { public: using item_t = ItemType; - using self_t = Factory; - using item_ptr_t = std::unique_ptr; + using self_t = Factory; + using item_ptr_t = ItemTypePtr; using creator_t = std::function; static Factory& Global() { -- GitLab