From 65bfecc9ec84da3dc448f3e9afc75635a60e8d5a Mon Sep 17 00:00:00 2001 From: superjomn Date: Thu, 18 Apr 2019 21:28:49 +0800 Subject: [PATCH] make kernel param-type-recorder and typesystem work --- paddle/fluid/lite/core/CMakeLists.txt | 3 +- paddle/fluid/lite/core/kernel.cc | 29 ++--- paddle/fluid/lite/core/kernel.h | 35 +++++- paddle/fluid/lite/core/mir/CMakeLists.txt | 1 + .../lite/core/mir/generate_program_pass.cc | 2 +- .../fluid/lite/core/mir/io_complement_pass.cc | 4 +- paddle/fluid/lite/core/mir/node.h | 11 +- paddle/fluid/lite/core/mir/passes.h | 1 + paddle/fluid/lite/core/mir/ssa_graph.cc | 75 ++++++++++- paddle/fluid/lite/core/mir/ssa_graph.h | 84 +++++++++---- .../lite/core/mir/static_kernel_pick_pass.cc | 1 + .../core/mir/variable_place_inference_pass.cc | 34 +++++ .../core/mir/variable_place_inference_pass.h | 116 ++++++++++++++++++ paddle/fluid/lite/core/op_lite.cc | 7 ++ paddle/fluid/lite/core/op_lite.h | 99 ++++++++++++--- paddle/fluid/lite/core/optimizer_test.cc | 7 ++ paddle/fluid/lite/core/type_system.cc | 12 ++ paddle/fluid/lite/kernels/host/fc_compute.cc | 13 +- 18 files changed, 459 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index c7e86290853..1c58ceda545 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -1,6 +1,7 @@ cc_library(memory_lite SRCS memory.cc) +cc_library(target_wrapper_lite SRCS target_wrapper.cc) cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) -cc_library(kernel_lite SRCS kernel.cc DEPS type_system) +cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(variable_lite SRCS variable.cc) cc_library(op_registry_lite SRCS op_registry.cc) cc_library(scope_lite SRCS scope.cc) diff --git a/paddle/fluid/lite/core/kernel.cc b/paddle/fluid/lite/core/kernel.cc index a4268c78375..7cea5986dd7 100644 --- a/paddle/fluid/lite/core/kernel.cc +++ b/paddle/fluid/lite/core/kernel.cc @@ -17,29 +17,18 @@ namespace paddle { namespace lite { -bool operator<(const Place &a, const Place &b) { - if (a.target != b.target) - return a.target < b.target; - else if (a.precision != b.precision) - return a.precision < b.precision; - else if (a.layout != b.layout) - return a.layout < b.layout; - return true; -} - bool ParamTypeRegistry::KeyCmp::operator()( const ParamTypeRegistry::key_t &a, const ParamTypeRegistry::key_t &b) const { - if (a.kernel_type != b.kernel_type) - return a.kernel_type < b.kernel_type; - else if (a.io != b.io) - return a.io < b.io; - else if (a.arg_name != b.arg_name) - return a.arg_name < b.arg_name; - else if (!(a.place == b.place)) { - return a.place < b.place; - } - return true; + return a.hash() < b.hash(); +} + +std::ostream &operator<<(std::ostream &os, + const ParamTypeRegistry::KernelIdTy &other) { + std::string io_s = other.io == ParamTypeRegistry::IO::kInput ? "in" : "out"; + os << other.kernel_type << ":" << other.arg_name << ":" << io_s << ":" + << other.place.DebugString(); + return os; } } // namespace lite diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index f5e31233104..fef32ba3bc2 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -55,6 +55,7 @@ class KernelBase { void Torch() {} + virtual Place place() const = 0; virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual DataLayoutType layout() const = 0; @@ -87,7 +88,9 @@ struct ParamType { : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash, const Place& place) : element_type_hash(element_type_hash), tensor_place(place) {} - ParamType(const Type* type) : type_(type) {} + ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); } + + std::string DebugString() const { return tensor_place.DebugString(); } }; /* @@ -167,15 +170,32 @@ class ParamTypeRegistry { const std::string& arg_name, ParamType data_type) { KernelIdTy key{kernel_type, place, io, arg_name}; types_[key] = data_type; + CHECK(types_.count(key)); } - ParamType Retrive(const Place& place, int offset); + template + const ParamType* Retrieve(const Place& place, const std::string& op_type, + const std::string& arg_name) { + KernelIdTy key{op_type, place, io, arg_name}; + LOG(INFO) << "Looking for " << key; + auto it = types_.find(key); + if (it == types_.end()) return nullptr; + return &it->second; + } static ParamTypeRegistry& Global() { static ParamTypeRegistry x; return x; } + friend std::ostream& operator<<(std::ostream& os, + const ParamTypeRegistry& other) { + for (auto& item : other.types_) { + os << item.first << " " << item.second.DebugString() << "\n"; + } + return os; + } + private: ParamTypeRegistry() = default; @@ -186,6 +206,16 @@ class ParamTypeRegistry { Place place; IO io; std::string arg_name; + + size_t hash() const { + std::hash h; + size_t hash = h(kernel_type); + hash = hash_combine(hash, place.hash()); + hash = hash_combine(hash, std::hash()(static_cast(io))); + hash = hash_combine(hash, std::hash()(arg_name)); + return hash; + } + friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other); }; using key_t = KernelIdTy; @@ -213,6 +243,7 @@ class OpKernel : public KernelBase { TargetType target() const override { return Target; } PrecisionType precision() const override { return Precision; } DataLayoutType layout() const override { return DataLayout; } + Place place() const override { return Place{Target, Precision, DataLayout}; } std::string name() const override { return op_type() + ":" + TargetToStr(Target) + "/" + PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout); diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 757e5e141fd..85fdfa24c19 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -8,6 +8,7 @@ cc_library(mir_passes io_complement_pass.cc graph_visualize_pass.cc generate_program_pass.cc + variable_place_inference_pass.cc demo_pass.cc DEPS mir_pass types_lite) diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index 659a959fec6..484edb11327 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -20,7 +20,7 @@ namespace lite { namespace mir { void GenerateProgramPass::Apply(std::unique_ptr& graph) { - for (auto& item : graph->TopoloticalOrder()) { + for (auto& item : graph->InstructTopologicalOrder()) { if (item->IsInstruct()) { auto& instruct = item->AsInstruct(); kernels_.emplace_back(std::move(instruct.valid_kernels.front())); diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index 8511a3921f5..c1b3a2c248a 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -19,7 +19,9 @@ namespace paddle { namespace lite { namespace mir { -void IoComplementPass::Apply(std::unique_ptr &graph) {} +void IoComplementPass::Apply(std::unique_ptr &graph) { + // Start from inputs of the graph, those should should have place set. +} } // namespace mir } // namespace lite diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index 15b18b5e822..f7d5dc699b0 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -19,6 +19,7 @@ #include #include #include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" namespace paddle { namespace lite { @@ -44,11 +45,15 @@ class Node { Place place; // The kernel instances this Instruct contains. std::vector> valid_kernels; + std::shared_ptr op_info; }; struct Argument { std::string name; Place place; + // Weight is a special kind of argument, it is marked as weight explicitly + // so that some weight related optimization can take place. + bool is_weight{false}; }; Argument& AsArgument(const std::string& name) { @@ -57,9 +62,13 @@ class Node { return x; } - Instruct& AsInstruct(const std::string& op_type) { + Instruct& AsInstruct(const std::string& op_type, + std::vector>&& kernels, + const std::shared_ptr& op_info) { auto& x = AsInstruct(); x.op_type = op_type; + x.valid_kernels = std::move(kernels); + x.op_info = op_info; return x; } diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index 237b2e889df..50ecd267887 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -23,5 +23,6 @@ namespace mir {} // namespace mir USE_MIR_PASS(demo); USE_MIR_PASS(static_kernel_pick_pass); +USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(io_complement_pass); USE_MIR_PASS(generate_program_pass); diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index cfab5b35f01..0ab203af4d8 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -16,6 +16,79 @@ namespace paddle { namespace lite { -namespace mir {} // namespace mir +namespace mir { + +bool SSAGraph::CheckBidirectionalConnection() { + LOG(INFO) << "node count " << node_storage_.size(); + for (auto &node : node_storage_) { + for (auto *in : node.inlinks) { + CHECK(in->outlinks.end() != + std::find(in->outlinks.begin(), in->outlinks.end(), &node)); + } + for (auto *out : node.outlinks) { + CHECK(out->inlinks.end() != + std::find(out->inlinks.begin(), out->inlinks.end(), &node)); + } + } + return true; +} + +std::map> SSAGraph::BuildOperationAdjList() { + std::map> adj_list; + + for (auto &n : mutable_nodes()) { + if (!n.IsInstruct()) continue; + if (adj_list.find(&n) == adj_list.end()) { + adj_list[&n] = std::set(); + } + std::vector nodes; + for (auto &var : n.inlinks) { + for (auto &adj_n : var->inlinks) { + PADDLE_ENFORCE(adj_n->IsInstruct()); + nodes.push_back(adj_n); + } + } + std::sort(nodes.begin(), nodes.end(), + [](mir::Node *node1, mir::Node *node2) { return node1 > node2; }); + adj_list[&n].insert(std::make_move_iterator(nodes.begin()), + std::make_move_iterator(nodes.end())); + } + return adj_list; +} + +void SSAGraph::SortHelper( + const std::map> &adj_list, + mir::Node *node, std::set *visited, + std::vector *ret) { + visited->insert(node); + + for (auto adj : adj_list.at(node)) { + if (visited->find(adj) == visited->end()) { + SortHelper(adj_list, adj, visited, ret); + } + } + + ret->push_back(node); +} + +std::vector SSAGraph::InstructTopologicalOrder() { + CheckBidirectionalConnection(); + + std::stack stack; + std::set visited; + std::vector res; + + auto adj_list = BuildOperationAdjList(); + + for (auto adj : adj_list) { + if (visited.find(adj.first) == visited.end()) { + SortHelper(adj_list, adj.first, &visited, &res); + } + } + + return res; +} + +} // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 95d2d2d98d2..63b0cdb7f69 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -19,6 +19,7 @@ #include #include #include +#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/op_lite.h" @@ -34,7 +35,13 @@ struct Program { std::list tmp_vars; std::list weights; std::list> ops; - lite::Scope *scope; + lite::Scope *scope{}; +}; + +// Program of kernel. +struct KernelProgram { + std::list> instructions; + lite::Scope *scope{}; }; // An Graph for MIR. It is built from a list of Op and a scope. @@ -59,17 +66,19 @@ class SSAGraph : GraphBase { // TODO(Superjomn) remove one valid_places here. op->SetValidPlaces(valid_places); auto &new_node = node_storage_.back(); - auto &new_kernel = node_storage_.back().AsInstruct(op->op_type_); - new_kernel.valid_kernels = op->CreateKernels(valid_places); + node_storage_.back().AsInstruct( + op->op_type_, op->CreateKernels(valid_places), op->op_info()); CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found"; // collect inputs and outputs - for (const std::string &name : op->input_names()) { - new_node.inlinks.push_back(arguments_.at(name)); + for (const std::string &name : op->op_info()->input_names()) { + auto *arg = arguments_.at(name); + new_node.inlinks.push_back(arg); + arg->outlinks.push_back(&new_node); } - for (const std::string &name : op->output_names()) { + for (const std::string &name : op->op_info()->output_names()) { if (!arguments_.count(name)) { node_storage_.emplace_back(); auto &new_node = node_storage_.back(); @@ -77,33 +86,35 @@ class SSAGraph : GraphBase { arg.name = name; arguments_.emplace(name, &new_node); } - new_node.outlinks.push_back(arguments_.at(name)); + auto *arg = arguments_.at(name); + new_node.outlinks.push_back(arg); + arg->inlinks.push_back(&new_node); } } + + MarkArgumentWeights(program); } - void sort_utils(mir::Node *n, std::map &visited, - std::stack &stack) { - visited[n] = true; - for (auto &out : n->outlinks) { - if (!visited[out]) { - sort_utils(out, visited, stack); + std::vector InstructTopologicalOrder(); + + // The inputs of the graph. + std::vector inputs() { + std::vector res; + for (auto &node : node_storage_) { + if (node.inlinks.empty()) { + res.push_back(&node); } } + return res; } - std::vector TopoloticalOrder() { - std::map visited; - std::stack stack; + // The outputs of the graph. + std::vector outputs() { std::vector res; - - for (auto &n : mutable_nodes()) { - if (!visited[&n]) sort_utils(&n, visited, stack); - } - - while (!stack.empty()) { - res.push_back(stack.top()); - stack.pop(); + for (auto &node : node_storage_) { + if (node.outlinks.empty()) { + res.push_back(&node); + } } return res; } @@ -111,6 +122,31 @@ class SSAGraph : GraphBase { const std::list &nodes() const { return node_storage_; } std::list &mutable_nodes() { return node_storage_; } + mir::Node *RetriveArgument(const std::string &arg) { + auto it = arguments_.find(arg); + if (it != arguments_.end()) { + return it->second; + } + return nullptr; + } + + private: + // Check the bidirectional connection. + bool CheckBidirectionalConnection(); + + void MarkArgumentWeights(const Program &program) { + for (const auto &name : program.weights) { + arguments_[name]->AsArgument().is_weight = true; + } + } + + // Build operator inlink edge table. + std::map> BuildOperationAdjList(); + + void SortHelper(const std::map> &adj_list, + mir::Node *node, std::set *visited, + std::vector *ret); + private: std::list node_storage_; std::map arguments_; diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index 8a324a6cca2..a86470f3fce 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr& graph) { // TODO(Superjomn) reconsider this. instruct.valid_kernels.clear(); instruct.valid_kernels.emplace_back(std::move(scored.front().second)); + instruct.place = instruct.valid_kernels.front()->place(); LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); } } diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc index e69de29bb2d..50665066365 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/variable_place_inference_pass.h" +#include +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void VariablePlaceInferencePass::Apply(std::unique_ptr& graph) { + MarkInputPlace(graph.get()); + + InferenceArgumentPlace(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(variable_place_inference_pass, + paddle::lite::mir::VariablePlaceInferencePass); diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 8b137891791..0eeb7c9cce1 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -1 +1,117 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * Mark the place of the variables in the SSAGrpah, it will inference the + * variables' place by the kernels outputs them. + */ +class VariablePlaceInferencePass : public DebugPass { + public: + void Apply(std::unique_ptr& graph) override; + + private: + // Mark the place of input arguments. + void MarkInputPlace(SSAGraph* graph) { + for (const auto& v : graph->inputs()) { + // the feed op might in the inputs + if (v->IsInstruct()) { + LOG(INFO) << "found kernel in inputs " << v->AsInstruct().op_type; + continue; + } + + auto& arg = v->AsArgument(); + arg.place.target = argument_default_target_; + // the other place description can't be determined yet, until their first + // usage by some kernel. + } + } + + void InferenceArgumentPlace(SSAGraph* graph) { + LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global(); + for (auto& x : graph->InstructTopologicalOrder()) { + auto& inst = x->AsInstruct(); + CHECK(inst.place.is_valid()) + << "kernel's place should be set when loaded"; + // deal with inputs + for (auto& arg_name : inst.op_info->input_argnames()) { + auto type = + ParamTypeRegistry::Global().Retrieve( + inst.place, inst.op_type, arg_name); + CHECK(type) << "no param-type found for " << inst.op_type << ":" + << arg_name << " " << inst.place.DebugString(); + auto arg_names = inst.op_info->input_argument().at(arg_name); + // check if inputs's place is set, if not set, update them with the + // kernel's declaration. + + for (auto& arg_name : arg_names) { + auto* node = graph->RetriveArgument(arg_name); + CHECK(node) << "argument " << arg_name << " not exists in the graph"; + auto& arg_node = node->AsArgument(); + if (arg_node.place.is_valid()) continue; + UpdatePlace(&arg_node.place, type->tensor_place); + } + } + + for (auto& arg_name : inst.op_info->output_argnames()) { + auto type = ParamTypeRegistry::Global() + .Retrieve( + inst.place, inst.op_type, arg_name); + CHECK(type) << "no param-type found for " << inst.op_type << ":" + << arg_name << " " << inst.place.DebugString(); + auto arg_names = inst.op_info->output_argument().at(arg_name); + // check if outputs's place is set, if not set, update them with the + // kernel's declaration. + + for (auto& arg_name : arg_names) { + auto* node = graph->RetriveArgument(arg_name); + CHECK(node) << "argument " << arg_name << " not exists in the graph"; + auto& arg_node = node->AsArgument(); + if (arg_node.place.is_valid()) continue; + UpdatePlace(&arg_node.place, type->tensor_place); + } + } + } + } + + // Update me's kUnk fields by other's fields. + void UpdatePlace(Place* me, const Place& other) { + CHECK(other.is_valid()); + if (me->target == TARGET(kUnk)) { + me->target = other.target; + } + if (me->precision == PRECISION(kUnk)) { + me->precision = other.precision; + } + if (me->layout == DATALAYOUT(kUnk)) { + me->layout = other.layout; + } + } + + private: + // The default target for arguments, e.g. load weights to CPU memory for CUDA + // computation by default. + TargetType argument_default_target_{TARGET(kHost)}; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index a57ee119cc8..72b50c5a8c1 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -54,5 +54,12 @@ bool OpLite::Run() { return true; } +bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { + CHECK(!op_info_) << "op_info duplicate build found"; + op_info_ = std::make_shared(); + op_info_->Build(opdesc); + return AttachImpl(opdesc, scope); +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 0ab0550be08..ffb8b7b3ce0 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_desc.h" @@ -41,6 +42,8 @@ class Node; class SSAGraph; } +class OpInfo; + /** * The base class of an light-weight operators, currently just used in inference * to eliminate overhead of some operations in current framework. @@ -78,10 +81,10 @@ class OpLite : public Registry { // Run this operator. virtual bool Run(); - bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { - ExtractInputsAndOutputs(opdesc); - return AttachImpl(opdesc, scope); - } + bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope); + + const std::shared_ptr &op_info() const { return op_info_; } + std::shared_ptr &mutable_op_info() { return op_info_; } // Human-readable information. virtual std::string DebugString() const = 0; @@ -91,9 +94,6 @@ class OpLite : public Registry { void PickKernel(const std::vector &valid_places, KernelStrategy kernel_strategy = KernelStrategy::kStatic); - const std::list &input_names() const { return input_names_; } - const std::list &output_names() const { return output_names_; } - virtual ~OpLite() = default; protected: @@ -101,19 +101,6 @@ class OpLite : public Registry { virtual bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) = 0; - void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) { - for (const auto &item : opdesc.Inputs()) { - for (const auto &x : item.second) { - input_names_.push_back(x); - } - } - for (const auto &item : opdesc.Outputs()) { - for (const auto &x : item.second) { - output_names_.push_back(x); - } - } - } - // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. virtual void StaticPickKernel(const std::vector &valid_targets) { @@ -141,8 +128,80 @@ class OpLite : public Registry { std::string op_type_; std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; + std::shared_ptr op_info_; +}; + +/* + * Operator Information, such as some description. It will be shared by all the + * kernels of the same operator. + */ +class OpInfo { + public: + void Build(const framework::OpDesc &desc) { + ExtractInputsAndOutputs(desc); + CollectInputAndOutputArgnames(desc); + CollectArguments(desc); + } + + const std::list &input_names() const { return input_names_; } + const std::list &output_names() const { return output_names_; } + const std::map> &input_argument() { + return input_argument_; + } + const std::map> &output_argument() { + return output_argument_; + } + + const std::list &input_argnames() const { + return input_argnames_; + } + const std::list &output_argnames() const { + return output_argnames_; + } + + private: + void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) { + for (const auto &item : opdesc.Inputs()) { + for (const auto &x : item.second) { + input_names_.push_back(x); + } + } + for (const auto &item : opdesc.Outputs()) { + for (const auto &x : item.second) { + output_names_.push_back(x); + } + } + } + + void CollectInputAndOutputArgnames(const framework::OpDesc &opdesc) { + for (const auto &item : opdesc.InputNames()) { + input_argnames_.push_back(item); + } + for (const auto &item : opdesc.OutputNames()) { + output_argnames_.push_back(item); + } + } + + void CollectArguments(const framework::OpDesc &opdesc) { + for (const auto &item : opdesc.Inputs()) { + for (auto &x : item.second) { + input_argument_[item.first].push_back(x); + } + } + for (const auto &item : opdesc.Outputs()) { + for (auto &x : item.second) { + output_argument_[item.first].push_back(x); + } + } + } + + private: std::list input_names_; std::list output_names_; + std::list input_argnames_; + std::list output_argnames_; + std::map> input_argument_; + std::map> output_argument_; }; } // namespace lite diff --git a/paddle/fluid/lite/core/optimizer_test.cc b/paddle/fluid/lite/core/optimizer_test.cc index 5d718754ddb..f363bf3e835 100644 --- a/paddle/fluid/lite/core/optimizer_test.cc +++ b/paddle/fluid/lite/core/optimizer_test.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/lite/core/optimizer.h" #include +#include "paddle/fluid/lite/core/mir/generate_program_pass.h" #include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" @@ -36,6 +37,12 @@ TEST(Optimizer, test) { .ConsiderPrecision(); optimizer.Run(std::move(program), places); + + auto* program_pass = + mir::PassManager::Global().LookUp( + "generate_program_pass"); + auto& kernels = program_pass->kernels(); + LOG(INFO) << "get kernels: " << kernels.size(); } } // namespace lite diff --git a/paddle/fluid/lite/core/type_system.cc b/paddle/fluid/lite/core/type_system.cc index 1d2ad240edc..95e4db7f105 100644 --- a/paddle/fluid/lite/core/type_system.cc +++ b/paddle/fluid/lite/core/type_system.cc @@ -34,6 +34,14 @@ Type::Get +const Type* +Type::Get() { + static TensorFp32NCHWTy x(TargetType::kHost); + return &x; +} + template <> const Type* Type::Get(TargetType target) { return Get(TargetType target) { case TargetType::kX86: return Get(); + case TargetType::kHost: + return Get(); + default: LOG(FATAL) << "unsupported target " << TargetToStr(target); return nullptr; diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index bbbc13e30ab..9c2b6c6205b 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } } // namespace paddle REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute) - .BindInput(0, {paddle::lite::Type::Get( - TARGET(kX86))}) - .BindOutput(0, {paddle::lite::Type::Get( - TARGET(kX86))}) + .BindInput("Input", + {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindInput("Bias", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindInput("W", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kHost))}) .Finalize(); -- GitLab