提交 12db9f3c 编写于 作者: S superjomn

make the predictor works from a faked model

上级 610ce3ae
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite) cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/optimizer.h" #include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h" #include "paddle/fluid/lite/model_parser/model_parser.h"
namespace paddle { namespace paddle {
...@@ -30,7 +31,6 @@ class Predictor { ...@@ -30,7 +31,6 @@ class Predictor {
void Build(const std::string& model_path, void Build(const std::string& model_path,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
CHECK(!scope_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog; framework::proto::ProgramDesc prog;
LoadModel(model_path, scope_.get(), &prog); LoadModel(model_path, scope_.get(), &prog);
framework::ProgramDesc prog_desc(prog); framework::ProgramDesc prog_desc(prog);
...@@ -38,10 +38,31 @@ class Predictor { ...@@ -38,10 +38,31 @@ class Predictor {
Program program(prog_desc, scope_, valid_places); Program program(prog_desc, scope_, valid_places);
Optimizer optimizer; Optimizer optimizer;
optimizer.Run(std::move(program), valid_places); core::KernelPickFactor factor;
factor.ConsiderTarget();
optimizer.Run(std::move(program), valid_places, factor);
program_ = optimizer.GenRuntimeProgram(); program_ = optimizer.GenRuntimeProgram();
} }
// Get offset-th col of feed.
Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
}
const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto fetch_list = _fetch_list->Get<std::vector<Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
void Run() { program_->Run(); } void Run() { program_->Run(); }
private: private:
......
...@@ -14,36 +14,22 @@ ...@@ -14,36 +14,22 @@
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_executor.h" #include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
TEST(CXXApi, raw) {
Scope scope;
framework::proto::ProgramDesc prog;
LoadModel("/home/chunwei/project2/models/model2", &scope, &prog);
framework::ProgramDesc prog_desc(prog);
lite::Executor executor(&scope, {Place{TARGET(kHost), PRECISION(kFloat)}});
auto x = scope.Var("a")->GetMutable<Tensor>();
x->Resize({100, 100});
x->mutable_data<float>();
executor.PrepareWorkspace(prog_desc);
executor.Build(prog_desc);
executor.Run();
}
TEST(CXXApi, test) { TEST(CXXApi, test) {
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build("/home/chunwei/project2/models/model2", predictor.Build("/home/chunwei/project2/models/model2",
{Place{TARGET(kHost), PRECISION(kFloat)}}); {Place{TARGET(kHost), PRECISION(kFloat)}});
auto* x = predictor.GetInputTensor("a");
x->Resize({100, 200}); auto* input_tensor = predictor.GetInput(0);
x->mutable_data<float>(); input_tensor->Resize({100, 100});
input_tensor->mutable_data<float>();
predictor.Run();
} }
} // namespace lite } // namespace lite
...@@ -52,6 +38,10 @@ TEST(CXXApi, test) { ...@@ -52,6 +38,10 @@ TEST(CXXApi, test) {
USE_LITE_OP(mul); USE_LITE_OP(mul);
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_OP(scale); USE_LITE_OP(scale);
USE_LITE_KERNEL(fc, kHost, kFloat); USE_LITE_OP(feed);
USE_LITE_KERNEL(mul, kHost, kFloat); USE_LITE_OP(fetch);
USE_LITE_KERNEL(scale, kHost, kFloat); USE_LITE_KERNEL(fc, kHost, kFloat, def);
USE_LITE_KERNEL(mul, kHost, kFloat, def);
USE_LITE_KERNEL(scale, kHost, kFloat, def);
USE_LITE_KERNEL(feed, kHost, kFloat, def);
USE_LITE_KERNEL(fetch, kHost, kFloat, def);
cc_library(mir_node SRCS node.cc) cc_library(mir_node SRCS node.cc)
cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_passes cc_library(mir_passes
SRCS static_kernel_pick_pass.cc SRCS static_kernel_pick_pass.cc
......
...@@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
inst.place, inst.op_type, tmp); inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place; << " " << inst.place;
if (type->tensor_place != inst.place) { if (type->tensor_place != in->AsArgument().place) {
LOG(INFO) << "found IO unmatched tensor"; LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name;
} }
} }
} }
......
...@@ -36,7 +36,7 @@ class SSAGraph : GraphBase { ...@@ -36,7 +36,7 @@ class SSAGraph : GraphBase {
// @param program: the op program // @param program: the op program
// @param valid_places: the valid places user set for the system. // @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) { void Build(const Program &program, const std::vector<Place> &valid_places) {
// create inputs // create temporary nodes.
for (const auto &name : program.tmp_vars) { for (const auto &name : program.tmp_vars) {
node_storage_.emplace_back(); node_storage_.emplace_back();
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
...@@ -45,20 +45,33 @@ class SSAGraph : GraphBase { ...@@ -45,20 +45,33 @@ class SSAGraph : GraphBase {
arguments_[name] = &new_node; arguments_[name] = &new_node;
} }
// create weight nodes.
for (const auto &name : program.weights) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument();
arg.name = name;
arguments_[name] = &new_node;
}
for (auto &op : program.ops) { for (auto &op : program.ops) {
node_storage_.emplace_back(); node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here. // TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places); op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
node_storage_.back().AsInstruct( auto kernels = op->CreateKernels(valid_places);
op->op_type_, op->CreateKernels(valid_places), op, op->op_info()); for (auto &kernel : kernels) {
op->AttachKernel(kernel.get());
}
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op,
op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs // collect inputs and outputs
for (const std::string &name : op->op_info()->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
auto *arg = arguments_.at(name); auto *arg = Argument(name);
new_node.inlinks.push_back(arg); new_node.inlinks.push_back(arg);
arg->outlinks.push_back(&new_node); arg->outlinks.push_back(&new_node);
} }
...@@ -79,6 +92,12 @@ class SSAGraph : GraphBase { ...@@ -79,6 +92,12 @@ class SSAGraph : GraphBase {
MarkArgumentWeights(program); MarkArgumentWeights(program);
} }
mir::Node *Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
return it->second;
}
std::vector<mir::Node *> InstructTopologicalOrder(); std::vector<mir::Node *> InstructTopologicalOrder();
// The inputs of the graph. // The inputs of the graph.
......
...@@ -20,12 +20,9 @@ namespace paddle { ...@@ -20,12 +20,9 @@ namespace paddle {
namespace lite { namespace lite {
TEST(executor, test) { TEST(executor, test) {
std::vector<OpLite::Place> valid_places{ std::vector<Place> valid_places{Place{TARGET(kHost), PRECISION(kFloat)}};
OpLite::Place{TARGET(kHost), PRECISION(kFloat)}};
Scope scope; auto scope = std::make_shared<lite::Scope>();
Executor executor(&scope, valid_places);
framework::ProgramDesc program; framework::ProgramDesc program;
program.MutableBlock(0)->Var("x"); program.MutableBlock(0)->Var("x");
...@@ -42,19 +39,18 @@ TEST(executor, test) { ...@@ -42,19 +39,18 @@ TEST(executor, test) {
op_desc.SetAttr("in_num_col_dims", static_cast<int>(1)); op_desc.SetAttr("in_num_col_dims", static_cast<int>(1));
program.Flush(); program.Flush();
auto* w = scope.Var("w")->GetMutable<Tensor>(); auto* w = scope->Var("w")->GetMutable<Tensor>();
w->Resize({20, 20}); w->Resize({20, 20});
auto* x = scope.Var("x")->GetMutable<Tensor>(); auto* x = scope->Var("x")->GetMutable<Tensor>();
x->Resize({1, 10, 20}); x->Resize({1, 10, 20});
auto* bias = scope.Var("bias")->GetMutable<Tensor>(); auto* bias = scope->Var("bias")->GetMutable<Tensor>();
bias->Resize({1, 20}); bias->Resize({1, 20});
bias->mutable_data<float>(); bias->mutable_data<float>();
w->mutable_data<float>(); w->mutable_data<float>();
x->mutable_data<float>(); x->mutable_data<float>();
executor.PrepareWorkspace(program); lite::Executor executor(program, scope, valid_places);
executor.Build(program);
executor.Run(); executor.Run();
} }
...@@ -62,4 +58,4 @@ TEST(executor, test) { ...@@ -62,4 +58,4 @@ TEST(executor, test) {
} // namespace paddle } // namespace paddle
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat); USE_LITE_KERNEL(fc, kHost, kFloat, def);
...@@ -101,6 +101,9 @@ class OpLite : public Registry { ...@@ -101,6 +101,9 @@ class OpLite : public Registry {
virtual bool AttachImpl(const framework::OpDesc &opdesc, virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0; lite::Scope *scope) = 0;
// Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0;
// Specify the kernel to run by default. This will specify the value of // Specify the kernel to run by default. This will specify the value of
// `kernel_place_`. // `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) { virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
......
...@@ -11,3 +11,19 @@ ...@@ -11,3 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace paddle {
namespace lite {
void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
auto* pass = mir::PassManager::Global().LookUp<mir::StaticKernelPickPass>(
"static_kernel_pick_pass");
CHECK(pass);
*pass->mutable_kernel_pick_factors() = factor;
}
} // namespace lite
} // namespace paddle
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -30,11 +31,14 @@ namespace lite { ...@@ -30,11 +31,14 @@ namespace lite {
class Optimizer { class Optimizer {
public: public:
void Run(Program&& program, const std::vector<Place>& valid_places, void Run(Program&& program, const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor,
const std::vector<std::string>& passes = {}) { const std::vector<std::string>& passes = {}) {
CHECK(!graph_) << "duplicate optimize found"; CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph); graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places); graph_->Build(program, valid_places);
SpecifyKernelPickTactic(kernel_pick_factor);
RunPasses(); RunPasses();
exec_scope_ = program.exec_scope;
} }
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
...@@ -42,7 +46,10 @@ class Optimizer { ...@@ -42,7 +46,10 @@ class Optimizer {
std::unique_ptr<Program> res; std::unique_ptr<Program> res;
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>( auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass"); "generate_program_pass");
return pass->GenProgram(); auto program = pass->GenProgram();
CHECK(exec_scope_);
program->set_exec_scope(exec_scope_);
return program;
} }
// Generate C++ code which combines the inference program, model and weights. // Generate C++ code which combines the inference program, model and weights.
...@@ -54,6 +61,8 @@ class Optimizer { ...@@ -54,6 +61,8 @@ class Optimizer {
} }
protected: protected:
void SpecifyKernelPickTactic(core::KernelPickFactor factor);
// Run the default passes registered in the PassManager. // Run the default passes registered in the PassManager.
void RunPasses() { mir::PassManager::Global().Run(graph_); } void RunPasses() { mir::PassManager::Global().Run(graph_); }
...@@ -62,6 +71,7 @@ class Optimizer { ...@@ -62,6 +71,7 @@ class Optimizer {
private: private:
std::unique_ptr<mir::SSAGraph> graph_; std::unique_ptr<mir::SSAGraph> graph_;
lite::Scope* exec_scope_{};
}; };
} // namespace lite } // namespace lite
......
...@@ -35,23 +35,22 @@ struct Program { ...@@ -35,23 +35,22 @@ struct Program {
std::list<std::shared_ptr<OpLite>> ops; std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE not the root scope. // the scope to run the kernels, NOTE not the root scope.
std::shared_ptr<lite::Scope> scope; std::shared_ptr<lite::Scope> scope;
std::vector<Place> valid_places;
// Runtime scope. // Runtime scope.
lite::Scope* exec_scope{}; lite::Scope* exec_scope{};
const framework::ProgramDesc desc;
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; } explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
Program(const framework::ProgramDesc& desc, Program(const framework::ProgramDesc& desc,
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places)
scope = root; : scope(root), valid_places(valid_places), desc(desc) {
PrepareWorkspace(desc); PrepareWorkspace(desc);
Build(desc, valid_places); Build(desc, valid_places);
} }
std::unique_ptr<Program> Clone() const { std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(scope)); std::unique_ptr<Program> res(new Program(desc, scope, valid_places));
res->tmp_vars = tmp_vars;
res->weights = weights;
res->ops = ops;
return res; return res;
} }
...@@ -64,7 +63,7 @@ struct Program { ...@@ -64,7 +63,7 @@ struct Program {
// Create operators. // Create operators.
for (auto* op_desc : program.Block(0).AllOps()) { for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type(); auto op_type = op_desc->Type();
if (op_type == "feed" || op_type == "fetch") continue; // if (op_type == "feed" || op_type == "fetch") continue;
LOG(INFO) << "create Op [" << op_type << "]"; LOG(INFO) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel // pick initial kernel
...@@ -77,11 +76,22 @@ struct Program { ...@@ -77,11 +76,22 @@ struct Program {
void PrepareWorkspace(const framework::ProgramDesc& program) { void PrepareWorkspace(const framework::ProgramDesc& program) {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found"; CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope(); exec_scope = &scope->NewScope();
// Create Feed and Fetch var.
scope->Var("feed")->GetMutable<std::vector<Tensor>>();
scope->Var("fetch")->GetMutable<std::vector<Tensor>>();
tmp_vars.push_back("feed");
tmp_vars.push_back("fetch");
for (auto var_desc : program.Block(0).AllVars()) { for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) { if (!var_desc->Persistable()) {
LOG(INFO) << "get tmp var " << var_desc->Name();
tmp_vars.push_back(var_desc->Name());
auto* var = exec_scope->Var(var_desc->Name()); auto* var = exec_scope->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var; LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
} else {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue;
LOG(INFO) << "get weight var " << var_desc->Name();
weights.push_back(var_desc->Name());
} }
} }
} }
...@@ -118,11 +128,15 @@ class RuntimeProgram { ...@@ -118,11 +128,15 @@ class RuntimeProgram {
} }
} }
void set_exec_scope(lite::Scope* x) { exec_scope_ = x; }
lite::Scope* exec_scope() { return exec_scope_; }
size_t num_instructions() const { return instructions_.size(); } size_t num_instructions() const { return instructions_.size(); }
private: private:
RuntimeProgram(const RuntimeProgram&) = delete; RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_; std::vector<Instruction> instructions_;
lite::Scope* exec_scope_{};
}; };
} // namespace lite } // namespace lite
......
...@@ -48,6 +48,45 @@ const Type* Type::Get<UnsupportedTy>(TargetType target) { ...@@ -48,6 +48,45 @@ const Type* Type::Get<UnsupportedTy>(TargetType target) {
DataLayoutType::kNCHW>(); DataLayoutType::kNCHW>();
} }
template <TargetType Target>
TensorListAnyTy* GetTensorListAnyTy() {
static TensorListAnyTy x(Target);
return &x;
}
template <TargetType Target>
TensorAnyTy* GetTensorAnyTy() {
static TensorAnyTy x(Target);
return &x;
}
template <>
const Type* Type::Get<TensorListAnyTy>(TargetType target) {
switch (target) {
case TargetType::kHost:
return GetTensorListAnyTy<TARGET(kHost)>();
case TargetType::kCUDA:
return GetTensorListAnyTy<TARGET(kCUDA)>();
case TargetType::kX86:
return GetTensorListAnyTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported type";
}
}
template <>
const Type* Type::Get<TensorAnyTy>(TargetType target) {
switch (target) {
case TargetType::kHost:
return GetTensorAnyTy<TARGET(kHost)>();
case TargetType::kCUDA:
return GetTensorAnyTy<TARGET(kCUDA)>();
case TargetType::kX86:
return GetTensorAnyTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported type";
}
}
template <> template <>
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) { const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) { switch (target) {
......
...@@ -60,6 +60,8 @@ class DataTypeBase { ...@@ -60,6 +60,8 @@ class DataTypeBase {
// Tensor_Any represents a Tensor with any place, data, layout. It is used // Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data. // in some IO kernels those doesn't care the data.
Tensor_Any, Tensor_Any,
// Used by feed or fetch op.
TensorList_Any,
NumTypes, // Must remains as last defined ID. NumTypes, // Must remains as last defined ID.
}; };
...@@ -146,6 +148,13 @@ class TensorAnyTy : public Type { ...@@ -146,6 +148,13 @@ class TensorAnyTy : public Type {
: Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny), : Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny),
DATALAYOUT(kAny)) {} DATALAYOUT(kAny)) {}
}; };
// A list of tensor, and no assumption on the data layout or data type.
class TensorListAnyTy : public Type {
public:
TensorListAnyTy(TargetType target)
: Type(ID::TensorList_Any, "TensorList_Any", false, target,
PRECISION(kAny), DATALAYOUT(kAny)) {}
};
class TensorFp32NCHWTy : public Type { class TensorFp32NCHWTy : public Type {
public: public:
TensorFp32NCHWTy(TargetType target) TensorFp32NCHWTy(TargetType target)
......
...@@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps}) ...@@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps}) cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps})
cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps}) cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps})
cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps})
cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
cc_library(host_kernels DEPS cc_library(host_kernels DEPS
feed_compute_host
fetch_compute_host
fc_compute_host fc_compute_host
relu_compute_host relu_compute_host
mul_compute_host mul_compute_host
scale_compute_host scale_compute_host
feed_compute_host
DEPS ${lite_kernel_deps} DEPS ${lite_kernel_deps}
) )
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
...@@ -26,9 +25,9 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -26,9 +25,9 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::FeedParam; using param_t = operators::FeedParam;
void Run() override { void Run() override {
auto &theparam = Param<operators::FeedParam>(); auto &param = Param<operators::FeedParam>();
const Tensor &feed_item = theparam.feed_list->at(theparam.col); const Tensor &feed_item = param.feed_list->at(param.col);
theparam.out->CopyDataFrom(feed_item); param.out->CopyDataFrom(feed_item);
} }
}; };
...@@ -39,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -39,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(feed, kHost, kFloat, REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle::lite::kernels::host::FeedCompute, def) paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))}) TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))}) TARGET(kHost))})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::FeedParam;
void Run() override {
auto& param = Param<operators::FetchParam>();
auto* fetch_list = param.fetch_list;
if (fetch_list->size() <= static_cast<size_t>(param.col)) {
fetch_list->resize(param.col + 1);
}
auto& dst = fetch_list->at(param.col);
dst.CopyDataFrom(*param.input);
}
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(fetch, kHost, kFloat,
paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorListAnyTy>(
TARGET(kHost))})
.Finalize();
...@@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(scale, kHost, kFloat, REGISTER_LITE_KERNEL(scale, kHost, kFloat,
paddle::lite::kernels::host::ScaleCompute, def) paddle::lite::kernels::host::ScaleCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) ...@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite)
cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite)
cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite) cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite)
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS op_lite)
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
...@@ -12,6 +13,7 @@ cc_library(ops_lite DEPS ...@@ -12,6 +13,7 @@ cc_library(ops_lite DEPS
mul_op_lite mul_op_lite
scale_op_lite scale_op_lite
feed_op_lite feed_op_lite
fetch_op_lite
io_copy_op_lite io_copy_op_lite
) )
......
...@@ -67,6 +67,8 @@ class FcOpLite : public OpLite { ...@@ -67,6 +67,8 @@ class FcOpLite : public OpLite {
return true; return true;
} }
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "fc"; } std::string DebugString() const override { return "fc"; }
private: private:
......
...@@ -31,6 +31,9 @@ class FeedOp : public OpLite { ...@@ -31,6 +31,9 @@ class FeedOp : public OpLite {
bool InferShape() const override { return true; } bool InferShape() const override { return true; }
protected:
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const framework::OpDesc& opdesc, bool AttachImpl(const framework::OpDesc& opdesc,
lite::Scope* scope) override { lite::Scope* scope) override {
...@@ -48,6 +51,7 @@ class FeedOp : public OpLite { ...@@ -48,6 +51,7 @@ class FeedOp : public OpLite {
// NOTE need boost here // NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc // TODO(Superjomn) drop the need of framework::op_desc
param_.col = boost::get<int>(opdesc.GetAttr("col")); param_.col = boost::get<int>(opdesc.GetAttr("col"));
kernel_->SetParam(param_);
return true; return true;
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class FetchOp : public OpLite {
public:
explicit FetchOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.fetch_list);
return true;
}
bool InferShape() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected:
bool AttachImpl(const framework::OpDesc& opdesc,
lite::Scope* scope) override {
auto _x = opdesc.Input("X").front();
auto* x = scope->FindVar(_x);
CHECK(x);
param_.input = &x->Get<Tensor>();
auto _out = opdesc.Output("Out").front();
auto* out = scope->FindVar(_out);
param_.fetch_list = out->GetMutable<std::vector<lite::Tensor>>();
param_.col = boost::get<int>(opdesc.GetAttr("col"));
return true;
}
std::string DebugString() const override { return "fetch"; }
private:
mutable FetchParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fetch, paddle::lite::operators::FetchOp);
...@@ -28,6 +28,8 @@ class IoCopyOp : public OpLite { ...@@ -28,6 +28,8 @@ class IoCopyOp : public OpLite {
bool Run() override; bool Run() override;
std::string DebugString() const override; std::string DebugString() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -36,6 +36,7 @@ class MulOpLite : public OpLite { ...@@ -36,6 +36,7 @@ class MulOpLite : public OpLite {
bool InferShape() const override; bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const framework::OpDesc &op_desc, bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override { lite::Scope *scope) override {
......
...@@ -25,8 +25,14 @@ namespace lite { ...@@ -25,8 +25,14 @@ namespace lite {
namespace operators { namespace operators {
struct FeedParam { struct FeedParam {
const std::vector<Tensor>* feed_list; const std::vector<Tensor>* feed_list{};
Tensor* out; Tensor* out{};
int col;
};
struct FetchParam {
const Tensor* input{};
std::vector<Tensor>* fetch_list{};
int col; int col;
}; };
...@@ -69,8 +75,8 @@ struct IoCopyParam { ...@@ -69,8 +75,8 @@ struct IoCopyParam {
Tensor* y{}; Tensor* y{};
}; };
using param_t = using param_t = variant<FeedParam, FetchParam, FcParam, ReluParam, MulParam,
variant<FeedParam, FcParam, ReluParam, MulParam, ScaleParam, IoCopyParam>; ScaleParam, IoCopyParam>;
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
......
...@@ -34,6 +34,7 @@ class ReluOp : public OpLite { ...@@ -34,6 +34,7 @@ class ReluOp : public OpLite {
bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "tanh"; } std::string DebugString() const override { return "tanh"; }
private: private:
......
...@@ -43,6 +43,8 @@ class ScaleOp : public OpLite { ...@@ -43,6 +43,8 @@ class ScaleOp : public OpLite {
return true; return true;
} }
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const framework::OpDesc &op_desc, bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override { lite::Scope *scope) override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册