diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index 9b561e3be25db59ad5cc5b2512eed2003aa3d4a5..d0d4e3d6c4d6ba1aaf58c7cbd108f39197a37a4c 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -25,7 +25,7 @@ struct Config {}; class Predictor { public: void Build(const std::string& model_path, - const std::vector& valid_places) { + const std::vector& valid_places) { CHECK(!executor_.get()) << "duplicate build found"; framework::proto::ProgramDesc prog; LoadModel(model_path, &scope_, &prog); diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index f7d5dc699b0f4f9a71e596cf56973fdd04ba28ae..b5af9d3d04ce25ecd36490ed5e70ca895da0d6f5 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -46,6 +46,8 @@ class Node { // The kernel instances this Instruct contains. std::vector> valid_kernels; std::shared_ptr op_info; + // TODO(Superjomn) make this a shared_ptr for resource safety. + std::shared_ptr op; // we hold op to run InferShape }; struct Argument { @@ -64,9 +66,11 @@ class Node { Instruct& AsInstruct(const std::string& op_type, std::vector>&& kernels, + const std::shared_ptr& op, const std::shared_ptr& op_info) { auto& x = AsInstruct(); x.op_type = op_type; + x.op = op; x.valid_kernels = std::move(kernels); x.op_info = op_info; return x; diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 63b0cdb7f69f611b997d7fe936d76a785e55fc1b..f9f49e3e9e41d3725928c496572a44ef1c9818cd 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -34,13 +34,7 @@ namespace mir { struct Program { std::list tmp_vars; std::list weights; - std::list> ops; - lite::Scope *scope{}; -}; - -// Program of kernel. -struct KernelProgram { - std::list> instructions; + std::list> ops; lite::Scope *scope{}; }; @@ -67,7 +61,7 @@ class SSAGraph : GraphBase { op->SetValidPlaces(valid_places); auto &new_node = node_storage_.back(); node_storage_.back().AsInstruct( - op->op_type_, op->CreateKernels(valid_places), op->op_info()); + op->op_type_, op->CreateKernels(valid_places), op, op->op_info()); CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found"; @@ -122,7 +116,7 @@ class SSAGraph : GraphBase { const std::list &nodes() const { return node_storage_; } std::list &mutable_nodes() { return node_storage_; } - mir::Node *RetriveArgument(const std::string &arg) { + mir::Node *RetrieveArgument(const std::string &arg) { auto it = arguments_.find(arg); if (it != arguments_.end()) { return it->second; diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 0eeb7c9cce1a449487ff836db9bf2079df0f2df8..608894504d7ef8aaa3a1fbdf58bbe517051381d4 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -63,7 +63,7 @@ class VariablePlaceInferencePass : public DebugPass { // kernel's declaration. for (auto& arg_name : arg_names) { - auto* node = graph->RetriveArgument(arg_name); + auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); if (arg_node.place.is_valid()) continue; @@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass { // kernel's declaration. for (auto& arg_name : arg_names) { - auto* node = graph->RetriveArgument(arg_name); + auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); if (arg_node.place.is_valid()) continue; diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 95d62f2fee4fea03efd95e80101fca3f20e9089f..04c19fbf873a9433eba27456b2e3ed2b1e07ca8d 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -27,7 +27,7 @@ namespace lite { using KernelFunc = std::function; using KernelFuncCreator = std::function()>; -class LiteOpRegistry final : public Factory { +class LiteOpRegistry final : public Factory> { public: static LiteOpRegistry &Global() { static auto *x = new LiteOpRegistry; @@ -51,7 +51,9 @@ class OpLiteRegistor : public Registor { }; template -using KernelRegistryForTarget = Factory>; +using KernelRegistryForTarget = + Factory, + std::unique_ptr>>; class KernelRegistry final { public: diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index a7a894de125709dfdf79c0093779df5e654b60cb..ebbfb2139e5321bea32dc76dd2617e770b4e483e 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,3 +1,4 @@ +set(lite_kernel_deps type_system kernel_lite op_registry_lite) add_subdirectory(host) add_subdirectory(arm) add_subdirectory(cuda) diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index 60e500630d5b24af73bc6a3f925669db28fd596f..03501dce5a76a6fec8140e0a5ec6b754fd4cd060 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -1,8 +1,8 @@ -cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite) -cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite) -cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite) -cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite) -cc_library(feed_compute_host SRCS feed_compute.cc DEPS tensor_lite) +cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps}) +cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps}) +cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps}) +cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps}) +cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) cc_library(host_kernels DEPS fc_compute_host @@ -10,7 +10,7 @@ cc_library(host_kernels DEPS mul_compute_host scale_compute_host feed_compute_host - DEPS kernel_lite + DEPS ${lite_kernel_deps} ) cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite) diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index 755f5683dee7f9a8dff5399163955fa01b3e171c..a9667dd8312b4436efa65bedca96a4a8d2e86643 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -68,4 +68,10 @@ class MulCompute : public OpKernel { REGISTER_LITE_KERNEL(mul, kHost, kFloat, paddle::lite::kernels::host::MulCompute) + .BindInput("X", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindInput("Y", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/utils/factory.h b/paddle/fluid/lite/utils/factory.h index 37f2c2293cb7655ff332827fd7910746adb79974..395390b3b5b3a9b173914a114b81c377a7b4c0cf 100644 --- a/paddle/fluid/lite/utils/factory.h +++ b/paddle/fluid/lite/utils/factory.h @@ -33,12 +33,12 @@ namespace lite { * // Retrive a creator. * auto some_type_instance = Factory::Global().Create("some_key"); */ -template +template class Factory { public: using item_t = ItemType; - using self_t = Factory; - using item_ptr_t = std::unique_ptr; + using self_t = Factory; + using item_ptr_t = ItemTypePtr; using creator_t = std::function; static Factory& Global() {