提交 65bfecc9 编写于 作者: S superjomn

make kernel param-type-recorder and typesystem work

上级 1efa91dd
cc_library(memory_lite SRCS memory.cc) cc_library(memory_lite SRCS memory.cc)
cc_library(target_wrapper_lite SRCS target_wrapper.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system) cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc) cc_library(scope_lite SRCS scope.cc)
......
...@@ -17,29 +17,18 @@ ...@@ -17,29 +17,18 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
bool operator<(const Place &a, const Place &b) {
if (a.target != b.target)
return a.target < b.target;
else if (a.precision != b.precision)
return a.precision < b.precision;
else if (a.layout != b.layout)
return a.layout < b.layout;
return true;
}
bool ParamTypeRegistry::KeyCmp::operator()( bool ParamTypeRegistry::KeyCmp::operator()(
const ParamTypeRegistry::key_t &a, const ParamTypeRegistry::key_t &a,
const ParamTypeRegistry::key_t &b) const { const ParamTypeRegistry::key_t &b) const {
if (a.kernel_type != b.kernel_type) return a.hash() < b.hash();
return a.kernel_type < b.kernel_type; }
else if (a.io != b.io)
return a.io < b.io; std::ostream &operator<<(std::ostream &os,
else if (a.arg_name != b.arg_name) const ParamTypeRegistry::KernelIdTy &other) {
return a.arg_name < b.arg_name; std::string io_s = other.io == ParamTypeRegistry::IO::kInput ? "in" : "out";
else if (!(a.place == b.place)) { os << other.kernel_type << ":" << other.arg_name << ":" << io_s << ":"
return a.place < b.place; << other.place.DebugString();
} return os;
return true;
} }
} // namespace lite } // namespace lite
......
...@@ -55,6 +55,7 @@ class KernelBase { ...@@ -55,6 +55,7 @@ class KernelBase {
void Torch() {} void Torch() {}
virtual Place place() const = 0;
virtual TargetType target() const = 0; virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0; virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0; virtual DataLayoutType layout() const = 0;
...@@ -87,7 +88,9 @@ struct ParamType { ...@@ -87,7 +88,9 @@ struct ParamType {
: element_type_hash(element_type_hash) {} : element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place) ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {} : element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) {} ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); }
std::string DebugString() const { return tensor_place.DebugString(); }
}; };
/* /*
...@@ -167,15 +170,32 @@ class ParamTypeRegistry { ...@@ -167,15 +170,32 @@ class ParamTypeRegistry {
const std::string& arg_name, ParamType data_type) { const std::string& arg_name, ParamType data_type) {
KernelIdTy key{kernel_type, place, io, arg_name}; KernelIdTy key{kernel_type, place, io, arg_name};
types_[key] = data_type; types_[key] = data_type;
CHECK(types_.count(key));
} }
ParamType Retrive(const Place& place, int offset); template <IO io>
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
LOG(INFO) << "Looking for " << key;
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
}
static ParamTypeRegistry& Global() { static ParamTypeRegistry& Global() {
static ParamTypeRegistry x; static ParamTypeRegistry x;
return x; return x;
} }
friend std::ostream& operator<<(std::ostream& os,
const ParamTypeRegistry& other) {
for (auto& item : other.types_) {
os << item.first << " " << item.second.DebugString() << "\n";
}
return os;
}
private: private:
ParamTypeRegistry() = default; ParamTypeRegistry() = default;
...@@ -186,6 +206,16 @@ class ParamTypeRegistry { ...@@ -186,6 +206,16 @@ class ParamTypeRegistry {
Place place; Place place;
IO io; IO io;
std::string arg_name; std::string arg_name;
size_t hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other);
}; };
using key_t = KernelIdTy; using key_t = KernelIdTy;
...@@ -213,6 +243,7 @@ class OpKernel : public KernelBase { ...@@ -213,6 +243,7 @@ class OpKernel : public KernelBase {
TargetType target() const override { return Target; } TargetType target() const override { return Target; }
PrecisionType precision() const override { return Precision; } PrecisionType precision() const override { return Precision; }
DataLayoutType layout() const override { return DataLayout; } DataLayoutType layout() const override { return DataLayout; }
Place place() const override { return Place{Target, Precision, DataLayout}; }
std::string name() const override { std::string name() const override {
return op_type() + ":" + TargetToStr(Target) + "/" + return op_type() + ":" + TargetToStr(Target) + "/" +
PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout); PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
......
...@@ -8,6 +8,7 @@ cc_library(mir_passes ...@@ -8,6 +8,7 @@ cc_library(mir_passes
io_complement_pass.cc io_complement_pass.cc
graph_visualize_pass.cc graph_visualize_pass.cc
generate_program_pass.cc generate_program_pass.cc
variable_place_inference_pass.cc
demo_pass.cc demo_pass.cc
DEPS mir_pass types_lite) DEPS mir_pass types_lite)
......
...@@ -20,7 +20,7 @@ namespace lite { ...@@ -20,7 +20,7 @@ namespace lite {
namespace mir { namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
for (auto& item : graph->TopoloticalOrder()) { for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) { if (item->IsInstruct()) {
auto& instruct = item->AsInstruct(); auto& instruct = item->AsInstruct();
kernels_.emplace_back(std::move(instruct.valid_kernels.front())); kernels_.emplace_back(std::move(instruct.valid_kernels.front()));
......
...@@ -19,7 +19,9 @@ namespace paddle { ...@@ -19,7 +19,9 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {} void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {
// Start from inputs of the graph, those should should have place set.
}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -44,11 +45,15 @@ class Node { ...@@ -44,11 +45,15 @@ class Node {
Place place; Place place;
// The kernel instances this Instruct contains. // The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels; std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::shared_ptr<OpInfo> op_info;
}; };
struct Argument { struct Argument {
std::string name; std::string name;
Place place; Place place;
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool is_weight{false};
}; };
Argument& AsArgument(const std::string& name) { Argument& AsArgument(const std::string& name) {
...@@ -57,9 +62,13 @@ class Node { ...@@ -57,9 +62,13 @@ class Node {
return x; return x;
} }
Instruct& AsInstruct(const std::string& op_type) { Instruct& AsInstruct(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<lite::OpInfo>& op_info) {
auto& x = AsInstruct(); auto& x = AsInstruct();
x.op_type = op_type; x.op_type = op_type;
x.valid_kernels = std::move(kernels);
x.op_info = op_info;
return x; return x;
} }
......
...@@ -23,5 +23,6 @@ namespace mir {} // namespace mir ...@@ -23,5 +23,6 @@ namespace mir {} // namespace mir
USE_MIR_PASS(demo); USE_MIR_PASS(demo);
USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(io_complement_pass); USE_MIR_PASS(io_complement_pass);
USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(generate_program_pass);
...@@ -16,6 +16,79 @@ ...@@ -16,6 +16,79 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir {} // namespace mir namespace mir {
bool SSAGraph::CheckBidirectionalConnection() {
LOG(INFO) << "node count " << node_storage_.size();
for (auto &node : node_storage_) {
for (auto *in : node.inlinks) {
CHECK(in->outlinks.end() !=
std::find(in->outlinks.begin(), in->outlinks.end(), &node));
}
for (auto *out : node.outlinks) {
CHECK(out->inlinks.end() !=
std::find(out->inlinks.begin(), out->inlinks.end(), &node));
}
}
return true;
}
std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
std::map<mir::Node *, std::set<mir::Node *>> adj_list;
for (auto &n : mutable_nodes()) {
if (!n.IsInstruct()) continue;
if (adj_list.find(&n) == adj_list.end()) {
adj_list[&n] = std::set<mir::Node *>();
}
std::vector<mir::Node *> nodes;
for (auto &var : n.inlinks) {
for (auto &adj_n : var->inlinks) {
PADDLE_ENFORCE(adj_n->IsInstruct());
nodes.push_back(adj_n);
}
}
std::sort(nodes.begin(), nodes.end(),
[](mir::Node *node1, mir::Node *node2) { return node1 > node2; });
adj_list[&n].insert(std::make_move_iterator(nodes.begin()),
std::make_move_iterator(nodes.end()));
}
return adj_list;
}
void SSAGraph::SortHelper(
const std::map<mir::Node *, std::set<mir::Node *>> &adj_list,
mir::Node *node, std::set<mir::Node *> *visited,
std::vector<mir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
if (visited->find(adj) == visited->end()) {
SortHelper(adj_list, adj, visited, ret);
}
}
ret->push_back(node);
}
std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
CheckBidirectionalConnection();
std::stack<mir::Node *> stack;
std::set<mir::Node *> visited;
std::vector<mir::Node *> res;
auto adj_list = BuildOperationAdjList();
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &res);
}
}
return res;
}
} // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <stack> #include <stack>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
...@@ -34,7 +35,13 @@ struct Program { ...@@ -34,7 +35,13 @@ struct Program {
std::list<std::string> tmp_vars; std::list<std::string> tmp_vars;
std::list<std::string> weights; std::list<std::string> weights;
std::list<std::unique_ptr<OpLite>> ops; std::list<std::unique_ptr<OpLite>> ops;
lite::Scope *scope; lite::Scope *scope{};
};
// Program of kernel.
struct KernelProgram {
std::list<std::unique_ptr<KernelBase>> instructions;
lite::Scope *scope{};
}; };
// An Graph for MIR. It is built from a list of Op and a scope. // An Graph for MIR. It is built from a list of Op and a scope.
...@@ -59,17 +66,19 @@ class SSAGraph : GraphBase { ...@@ -59,17 +66,19 @@ class SSAGraph : GraphBase {
// TODO(Superjomn) remove one valid_places here. // TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places); op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
auto &new_kernel = node_storage_.back().AsInstruct(op->op_type_); node_storage_.back().AsInstruct(
new_kernel.valid_kernels = op->CreateKernels(valid_places); op->op_type_, op->CreateKernels(valid_places), op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs // collect inputs and outputs
for (const std::string &name : op->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
new_node.inlinks.push_back(arguments_.at(name)); auto *arg = arguments_.at(name);
new_node.inlinks.push_back(arg);
arg->outlinks.push_back(&new_node);
} }
for (const std::string &name : op->output_names()) { for (const std::string &name : op->op_info()->output_names()) {
if (!arguments_.count(name)) { if (!arguments_.count(name)) {
node_storage_.emplace_back(); node_storage_.emplace_back();
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
...@@ -77,33 +86,35 @@ class SSAGraph : GraphBase { ...@@ -77,33 +86,35 @@ class SSAGraph : GraphBase {
arg.name = name; arg.name = name;
arguments_.emplace(name, &new_node); arguments_.emplace(name, &new_node);
} }
new_node.outlinks.push_back(arguments_.at(name)); auto *arg = arguments_.at(name);
new_node.outlinks.push_back(arg);
arg->inlinks.push_back(&new_node);
} }
} }
MarkArgumentWeights(program);
} }
void sort_utils(mir::Node *n, std::map<mir::Node *, bool> &visited, std::vector<mir::Node *> InstructTopologicalOrder();
std::stack<mir::Node *> &stack) {
visited[n] = true; // The inputs of the graph.
for (auto &out : n->outlinks) { std::vector<mir::Node *> inputs() {
if (!visited[out]) { std::vector<mir::Node *> res;
sort_utils(out, visited, stack); for (auto &node : node_storage_) {
if (node.inlinks.empty()) {
res.push_back(&node);
} }
} }
return res;
} }
std::vector<mir::Node *> TopoloticalOrder() { // The outputs of the graph.
std::map<mir::Node *, bool> visited; std::vector<mir::Node *> outputs() {
std::stack<mir::Node *> stack;
std::vector<mir::Node *> res; std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
for (auto &n : mutable_nodes()) { if (node.outlinks.empty()) {
if (!visited[&n]) sort_utils(&n, visited, stack); res.push_back(&node);
} }
while (!stack.empty()) {
res.push_back(stack.top());
stack.pop();
} }
return res; return res;
} }
...@@ -111,6 +122,31 @@ class SSAGraph : GraphBase { ...@@ -111,6 +122,31 @@ class SSAGraph : GraphBase {
const std::list<mir::Node> &nodes() const { return node_storage_; } const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; } std::list<mir::Node> &mutable_nodes() { return node_storage_; }
mir::Node *RetriveArgument(const std::string &arg) {
auto it = arguments_.find(arg);
if (it != arguments_.end()) {
return it->second;
}
return nullptr;
}
private:
// Check the bidirectional connection.
bool CheckBidirectionalConnection();
void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) {
arguments_[name]->AsArgument().is_weight = true;
}
}
// Build operator inlink edge table.
std::map<mir::Node *, std::set<mir::Node *>> BuildOperationAdjList();
void SortHelper(const std::map<mir::Node *, std::set<mir::Node *>> &adj_list,
mir::Node *node, std::set<mir::Node *> *visited,
std::vector<mir::Node *> *ret);
private: private:
std::list<mir::Node> node_storage_; std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_; std::map<std::string, mir::Node *> arguments_;
......
...@@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this. // TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear(); instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second)); instruct.valid_kernels.emplace_back(std::move(scored.front().second));
instruct.place = instruct.valid_kernels.front()->place();
LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); LOG(INFO) << "pick " << instruct.valid_kernels.front()->name();
} }
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/variable_place_inference_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void VariablePlaceInferencePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
MarkInputPlace(graph.get());
InferenceArgumentPlace(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(variable_place_inference_pass,
paddle::lite::mir::VariablePlaceInferencePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* Mark the place of the variables in the SSAGrpah, it will inference the
* variables' place by the kernels outputs them.
*/
class VariablePlaceInferencePass : public DebugPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
private:
// Mark the place of input arguments.
void MarkInputPlace(SSAGraph* graph) {
for (const auto& v : graph->inputs()) {
// the feed op might in the inputs
if (v->IsInstruct()) {
LOG(INFO) << "found kernel in inputs " << v->AsInstruct().op_type;
continue;
}
auto& arg = v->AsArgument();
arg.place.target = argument_default_target_;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
void InferenceArgumentPlace(SSAGraph* graph) {
LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) {
auto& inst = x->AsInstruct();
CHECK(inst.place.is_valid())
<< "kernel's place should be set when loaded";
// deal with inputs
for (auto& arg_name : inst.op_info->input_argnames()) {
auto type =
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->input_argument().at(arg_name);
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
UpdatePlace(&arg_node.place, type->tensor_place);
}
}
for (auto& arg_name : inst.op_info->output_argnames()) {
auto type = ParamTypeRegistry::Global()
.Retrieve<ParamTypeRegistry::IO::kOutput>(
inst.place, inst.op_type, arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->output_argument().at(arg_name);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
UpdatePlace(&arg_node.place, type->tensor_place);
}
}
}
}
// Update me's kUnk fields by other's fields.
void UpdatePlace(Place* me, const Place& other) {
CHECK(other.is_valid());
if (me->target == TARGET(kUnk)) {
me->target = other.target;
}
if (me->precision == PRECISION(kUnk)) {
me->precision = other.precision;
}
if (me->layout == DATALAYOUT(kUnk)) {
me->layout = other.layout;
}
}
private:
// The default target for arguments, e.g. load weights to CPU memory for CUDA
// computation by default.
TargetType argument_default_target_{TARGET(kHost)};
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -54,5 +54,12 @@ bool OpLite::Run() { ...@@ -54,5 +54,12 @@ bool OpLite::Run() {
return true; return true;
} }
bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
CHECK(!op_info_) << "op_info duplicate build found";
op_info_ = std::make_shared<OpInfo>();
op_info_->Build(opdesc);
return AttachImpl(opdesc, scope);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -41,6 +42,8 @@ class Node; ...@@ -41,6 +42,8 @@ class Node;
class SSAGraph; class SSAGraph;
} }
class OpInfo;
/** /**
* The base class of an light-weight operators, currently just used in inference * The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework. * to eliminate overhead of some operations in current framework.
...@@ -78,10 +81,10 @@ class OpLite : public Registry { ...@@ -78,10 +81,10 @@ class OpLite : public Registry {
// Run this operator. // Run this operator.
virtual bool Run(); virtual bool Run();
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope);
ExtractInputsAndOutputs(opdesc);
return AttachImpl(opdesc, scope); const std::shared_ptr<OpInfo> &op_info() const { return op_info_; }
} std::shared_ptr<OpInfo> &mutable_op_info() { return op_info_; }
// Human-readable information. // Human-readable information.
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
...@@ -91,9 +94,6 @@ class OpLite : public Registry { ...@@ -91,9 +94,6 @@ class OpLite : public Registry {
void PickKernel(const std::vector<Place> &valid_places, void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic); KernelStrategy kernel_strategy = KernelStrategy::kStatic);
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
virtual ~OpLite() = default; virtual ~OpLite() = default;
protected: protected:
...@@ -101,19 +101,6 @@ class OpLite : public Registry { ...@@ -101,19 +101,6 @@ class OpLite : public Registry {
virtual bool AttachImpl(const framework::OpDesc &opdesc, virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0; lite::Scope *scope) = 0;
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (const auto &x : item.second) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (const auto &x : item.second) {
output_names_.push_back(x);
}
}
}
// Specify the kernel to run by default. This will specify the value of // Specify the kernel to run by default. This will specify the value of
// `kernel_place_`. // `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) { virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
...@@ -141,8 +128,80 @@ class OpLite : public Registry { ...@@ -141,8 +128,80 @@ class OpLite : public Registry {
std::string op_type_; std::string op_type_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::shared_ptr<OpInfo> op_info_;
};
/*
* Operator Information, such as some description. It will be shared by all the
* kernels of the same operator.
*/
class OpInfo {
public:
void Build(const framework::OpDesc &desc) {
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
}
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
const std::map<std::string, std::list<std::string>> &input_argument() {
return input_argument_;
}
const std::map<std::string, std::list<std::string>> &output_argument() {
return output_argument_;
}
const std::list<std::string> &input_argnames() const {
return input_argnames_;
}
const std::list<std::string> &output_argnames() const {
return output_argnames_;
}
private:
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (const auto &x : item.second) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (const auto &x : item.second) {
output_names_.push_back(x);
}
}
}
void CollectInputAndOutputArgnames(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.InputNames()) {
input_argnames_.push_back(item);
}
for (const auto &item : opdesc.OutputNames()) {
output_argnames_.push_back(item);
}
}
void CollectArguments(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (auto &x : item.second) {
input_argument_[item.first].push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (auto &x : item.second) {
output_argument_[item.first].push_back(x);
}
}
}
private:
std::list<std::string> input_names_; std::list<std::string> input_names_;
std::list<std::string> output_names_; std::list<std::string> output_names_;
std::list<std::string> input_argnames_;
std::list<std::string> output_argnames_;
std::map<std::string, std::list<std::string>> input_argument_;
std::map<std::string, std::list<std::string>> output_argument_;
}; };
} // namespace lite } // namespace lite
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/core/optimizer.h" #include "paddle/fluid/lite/core/optimizer.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
...@@ -36,6 +37,12 @@ TEST(Optimizer, test) { ...@@ -36,6 +37,12 @@ TEST(Optimizer, test) {
.ConsiderPrecision(); .ConsiderPrecision();
optimizer.Run(std::move(program), places); optimizer.Run(std::move(program), places);
auto* program_pass =
mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
auto& kernels = program_pass->kernels();
LOG(INFO) << "get kernels: " << kernels.size();
} }
} // namespace lite } // namespace lite
......
...@@ -34,6 +34,14 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86, ...@@ -34,6 +34,14 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
return &x; return &x;
} }
template <>
const Type*
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kHost,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static TensorFp32NCHWTy x(TargetType::kHost);
return &x;
}
template <> template <>
const Type* Type::Get<UnsupportedTy>(TargetType target) { const Type* Type::Get<UnsupportedTy>(TargetType target) {
return Get<false, false, TargetType::kHost, PrecisionType::kFloat, return Get<false, false, TargetType::kHost, PrecisionType::kFloat,
...@@ -46,6 +54,10 @@ const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) { ...@@ -46,6 +54,10 @@ const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
case TargetType::kX86: case TargetType::kX86:
return Get<false, true, TargetType::kX86, PrecisionType::kFloat, return Get<false, true, TargetType::kX86, PrecisionType::kFloat,
DataLayoutType::kNCHW>(); DataLayoutType::kNCHW>();
case TargetType::kHost:
return Get<false, true, TargetType::kHost, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
default: default:
LOG(FATAL) << "unsupported target " << TargetToStr(target); LOG(FATAL) << "unsupported target " << TargetToStr(target);
return nullptr; return nullptr;
......
...@@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } ...@@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute) REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute)
.BindInput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("Input",
TARGET(kX86))}) {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
.BindOutput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( TARGET(kHost))})
TARGET(kX86))}) .BindInput("Bias", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindInput("W", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册