提交 65bfecc9 编写于 作者: S superjomn

make kernel param-type-recorder and typesystem work

上级 1efa91dd
cc_library(memory_lite SRCS memory.cc)
cc_library(target_wrapper_lite SRCS target_wrapper.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc)
......
......@@ -17,29 +17,18 @@
namespace paddle {
namespace lite {
bool operator<(const Place &a, const Place &b) {
if (a.target != b.target)
return a.target < b.target;
else if (a.precision != b.precision)
return a.precision < b.precision;
else if (a.layout != b.layout)
return a.layout < b.layout;
return true;
}
bool ParamTypeRegistry::KeyCmp::operator()(
const ParamTypeRegistry::key_t &a,
const ParamTypeRegistry::key_t &b) const {
if (a.kernel_type != b.kernel_type)
return a.kernel_type < b.kernel_type;
else if (a.io != b.io)
return a.io < b.io;
else if (a.arg_name != b.arg_name)
return a.arg_name < b.arg_name;
else if (!(a.place == b.place)) {
return a.place < b.place;
}
return true;
return a.hash() < b.hash();
}
std::ostream &operator<<(std::ostream &os,
const ParamTypeRegistry::KernelIdTy &other) {
std::string io_s = other.io == ParamTypeRegistry::IO::kInput ? "in" : "out";
os << other.kernel_type << ":" << other.arg_name << ":" << io_s << ":"
<< other.place.DebugString();
return os;
}
} // namespace lite
......
......@@ -55,6 +55,7 @@ class KernelBase {
void Torch() {}
virtual Place place() const = 0;
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0;
......@@ -87,7 +88,9 @@ struct ParamType {
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) {}
ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); }
std::string DebugString() const { return tensor_place.DebugString(); }
};
/*
......@@ -167,15 +170,32 @@ class ParamTypeRegistry {
const std::string& arg_name, ParamType data_type) {
KernelIdTy key{kernel_type, place, io, arg_name};
types_[key] = data_type;
CHECK(types_.count(key));
}
ParamType Retrive(const Place& place, int offset);
template <IO io>
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
LOG(INFO) << "Looking for " << key;
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
}
static ParamTypeRegistry& Global() {
static ParamTypeRegistry x;
return x;
}
friend std::ostream& operator<<(std::ostream& os,
const ParamTypeRegistry& other) {
for (auto& item : other.types_) {
os << item.first << " " << item.second.DebugString() << "\n";
}
return os;
}
private:
ParamTypeRegistry() = default;
......@@ -186,6 +206,16 @@ class ParamTypeRegistry {
Place place;
IO io;
std::string arg_name;
size_t hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other);
};
using key_t = KernelIdTy;
......@@ -213,6 +243,7 @@ class OpKernel : public KernelBase {
TargetType target() const override { return Target; }
PrecisionType precision() const override { return Precision; }
DataLayoutType layout() const override { return DataLayout; }
Place place() const override { return Place{Target, Precision, DataLayout}; }
std::string name() const override {
return op_type() + ":" + TargetToStr(Target) + "/" +
PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
......
......@@ -8,6 +8,7 @@ cc_library(mir_passes
io_complement_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
variable_place_inference_pass.cc
demo_pass.cc
DEPS mir_pass types_lite)
......
......@@ -20,7 +20,7 @@ namespace lite {
namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
for (auto& item : graph->TopoloticalOrder()) {
for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) {
auto& instruct = item->AsInstruct();
kernels_.emplace_back(std::move(instruct.valid_kernels.front()));
......
......@@ -19,7 +19,9 @@ namespace paddle {
namespace lite {
namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {}
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {
// Start from inputs of the graph, those should should have place set.
}
} // namespace mir
} // namespace lite
......
......@@ -19,6 +19,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
......@@ -44,11 +45,15 @@ class Node {
Place place;
// The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::shared_ptr<OpInfo> op_info;
};
struct Argument {
std::string name;
Place place;
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool is_weight{false};
};
Argument& AsArgument(const std::string& name) {
......@@ -57,9 +62,13 @@ class Node {
return x;
}
Instruct& AsInstruct(const std::string& op_type) {
Instruct& AsInstruct(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<lite::OpInfo>& op_info) {
auto& x = AsInstruct();
x.op_type = op_type;
x.valid_kernels = std::move(kernels);
x.op_info = op_info;
return x;
}
......
......@@ -23,5 +23,6 @@ namespace mir {} // namespace mir
USE_MIR_PASS(demo);
USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(io_complement_pass);
USE_MIR_PASS(generate_program_pass);
......@@ -16,6 +16,79 @@
namespace paddle {
namespace lite {
namespace mir {} // namespace mir
namespace mir {
bool SSAGraph::CheckBidirectionalConnection() {
LOG(INFO) << "node count " << node_storage_.size();
for (auto &node : node_storage_) {
for (auto *in : node.inlinks) {
CHECK(in->outlinks.end() !=
std::find(in->outlinks.begin(), in->outlinks.end(), &node));
}
for (auto *out : node.outlinks) {
CHECK(out->inlinks.end() !=
std::find(out->inlinks.begin(), out->inlinks.end(), &node));
}
}
return true;
}
std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
std::map<mir::Node *, std::set<mir::Node *>> adj_list;
for (auto &n : mutable_nodes()) {
if (!n.IsInstruct()) continue;
if (adj_list.find(&n) == adj_list.end()) {
adj_list[&n] = std::set<mir::Node *>();
}
std::vector<mir::Node *> nodes;
for (auto &var : n.inlinks) {
for (auto &adj_n : var->inlinks) {
PADDLE_ENFORCE(adj_n->IsInstruct());
nodes.push_back(adj_n);
}
}
std::sort(nodes.begin(), nodes.end(),
[](mir::Node *node1, mir::Node *node2) { return node1 > node2; });
adj_list[&n].insert(std::make_move_iterator(nodes.begin()),
std::make_move_iterator(nodes.end()));
}
return adj_list;
}
void SSAGraph::SortHelper(
const std::map<mir::Node *, std::set<mir::Node *>> &adj_list,
mir::Node *node, std::set<mir::Node *> *visited,
std::vector<mir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
if (visited->find(adj) == visited->end()) {
SortHelper(adj_list, adj, visited, ret);
}
}
ret->push_back(node);
}
std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
CheckBidirectionalConnection();
std::stack<mir::Node *> stack;
std::set<mir::Node *> visited;
std::vector<mir::Node *> res;
auto adj_list = BuildOperationAdjList();
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &res);
}
}
return res;
}
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -19,6 +19,7 @@
#include <stack>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
......@@ -34,7 +35,13 @@ struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::unique_ptr<OpLite>> ops;
lite::Scope *scope;
lite::Scope *scope{};
};
// Program of kernel.
struct KernelProgram {
std::list<std::unique_ptr<KernelBase>> instructions;
lite::Scope *scope{};
};
// An Graph for MIR. It is built from a list of Op and a scope.
......@@ -59,17 +66,19 @@ class SSAGraph : GraphBase {
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto &new_kernel = node_storage_.back().AsInstruct(op->op_type_);
new_kernel.valid_kernels = op->CreateKernels(valid_places);
node_storage_.back().AsInstruct(
op->op_type_, op->CreateKernels(valid_places), op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs
for (const std::string &name : op->input_names()) {
new_node.inlinks.push_back(arguments_.at(name));
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = arguments_.at(name);
new_node.inlinks.push_back(arg);
arg->outlinks.push_back(&new_node);
}
for (const std::string &name : op->output_names()) {
for (const std::string &name : op->op_info()->output_names()) {
if (!arguments_.count(name)) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
......@@ -77,33 +86,35 @@ class SSAGraph : GraphBase {
arg.name = name;
arguments_.emplace(name, &new_node);
}
new_node.outlinks.push_back(arguments_.at(name));
auto *arg = arguments_.at(name);
new_node.outlinks.push_back(arg);
arg->inlinks.push_back(&new_node);
}
}
MarkArgumentWeights(program);
}
void sort_utils(mir::Node *n, std::map<mir::Node *, bool> &visited,
std::stack<mir::Node *> &stack) {
visited[n] = true;
for (auto &out : n->outlinks) {
if (!visited[out]) {
sort_utils(out, visited, stack);
std::vector<mir::Node *> InstructTopologicalOrder();
// The inputs of the graph.
std::vector<mir::Node *> inputs() {
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.inlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
std::vector<mir::Node *> TopoloticalOrder() {
std::map<mir::Node *, bool> visited;
std::stack<mir::Node *> stack;
// The outputs of the graph.
std::vector<mir::Node *> outputs() {
std::vector<mir::Node *> res;
for (auto &n : mutable_nodes()) {
if (!visited[&n]) sort_utils(&n, visited, stack);
}
while (!stack.empty()) {
res.push_back(stack.top());
stack.pop();
for (auto &node : node_storage_) {
if (node.outlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
......@@ -111,6 +122,31 @@ class SSAGraph : GraphBase {
const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; }
mir::Node *RetriveArgument(const std::string &arg) {
auto it = arguments_.find(arg);
if (it != arguments_.end()) {
return it->second;
}
return nullptr;
}
private:
// Check the bidirectional connection.
bool CheckBidirectionalConnection();
void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) {
arguments_[name]->AsArgument().is_weight = true;
}
}
// Build operator inlink edge table.
std::map<mir::Node *, std::set<mir::Node *>> BuildOperationAdjList();
void SortHelper(const std::map<mir::Node *, std::set<mir::Node *>> &adj_list,
mir::Node *node, std::set<mir::Node *> *visited,
std::vector<mir::Node *> *ret);
private:
std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_;
......
......@@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second));
instruct.place = instruct.valid_kernels.front()->place();
LOG(INFO) << "pick " << instruct.valid_kernels.front()->name();
}
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/variable_place_inference_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void VariablePlaceInferencePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
MarkInputPlace(graph.get());
InferenceArgumentPlace(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(variable_place_inference_pass,
paddle::lite::mir::VariablePlaceInferencePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* Mark the place of the variables in the SSAGrpah, it will inference the
* variables' place by the kernels outputs them.
*/
class VariablePlaceInferencePass : public DebugPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
private:
// Mark the place of input arguments.
void MarkInputPlace(SSAGraph* graph) {
for (const auto& v : graph->inputs()) {
// the feed op might in the inputs
if (v->IsInstruct()) {
LOG(INFO) << "found kernel in inputs " << v->AsInstruct().op_type;
continue;
}
auto& arg = v->AsArgument();
arg.place.target = argument_default_target_;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
void InferenceArgumentPlace(SSAGraph* graph) {
LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) {
auto& inst = x->AsInstruct();
CHECK(inst.place.is_valid())
<< "kernel's place should be set when loaded";
// deal with inputs
for (auto& arg_name : inst.op_info->input_argnames()) {
auto type =
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->input_argument().at(arg_name);
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
UpdatePlace(&arg_node.place, type->tensor_place);
}
}
for (auto& arg_name : inst.op_info->output_argnames()) {
auto type = ParamTypeRegistry::Global()
.Retrieve<ParamTypeRegistry::IO::kOutput>(
inst.place, inst.op_type, arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->output_argument().at(arg_name);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
UpdatePlace(&arg_node.place, type->tensor_place);
}
}
}
}
// Update me's kUnk fields by other's fields.
void UpdatePlace(Place* me, const Place& other) {
CHECK(other.is_valid());
if (me->target == TARGET(kUnk)) {
me->target = other.target;
}
if (me->precision == PRECISION(kUnk)) {
me->precision = other.precision;
}
if (me->layout == DATALAYOUT(kUnk)) {
me->layout = other.layout;
}
}
private:
// The default target for arguments, e.g. load weights to CPU memory for CUDA
// computation by default.
TargetType argument_default_target_{TARGET(kHost)};
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -54,5 +54,12 @@ bool OpLite::Run() {
return true;
}
bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
CHECK(!op_info_) << "op_info duplicate build found";
op_info_ = std::make_shared<OpInfo>();
op_info_->Build(opdesc);
return AttachImpl(opdesc, scope);
}
} // namespace lite
} // namespace paddle
......@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <boost/variant.hpp>
#include <map>
#include <memory>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
......@@ -41,6 +42,8 @@ class Node;
class SSAGraph;
}
class OpInfo;
/**
* The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework.
......@@ -78,10 +81,10 @@ class OpLite : public Registry {
// Run this operator.
virtual bool Run();
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
ExtractInputsAndOutputs(opdesc);
return AttachImpl(opdesc, scope);
}
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope);
const std::shared_ptr<OpInfo> &op_info() const { return op_info_; }
std::shared_ptr<OpInfo> &mutable_op_info() { return op_info_; }
// Human-readable information.
virtual std::string DebugString() const = 0;
......@@ -91,9 +94,6 @@ class OpLite : public Registry {
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
virtual ~OpLite() = default;
protected:
......@@ -101,19 +101,6 @@ class OpLite : public Registry {
virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0;
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (const auto &x : item.second) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (const auto &x : item.second) {
output_names_.push_back(x);
}
}
}
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
......@@ -141,8 +128,80 @@ class OpLite : public Registry {
std::string op_type_;
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::shared_ptr<OpInfo> op_info_;
};
/*
* Operator Information, such as some description. It will be shared by all the
* kernels of the same operator.
*/
class OpInfo {
public:
void Build(const framework::OpDesc &desc) {
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
}
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
const std::map<std::string, std::list<std::string>> &input_argument() {
return input_argument_;
}
const std::map<std::string, std::list<std::string>> &output_argument() {
return output_argument_;
}
const std::list<std::string> &input_argnames() const {
return input_argnames_;
}
const std::list<std::string> &output_argnames() const {
return output_argnames_;
}
private:
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (const auto &x : item.second) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (const auto &x : item.second) {
output_names_.push_back(x);
}
}
}
void CollectInputAndOutputArgnames(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.InputNames()) {
input_argnames_.push_back(item);
}
for (const auto &item : opdesc.OutputNames()) {
output_argnames_.push_back(item);
}
}
void CollectArguments(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (auto &x : item.second) {
input_argument_[item.first].push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (auto &x : item.second) {
output_argument_[item.first].push_back(x);
}
}
}
private:
std::list<std::string> input_names_;
std::list<std::string> output_names_;
std::list<std::string> input_argnames_;
std::list<std::string> output_argnames_;
std::map<std::string, std::list<std::string>> input_argument_;
std::map<std::string, std::list<std::string>> output_argument_;
};
} // namespace lite
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/core/optimizer.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
......@@ -36,6 +37,12 @@ TEST(Optimizer, test) {
.ConsiderPrecision();
optimizer.Run(std::move(program), places);
auto* program_pass =
mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
auto& kernels = program_pass->kernels();
LOG(INFO) << "get kernels: " << kernels.size();
}
} // namespace lite
......
......@@ -34,6 +34,14 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
return &x;
}
template <>
const Type*
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kHost,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static TensorFp32NCHWTy x(TargetType::kHost);
return &x;
}
template <>
const Type* Type::Get<UnsupportedTy>(TargetType target) {
return Get<false, false, TargetType::kHost, PrecisionType::kFloat,
......@@ -46,6 +54,10 @@ const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
case TargetType::kX86:
return Get<false, true, TargetType::kX86, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
case TargetType::kHost:
return Get<false, true, TargetType::kHost, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
default:
LOG(FATAL) << "unsupported target " << TargetToStr(target);
return nullptr;
......
......@@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute)
.BindInput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kX86))})
.BindOutput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kX86))})
.BindInput("Input",
{paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindInput("Bias", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindInput("W", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册