提交 12db9f3c 编写于 作者: S superjomn

make the predictor works from a faked model

上级 610ce3ae
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite)
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
......@@ -17,6 +17,7 @@
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace paddle {
......@@ -30,7 +31,6 @@ class Predictor {
void Build(const std::string& model_path,
const std::vector<Place>& valid_places) {
CHECK(!scope_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog;
LoadModel(model_path, scope_.get(), &prog);
framework::ProgramDesc prog_desc(prog);
......@@ -38,10 +38,31 @@ class Predictor {
Program program(prog_desc, scope_, valid_places);
Optimizer optimizer;
optimizer.Run(std::move(program), valid_places);
core::KernelPickFactor factor;
factor.ConsiderTarget();
optimizer.Run(std::move(program), valid_places, factor);
program_ = optimizer.GenRuntimeProgram();
}
// Get offset-th col of feed.
Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
}
const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto fetch_list = _fetch_list->Get<std::vector<Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
void Run() { program_->Run(); }
private:
......
......@@ -14,36 +14,22 @@
#include "paddle/fluid/lite/api/cxx_api.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
TEST(CXXApi, raw) {
Scope scope;
framework::proto::ProgramDesc prog;
LoadModel("/home/chunwei/project2/models/model2", &scope, &prog);
framework::ProgramDesc prog_desc(prog);
lite::Executor executor(&scope, {Place{TARGET(kHost), PRECISION(kFloat)}});
auto x = scope.Var("a")->GetMutable<Tensor>();
x->Resize({100, 100});
x->mutable_data<float>();
executor.PrepareWorkspace(prog_desc);
executor.Build(prog_desc);
executor.Run();
}
TEST(CXXApi, test) {
lite::Predictor predictor;
predictor.Build("/home/chunwei/project2/models/model2",
{Place{TARGET(kHost), PRECISION(kFloat)}});
auto* x = predictor.GetInputTensor("a");
x->Resize({100, 200});
x->mutable_data<float>();
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100});
input_tensor->mutable_data<float>();
predictor.Run();
}
} // namespace lite
......@@ -52,6 +38,10 @@ TEST(CXXApi, test) {
USE_LITE_OP(mul);
USE_LITE_OP(fc);
USE_LITE_OP(scale);
USE_LITE_KERNEL(fc, kHost, kFloat);
USE_LITE_KERNEL(mul, kHost, kFloat);
USE_LITE_KERNEL(scale, kHost, kFloat);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_KERNEL(fc, kHost, kFloat, def);
USE_LITE_KERNEL(mul, kHost, kFloat, def);
USE_LITE_KERNEL(scale, kHost, kFloat, def);
USE_LITE_KERNEL(feed, kHost, kFloat, def);
USE_LITE_KERNEL(fetch, kHost, kFloat, def);
cc_library(mir_node SRCS node.cc)
cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_passes
SRCS static_kernel_pick_pass.cc
......
......@@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place;
if (type->tensor_place != inst.place) {
LOG(INFO) << "found IO unmatched tensor";
if (type->tensor_place != in->AsArgument().place) {
LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name;
}
}
}
......
......@@ -36,7 +36,7 @@ class SSAGraph : GraphBase {
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) {
// create inputs
// create temporary nodes.
for (const auto &name : program.tmp_vars) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
......@@ -45,20 +45,33 @@ class SSAGraph : GraphBase {
arguments_[name] = &new_node;
}
// create weight nodes.
for (const auto &name : program.weights) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument();
arg.name = name;
arguments_[name] = &new_node;
}
for (auto &op : program.ops) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
node_storage_.back().AsInstruct(
op->op_type_, op->CreateKernels(valid_places), op, op->op_info());
auto kernels = op->CreateKernels(valid_places);
for (auto &kernel : kernels) {
op->AttachKernel(kernel.get());
}
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op,
op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = arguments_.at(name);
auto *arg = Argument(name);
new_node.inlinks.push_back(arg);
arg->outlinks.push_back(&new_node);
}
......@@ -79,6 +92,12 @@ class SSAGraph : GraphBase {
MarkArgumentWeights(program);
}
mir::Node *Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
return it->second;
}
std::vector<mir::Node *> InstructTopologicalOrder();
// The inputs of the graph.
......
......@@ -20,12 +20,9 @@ namespace paddle {
namespace lite {
TEST(executor, test) {
std::vector<OpLite::Place> valid_places{
OpLite::Place{TARGET(kHost), PRECISION(kFloat)}};
std::vector<Place> valid_places{Place{TARGET(kHost), PRECISION(kFloat)}};
Scope scope;
Executor executor(&scope, valid_places);
auto scope = std::make_shared<lite::Scope>();
framework::ProgramDesc program;
program.MutableBlock(0)->Var("x");
......@@ -42,19 +39,18 @@ TEST(executor, test) {
op_desc.SetAttr("in_num_col_dims", static_cast<int>(1));
program.Flush();
auto* w = scope.Var("w")->GetMutable<Tensor>();
auto* w = scope->Var("w")->GetMutable<Tensor>();
w->Resize({20, 20});
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* x = scope->Var("x")->GetMutable<Tensor>();
x->Resize({1, 10, 20});
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* bias = scope->Var("bias")->GetMutable<Tensor>();
bias->Resize({1, 20});
bias->mutable_data<float>();
w->mutable_data<float>();
x->mutable_data<float>();
executor.PrepareWorkspace(program);
executor.Build(program);
lite::Executor executor(program, scope, valid_places);
executor.Run();
}
......@@ -62,4 +58,4 @@ TEST(executor, test) {
} // namespace paddle
USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat);
USE_LITE_KERNEL(fc, kHost, kFloat, def);
......@@ -101,6 +101,9 @@ class OpLite : public Registry {
virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0;
// Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
......
......@@ -11,3 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace paddle {
namespace lite {
void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
auto* pass = mir::PassManager::Global().LookUp<mir::StaticKernelPickPass>(
"static_kernel_pick_pass");
CHECK(pass);
*pass->mutable_kernel_pick_factors() = factor;
}
} // namespace lite
} // namespace paddle
......@@ -19,6 +19,7 @@
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
namespace paddle {
namespace lite {
......@@ -30,11 +31,14 @@ namespace lite {
class Optimizer {
public:
void Run(Program&& program, const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor,
const std::vector<std::string>& passes = {}) {
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places);
SpecifyKernelPickTactic(kernel_pick_factor);
RunPasses();
exec_scope_ = program.exec_scope;
}
// Generate a new program based on the mir graph.
......@@ -42,7 +46,10 @@ class Optimizer {
std::unique_ptr<Program> res;
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
return pass->GenProgram();
auto program = pass->GenProgram();
CHECK(exec_scope_);
program->set_exec_scope(exec_scope_);
return program;
}
// Generate C++ code which combines the inference program, model and weights.
......@@ -54,6 +61,8 @@ class Optimizer {
}
protected:
void SpecifyKernelPickTactic(core::KernelPickFactor factor);
// Run the default passes registered in the PassManager.
void RunPasses() { mir::PassManager::Global().Run(graph_); }
......@@ -62,6 +71,7 @@ class Optimizer {
private:
std::unique_ptr<mir::SSAGraph> graph_;
lite::Scope* exec_scope_{};
};
} // namespace lite
......
......@@ -35,23 +35,22 @@ struct Program {
std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE not the root scope.
std::shared_ptr<lite::Scope> scope;
std::vector<Place> valid_places;
// Runtime scope.
lite::Scope* exec_scope{};
const framework::ProgramDesc desc;
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
Program(const framework::ProgramDesc& desc,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) {
scope = root;
const std::vector<Place>& valid_places)
: scope(root), valid_places(valid_places), desc(desc) {
PrepareWorkspace(desc);
Build(desc, valid_places);
}
std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(scope));
res->tmp_vars = tmp_vars;
res->weights = weights;
res->ops = ops;
std::unique_ptr<Program> res(new Program(desc, scope, valid_places));
return res;
}
......@@ -64,7 +63,7 @@ struct Program {
// Create operators.
for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type();
if (op_type == "feed" || op_type == "fetch") continue;
// if (op_type == "feed" || op_type == "fetch") continue;
LOG(INFO) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
......@@ -77,11 +76,22 @@ struct Program {
void PrepareWorkspace(const framework::ProgramDesc& program) {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope();
// Create Feed and Fetch var.
scope->Var("feed")->GetMutable<std::vector<Tensor>>();
scope->Var("fetch")->GetMutable<std::vector<Tensor>>();
tmp_vars.push_back("feed");
tmp_vars.push_back("fetch");
for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) {
LOG(INFO) << "get tmp var " << var_desc->Name();
tmp_vars.push_back(var_desc->Name());
auto* var = exec_scope->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
} else {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue;
LOG(INFO) << "get weight var " << var_desc->Name();
weights.push_back(var_desc->Name());
}
}
}
......@@ -118,11 +128,15 @@ class RuntimeProgram {
}
}
void set_exec_scope(lite::Scope* x) { exec_scope_ = x; }
lite::Scope* exec_scope() { return exec_scope_; }
size_t num_instructions() const { return instructions_.size(); }
private:
RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_;
lite::Scope* exec_scope_{};
};
} // namespace lite
......
......@@ -48,6 +48,45 @@ const Type* Type::Get<UnsupportedTy>(TargetType target) {
DataLayoutType::kNCHW>();
}
template <TargetType Target>
TensorListAnyTy* GetTensorListAnyTy() {
static TensorListAnyTy x(Target);
return &x;
}
template <TargetType Target>
TensorAnyTy* GetTensorAnyTy() {
static TensorAnyTy x(Target);
return &x;
}
template <>
const Type* Type::Get<TensorListAnyTy>(TargetType target) {
switch (target) {
case TargetType::kHost:
return GetTensorListAnyTy<TARGET(kHost)>();
case TargetType::kCUDA:
return GetTensorListAnyTy<TARGET(kCUDA)>();
case TargetType::kX86:
return GetTensorListAnyTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported type";
}
}
template <>
const Type* Type::Get<TensorAnyTy>(TargetType target) {
switch (target) {
case TargetType::kHost:
return GetTensorAnyTy<TARGET(kHost)>();
case TargetType::kCUDA:
return GetTensorAnyTy<TARGET(kCUDA)>();
case TargetType::kX86:
return GetTensorAnyTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported type";
}
}
template <>
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) {
......
......@@ -60,6 +60,8 @@ class DataTypeBase {
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data.
Tensor_Any,
// Used by feed or fetch op.
TensorList_Any,
NumTypes, // Must remains as last defined ID.
};
......@@ -146,6 +148,13 @@ class TensorAnyTy : public Type {
: Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny),
DATALAYOUT(kAny)) {}
};
// A list of tensor, and no assumption on the data layout or data type.
class TensorListAnyTy : public Type {
public:
TensorListAnyTy(TargetType target)
: Type(ID::TensorList_Any, "TensorList_Any", false, target,
PRECISION(kAny), DATALAYOUT(kAny)) {}
};
class TensorFp32NCHWTy : public Type {
public:
TensorFp32NCHWTy(TargetType target)
......
......@@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps})
cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps})
cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps})
cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
cc_library(host_kernels DEPS
feed_compute_host
fetch_compute_host
fc_compute_host
relu_compute_host
mul_compute_host
scale_compute_host
feed_compute_host
DEPS ${lite_kernel_deps}
)
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
......@@ -26,9 +25,9 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::FeedParam;
void Run() override {
auto &theparam = Param<operators::FeedParam>();
const Tensor &feed_item = theparam.feed_list->at(theparam.col);
theparam.out->CopyDataFrom(feed_item);
auto &param = Param<operators::FeedParam>();
const Tensor &feed_item = param.feed_list->at(param.col);
param.out->CopyDataFrom(feed_item);
}
};
......@@ -39,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::FeedParam;
void Run() override {
auto& param = Param<operators::FetchParam>();
auto* fetch_list = param.fetch_list;
if (fetch_list->size() <= static_cast<size_t>(param.col)) {
fetch_list->resize(param.col + 1);
}
auto& dst = fetch_list->at(param.col);
dst.CopyDataFrom(*param.input);
}
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(fetch, kHost, kFloat,
paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorListAnyTy>(
TARGET(kHost))})
.Finalize();
......@@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(scale, kHost, kFloat,
paddle::lite::kernels::host::ScaleCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize();
......@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite)
cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite)
cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite)
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS op_lite)
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
......@@ -12,6 +13,7 @@ cc_library(ops_lite DEPS
mul_op_lite
scale_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
)
......
......@@ -67,6 +67,8 @@ class FcOpLite : public OpLite {
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "fc"; }
private:
......
......@@ -31,6 +31,9 @@ class FeedOp : public OpLite {
bool InferShape() const override { return true; }
protected:
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected:
bool AttachImpl(const framework::OpDesc& opdesc,
lite::Scope* scope) override {
......@@ -48,6 +51,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_.col = boost::get<int>(opdesc.GetAttr("col"));
kernel_->SetParam(param_);
return true;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class FetchOp : public OpLite {
public:
explicit FetchOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.fetch_list);
return true;
}
bool InferShape() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected:
bool AttachImpl(const framework::OpDesc& opdesc,
lite::Scope* scope) override {
auto _x = opdesc.Input("X").front();
auto* x = scope->FindVar(_x);
CHECK(x);
param_.input = &x->Get<Tensor>();
auto _out = opdesc.Output("Out").front();
auto* out = scope->FindVar(_out);
param_.fetch_list = out->GetMutable<std::vector<lite::Tensor>>();
param_.col = boost::get<int>(opdesc.GetAttr("col"));
return true;
}
std::string DebugString() const override { return "fetch"; }
private:
mutable FetchParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fetch, paddle::lite::operators::FetchOp);
......@@ -28,6 +28,8 @@ class IoCopyOp : public OpLite {
bool Run() override;
std::string DebugString() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
protected:
bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -36,6 +36,7 @@ class MulOpLite : public OpLite {
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
......
......@@ -25,8 +25,14 @@ namespace lite {
namespace operators {
struct FeedParam {
const std::vector<Tensor>* feed_list;
Tensor* out;
const std::vector<Tensor>* feed_list{};
Tensor* out{};
int col;
};
struct FetchParam {
const Tensor* input{};
std::vector<Tensor>* fetch_list{};
int col;
};
......@@ -69,8 +75,8 @@ struct IoCopyParam {
Tensor* y{};
};
using param_t =
variant<FeedParam, FcParam, ReluParam, MulParam, ScaleParam, IoCopyParam>;
using param_t = variant<FeedParam, FetchParam, FcParam, ReluParam, MulParam,
ScaleParam, IoCopyParam>;
} // namespace operators
} // namespace lite
......
......@@ -34,6 +34,7 @@ class ReluOp : public OpLite {
bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "tanh"; }
private:
......
......@@ -43,6 +43,8 @@ class ScaleOp : public OpLite {
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册