提交 70540d1b 编写于 作者: S superjomn

add a new lightweight OpDesc compatible with the original framework::OpDesc

to support mobile
上级 6e19097b
...@@ -187,6 +187,7 @@ endif() ...@@ -187,6 +187,7 @@ endif()
# for lite # for lite
option(LITE_WITH_CUDA "Enable CUDA in lite mode" ON) option(LITE_WITH_CUDA "Enable CUDA in lite mode" ON)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON) option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" ON)
include(external/threadpool) include(external/threadpool)
include(flags) # set paddle compile flags include(flags) # set paddle compile flags
......
...@@ -171,3 +171,7 @@ endif() ...@@ -171,3 +171,7 @@ endif()
if (LITE_WITH_X86) if (LITE_WITH_X86)
add_definitions("-DLITE_WITH_X86") add_definitions("-DLITE_WITH_X86")
endif() endif()
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_definitions("-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK")
endif()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host )
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
if(LITE_WITH_CUDA) if(LITE_WITH_CUDA)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda) cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite)
else()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host )
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels)
endif() endif()
...@@ -33,9 +33,8 @@ class Predictor { ...@@ -33,9 +33,8 @@ class Predictor {
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
framework::proto::ProgramDesc prog; framework::proto::ProgramDesc prog;
LoadModel(model_path, scope_.get(), &prog); LoadModel(model_path, scope_.get(), &prog);
framework::ProgramDesc prog_desc(prog);
Program program(prog_desc, scope_, valid_places); Program program(prog, scope_, valid_places);
Optimizer optimizer; Optimizer optimizer;
optimizer.KernelPickPreferPlace(prefer_place); optimizer.KernelPickPreferPlace(prefer_place);
......
...@@ -5,10 +5,10 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) ...@@ -5,10 +5,10 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc) cc_library(scope_lite SRCS scope.cc)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite)
cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite
#TODO(Superjomn) remove these dependencies from original framework #TODO(Superjomn) remove these dependencies from original framework
proto_desc) )
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS tensor_lite) cc_library(type_system SRCS type_system.cc DEPS tensor_lite)
......
...@@ -65,19 +65,6 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, ...@@ -65,19 +65,6 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
} }
} }
void UpdateOpdescInputName(framework::OpDesc* desc,
const std::string& old_arg_name,
const std::string& new_arg_name) {
for (auto& item : *desc->Proto()->mutable_inputs()) {
for (int i = 0; i < item.mutable_arguments()->size(); i++) {
auto* arg = item.mutable_arguments(i);
if (*arg == old_arg_name) {
*arg = new_arg_name;
}
}
}
}
void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
const std::string& var, SSAGraph* graph, const std::string& var, SSAGraph* graph,
Node* inst_node, Node* inst_node,
...@@ -99,11 +86,10 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, ...@@ -99,11 +86,10 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
inst_node->AsInstruct().op->scope()->Var(io_copy_output_name); inst_node->AsInstruct().op->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction. // Create IoCopy Instruction.
framework::OpDesc op_desc; lite::OpDesc op_desc;
op_desc.SetType("io_copy"); op_desc.SetType("io_copy");
op_desc.SetInput("Input", {var}); op_desc.SetInput("Input", {var});
op_desc.SetOutput("Out", {io_copy_output_name}); op_desc.SetOutput("Out", {io_copy_output_name});
op_desc.Flush();
io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope()); io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope());
auto kernels = io_copy_op->CreateKernels(valid_places); auto kernels = io_copy_op->CreateKernels(valid_places);
...@@ -126,7 +112,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, ...@@ -126,7 +112,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc(); auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc();
UpdateInputTo(&desc_dummy, var, io_copy_output_name); UpdateInputTo(&desc_dummy, var, io_copy_output_name);
framework::OpDesc desc_fake(desc_dummy, nullptr); lite::OpDesc desc_fake(desc_dummy);
inst_node->AsInstruct().op->Attach(desc_fake, inst_node->AsInstruct().op->Attach(desc_fake,
inst_node->AsInstruct().op->scope()); inst_node->AsInstruct().op->scope());
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
/*
// The Executor is used to run the operators. // The Executor is used to run the operators.
class Executor { class Executor {
public: public:
...@@ -63,6 +64,7 @@ class RuntimeExecutor { ...@@ -63,6 +64,7 @@ class RuntimeExecutor {
private: private:
RuntimeProgram* program_{}; RuntimeProgram* program_{};
}; };
*/
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -76,7 +76,7 @@ bool OpLite::Run() { ...@@ -76,7 +76,7 @@ bool OpLite::Run() {
return true; return true;
} }
bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) {
CHECK(scope); CHECK(scope);
scope_ = scope; scope_ = scope;
op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. op_info_.reset(new OpInfo); // Force clean the out-of-date infomation.
......
...@@ -19,12 +19,11 @@ ...@@ -19,12 +19,11 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -82,7 +81,7 @@ class OpLite : public Registry { ...@@ -82,7 +81,7 @@ class OpLite : public Registry {
virtual bool Run(); virtual bool Run();
// Link the external execution environ to internal context. // Link the external execution environ to internal context.
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope); bool Attach(const OpDesc &opdesc, lite::Scope *scope);
const OpInfo *op_info() const { return op_info_.get(); } const OpInfo *op_info() const { return op_info_.get(); }
OpInfo *mutable_op_info() { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); }
...@@ -109,8 +108,7 @@ class OpLite : public Registry { ...@@ -109,8 +108,7 @@ class OpLite : public Registry {
protected: protected:
// Attach it with the runtime environment. // Attach it with the runtime environment.
virtual bool AttachImpl(const framework::OpDesc &opdesc, virtual bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) = 0;
lite::Scope *scope) = 0;
// Specify the kernel to run by default. This will specify the value of // Specify the kernel to run by default. This will specify the value of
// `kernel_place_`. // `kernel_place_`.
......
...@@ -38,10 +38,10 @@ struct Program { ...@@ -38,10 +38,10 @@ struct Program {
std::vector<Place> valid_places; std::vector<Place> valid_places;
// Runtime scope. // Runtime scope.
lite::Scope* exec_scope{}; lite::Scope* exec_scope{};
const framework::ProgramDesc desc; const framework::proto::ProgramDesc desc;
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; } explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
Program(const framework::ProgramDesc& desc, Program(const framework::proto::ProgramDesc& desc,
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) const std::vector<Place>& valid_places)
: scope(root), valid_places(valid_places), desc(desc) { : scope(root), valid_places(valid_places), desc(desc) {
...@@ -56,24 +56,25 @@ struct Program { ...@@ -56,24 +56,25 @@ struct Program {
private: private:
// Build from a program and scope. // Build from a program and scope.
void Build(const framework::ProgramDesc& program, void Build(const framework::proto::ProgramDesc& program,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
CHECK(ops.empty()) << "Executor duplicate Build found"; CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators. // Create operators.
for (auto* op_desc : program.Block(0).AllOps()) { for (const auto& proto_op_desc : program.blocks(0).ops()) {
auto op_type = op_desc->Type(); lite::OpDesc op_desc(proto_op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue; // if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]"; VLOG(4) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel // pick initial kernel
ops.back()->PickKernel(valid_places); ops.back()->PickKernel(valid_places);
ops.back()->Attach(*op_desc, exec_scope); ops.back()->Attach(op_desc, exec_scope);
} }
} }
// Create temporary variables. // Create temporary variables.
void PrepareWorkspace(const framework::ProgramDesc& program) { void PrepareWorkspace(const framework::proto::ProgramDesc& program) {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found"; CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope(); exec_scope = &scope->NewScope();
// Create Feed and Fetch var. // Create Feed and Fetch var.
...@@ -82,13 +83,14 @@ struct Program { ...@@ -82,13 +83,14 @@ struct Program {
tmp_vars.push_back("feed"); tmp_vars.push_back("feed");
tmp_vars.push_back("fetch"); tmp_vars.push_back("fetch");
for (auto var_desc : program.Block(0).AllVars()) { for (auto proto_var_desc : program.blocks(0).vars()) {
if (!var_desc->Persistable()) { lite::VarDesc var_desc(proto_var_desc);
tmp_vars.push_back(var_desc->Name()); if (!var_desc.Persistable()) {
exec_scope->Var(var_desc->Name()); tmp_vars.push_back(var_desc.Name());
exec_scope->Var(var_desc.Name());
} else { } else {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue; if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
weights.push_back(var_desc->Name()); weights.push_back(var_desc.Name());
} }
} }
} }
......
cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite) cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite)
cc_library(runtime_lite SRCS runtime.cc) cc_library(runtime_lite SRCS runtime.cc)
cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite) cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite)
else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_subdirectory(pb)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/*
* This file implements the interface to manipute the protobuf message. We use
* macros to make a compatible interface with the framework::XXDesc and
* lite::pb::XXDesc.
*/
#include "paddle/fluid/framework/framework.pb.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
namespace paddle {
namespace lite {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using OpDesc = lite::pb::OpDesc;
using VarDesc = lite::pb::VarDesc;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using Attribute = framework::Attribute;
using OpDesc = framework::OpDesc;
using VarDesc = framework::VarDesc;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
} // namespace lite
} // namespace paddle
cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto)
cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace lite {
class ProgramDesc;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method).
class BlockDesc {
public:
BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc);
BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ProgramDesc *prog);
int32_t ID() const { return desc_->idx(); }
int32_t Parent() const { return desc_->parent_idx(); }
int32_t ForwardBlockID() const { return desc_->forward_block_idx(); }
VarDesc *Var(const std::string &name_bytes);
VarDesc *FindVar(const std::string &name_bytes) const;
bool HasVar(const std::string &var_name) const;
VarDesc *RenameVar(const std::string &old_name, const std::string &new_name);
VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);
bool HasVarRecursive(const std::string &var_name) const;
std::set<std::string> LocalVarNames() const {
std::set<std::string> var_names;
for (auto &var : vars_) {
var_names.insert(var.first);
}
return var_names;
}
std::vector<VarDesc *> AllVars() const;
BlockDesc *ParentBlock() const;
BlockDesc *ForwardBlock() const;
void SetForwardBlockID(int32_t forward_block_id);
OpDesc *AppendOp();
void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
OpDesc *PrependOp();
void PrependAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
OpDesc *InsertOp(size_t index);
/*
* Only remove op itself,
* do nothing to its input and output variables
*/
void RemoveOp(size_t s, size_t e);
void RemoveOpInternal(const OpDesc *op_desc);
void RemoveVar(const std::string &name) { vars_.erase(name); }
std::vector<OpDesc *> AllOps() const;
size_t OpSize() const { return ops_.size(); }
OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
void Flush();
proto::BlockDesc *Proto();
ProgramDesc *Program() const { return this->prog_; }
private:
ProgramDesc *prog_; // not_own
proto::BlockDesc *desc_; // not_own
bool need_update_;
std::deque<std::unique_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_;
DISABLE_COPY_AND_ASSIGN(BlockDesc);
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/*
* This file implements a light-weight OpDesc like the framework::OpDesc. We
* delete the unnecessary methods, and remove the underlying dependencies, such
* as framework::Operator and boost::varient to make it runnable in mobile.
*/
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace pb {
using Attribute = variant<int, float, bool, std::vector<std::string>>;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
/*
* The lite::OpDesc, an light-weight implementation of wrapper of proto::OpDesc.
* Unlike the original one in framework::OpDesc, we remove the local members
* except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs.
*/
class OpDesc {
public:
OpDesc() {}
OpDesc(const framework::proto::OpDesc &desc) : desc_(desc) {}
void CopyFrom(const OpDesc &op_desc) { desc_ = op_desc.ReadonlyProto(); }
framework::proto::OpDesc *Proto() { return &desc_; }
const framework::proto::OpDesc &ReadonlyProto() const { return desc_; }
std::string Type() const { return desc_.type(); }
void SetType(const std::string &type) { desc_.set_type(type); }
// Get the arguments of parameter called `param`
std::vector<std::string> Input(const std::string &param) const {
return GetArguments(desc_.inputs(), param);
}
std::vector<std::string> InputArgumentNames() const {
return GetArgumentNames(desc_.inputs());
}
void SetInput(const std::string &param,
const std::vector<std::string> &args) {
SetArgument(desc_.mutable_inputs(), param, args);
}
std::vector<std::string> Output(const std::string &param) const {
return GetArguments(desc_.outputs(), param);
}
std::vector<std::string> OutputArgumentNames() const {
return GetArgumentNames(desc_.outputs());
}
void SetOutput(const std::string &param,
const std::vector<std::string> &args) {
SetArgument(desc_.mutable_outputs(), param, args);
}
bool HasAttr(const std::string &name) const {
const auto &xs = desc_.attrs();
auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
return it != xs.end();
}
framework::proto::AttrType GetAttrType(const std::string &name) const {
const auto &xs = desc_.attrs();
auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
CHECK(it != xs.end());
return it->type();
}
std::vector<std::string> AttrNames() const {
std::vector<std::string> res;
const auto &xs = desc_.attrs();
std::transform(
xs.begin(), xs.end(), std::back_inserter(res),
[](const framework::proto::OpDesc_Attr &x) { return x.name(); });
return res;
}
template <typename T>
void SetAttr(const std::string &name, const T &v) {
auto &xs = *desc_.mutable_attrs();
auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
if (it == xs.end()) {
auto *attr = xs.Add();
attr->set_name(name);
it = std::find(xs.begin(), xs.end(), name);
}
switch (typeid(T).hash_code()) {
case typeid(int).hash_code():
it->set_type(framework::proto::INT);
it->set_i(v);
break;
case typeid(float).hash_code():
it->set_type(framework::proto::FLOAT);
it->set_f(v);
break;
case typeid(std::string).hash_code():
it->set_type(framework::proto::STRING);
it->set_s(v.c_str());
break;
case typeid(std::string).hash_code():
it->set_type(framework::proto::BOOLEAN);
it->set_b(v);
break;
default:
LOG(FATAL) << "unsupport attr type";
}
}
Attribute GetAttr(const std::string &name) const {
auto &xs = desc_.attrs();
auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
Attribute res;
CHECK(it != xs.end());
switch (it->type()) {
case framework::proto::INT:
res.set<int>(it->i());
break;
case framework::proto::FLOAT:
res.set<float>(it->f());
break;
case framework::proto::STRING:
res.set<std::string>(it->s());
break;
case framework::proto::BOOLEAN:
res.set<bool>(it->b());
break;
default:
LOG(FATAL) << "unsupported attr type";
}
return res;
}
private:
std::vector<std::string> GetArguments(
const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var>
&xs,
const std::string &param) const {
std::vector<std::string> res;
auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Var &it) {
return it.parameter() == param;
});
CHECK(it != xs.end());
const auto &ys = it->arguments();
std::transform(ys.begin(), ys.end(), std::back_inserter(res),
[](const std::string &x) { return x; });
return res;
}
void SetArgument(
google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var> *xs,
const std::string &param, const std::vector<std::string> &args) {
auto it = std::find_if(xs->begin(), xs->end(),
[&](const framework::proto::OpDesc_Var &it) {
return it.parameter() == param;
});
if (it == xs->end()) {
auto *new_arg = xs->Add();
new_arg->set_parameter(param);
for (const auto &arg : args) {
*new_arg->mutable_arguments()->Add() = arg;
}
} else {
it->mutable_arguments()->Clear();
for (const auto &arg : args) {
*it->mutable_arguments()->Add() = arg;
}
}
}
std::vector<std::string> GetArgumentNames(
const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var>
&xs) const {
std::vector<std::string> res;
std::transform(
xs.begin(), xs.end(), std::back_inserter(res),
[](const framework::proto::OpDesc_Var &x) { return x.parameter(); });
return res;
}
private:
framework::proto::OpDesc desc_;
};
} // namespace pb
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
namespace paddle {
namespace lite {
namespace pb {
using namespace framework;
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
void VarDesc::SetType(proto::VarType::Type type) {
desc_.mutable_type()->set_type(type);
}
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}
void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type().type()) {
case proto::VarType::READER: {
auto *lod_tensors_ptr =
desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add();
}
return;
} break;
default:
LOG(FATAL) << "Setting 'sub_tensor_number' is not supported by the type "
"of var %s."
<< this->Name();
}
}
size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type().type()) {
case proto::VarType::READER:
return desc_.type().reader().lod_tensor_size();
break;
default:
LOG(FATAL) << "Getting 'sub_tensor_number' is not supported by the type "
"of var %s."
<< this->Name();
}
}
void VarDesc::SetShapes(
const std::vector<std::vector<int64_t>> &multiple_dims) {
if (multiple_dims.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size());
}
std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
}
std::vector<int64_t> VarDesc::GetShape() const {
return RepeatedToVector(tensor_desc().dims());
}
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(RepeatedToVector(tensor_desc.dims()));
}
return res;
}
void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}
void VarDesc::SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size());
}
std::vector<proto::VarType::TensorDesc *> tensor_descs =
mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
}
proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type();
}
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::VarType::Type> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
}
return res;
}
void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
break;
case proto::VarType::LOD_TENSOR_ARRAY:
desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
LOG(FATAL)
<< "Setting 'lod_level' is not supported by the type of var %s."
<< this->Name();
}
}
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
if (multiple_lod_level.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given lod_levels("
<< multiple_lod_level.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size());
}
switch (desc_.type().type()) {
case proto::VarType::READER: {
size_t i = 0;
for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]);
}
} break;
default:
LOG(FATAL)
<< "Setting 'lod_levels' is not supported by the type of var %s."
<< this->Name();
}
}
int32_t VarDesc::GetLoDLevel() const {
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
return desc_.type().lod_tensor().lod_level();
case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().lod_level();
default:
LOG(FATAL)
<< "Getting 'lod_level' is not supported by the type of var %s."
<< this->Name();
}
}
std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res;
switch (desc_.type().type()) {
case proto::VarType::READER:
res.reserve(desc_.type().reader().lod_tensor_size());
for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level());
}
return res;
break;
default:
LOG(FATAL)
<< "Getting 'lod_levels' is not supported by the type of var %s."
<< this->Name();
}
}
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
CHECK(desc_.has_type()) << "The var's type hasn't been set.";
CHECK(desc_.type().has_type()) << "The var type hasn't been set.";
switch (desc_.type().type()) {
case proto::VarType::SELECTED_ROWS:
return desc_.type().selected_rows();
case proto::VarType::LOD_TENSOR:
return desc_.type().lod_tensor().tensor();
case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().tensor();
default:
LOG(FATAL)
<< "Getting 'tensor_desc' is not supported by the type of var %s."
<< this->Name();
}
}
std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
CHECK(desc_.has_type()) << "The var type hasn't been set.";
std::vector<proto::VarType::TensorDesc> res;
res.reserve(GetTensorDescNum());
switch (desc_.type().type()) {
case proto::VarType::READER:
for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.tensor());
}
return res;
default:
LOG(FATAL)
<< "Getting 'tensor_descs' is not supported by the type of var "
"%s."
<< this->Name();
}
}
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
CHECK(desc_.has_type()) << "The var type hasn't been set.";
CHECK(desc_.type().has_type()) << "The var type hasn't been set.";
switch (desc_.type().type()) {
case proto::VarType::SELECTED_ROWS:
return desc_.mutable_type()->mutable_selected_rows();
case proto::VarType::LOD_TENSOR:
return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
default:
LOG(FATAL) << "Getting 'mutable_tensor_desc' is not supported by the "
"type of var "
"%s."
<< this->Name();
}
}
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
CHECK(desc_.has_type()) << "The var type hasn't been set.";
CHECK(desc_.type().has_type()) << "The var type hasn't been set.";
std::vector<proto::VarType::TensorDesc *> res;
res.reserve(GetTensorDescNum());
switch (desc_.type().type()) {
case proto::VarType::READER:
for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor());
}
return res;
default:
LOG(FATAL)
<< "Getting 'tensor_descs' is not supported by the type of var "
"%s."
<< this->Name();
}
}
} // namespace pb
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace lite {
namespace pb {
// convert between std::vector and protobuf repeated.
template <typename T>
inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T> &repeated_field) {
std::vector<T> ret;
ret.reserve(repeated_field.size());
std::copy(repeated_field.begin(), repeated_field.end(),
std::back_inserter(ret));
return ret;
}
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (const auto &elem : vec) {
*repeated_field->Add() = elem;
}
}
// Specialize vector<bool>.
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (auto elem : vec) {
*repeated_field->Add() = elem;
}
}
class VarDesc {
public:
explicit VarDesc(const std::string &name) {
desc_.set_name(name);
// TODO(paddle-dev): Why default to lodtensor.
desc_.mutable_type()->set_type(framework::proto::VarType::LOD_TENSOR);
}
explicit VarDesc(const framework::proto::VarDesc &desc) : desc_(desc) {}
framework::proto::VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); }
void SetName(std::string name) { desc_.set_name(name); }
void SetTensorDescNum(size_t num);
size_t GetTensorDescNum() const;
void SetShape(const std::vector<int64_t> &dims);
void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);
std::vector<int64_t> GetShape() const;
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(framework::proto::VarType::Type data_type);
void SetDataTypes(
const std::vector<framework::proto::VarType::Type> &multiple_data_type);
framework::proto::VarType::Type GetDataType() const;
std::vector<framework::proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level);
void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);
int32_t GetLoDLevel() const;
std::vector<int32_t> GetLoDLevels() const;
framework::proto::VarType::Type GetType() const;
void SetType(framework::proto::VarType::Type type);
bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private:
const framework::proto::VarType::TensorDesc &tensor_desc() const;
std::vector<framework::proto::VarType::TensorDesc> tensor_descs() const;
framework::proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<framework::proto::VarType::TensorDesc *> mutable_tensor_descs();
framework::proto::VarDesc desc_;
};
} // namespace pb
} // namespace framework
} // namespace paddle
...@@ -46,8 +46,7 @@ class FcOpLite : public OpLite { ...@@ -46,8 +46,7 @@ class FcOpLite : public OpLite {
*/ */
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const framework::OpDesc &op_desc, bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override {
lite::Scope *scope) override {
auto input = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front(); auto W = op_desc.Input("W").front();
auto bias = op_desc.Input("Bias").front(); auto bias = op_desc.Input("Bias").front();
...@@ -58,8 +57,7 @@ class FcOpLite : public OpLite { ...@@ -58,8 +57,7 @@ class FcOpLite : public OpLite {
param_.bias = scope->FindVar(bias)->GetMutable<Tensor>(); param_.bias = scope->FindVar(bias)->GetMutable<Tensor>();
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>(); param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.in_num_col_dims = param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims").get<int>();
boost::get<int>(op_desc.GetAttr("in_num_col_dims"));
CHECK(kernel_); CHECK(kernel_);
kernel_->SetParam(param_); kernel_->SetParam(param_);
......
...@@ -35,8 +35,7 @@ class FeedOp : public OpLite { ...@@ -35,8 +35,7 @@ class FeedOp : public OpLite {
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const framework::OpDesc& opdesc, bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
lite::Scope* scope) override {
auto feed_var_name = opdesc.Input("X").front(); auto feed_var_name = opdesc.Input("X").front();
auto* feed_var = scope->FindVar(feed_var_name); auto* feed_var = scope->FindVar(feed_var_name);
CHECK(feed_var); CHECK(feed_var);
...@@ -50,7 +49,7 @@ class FeedOp : public OpLite { ...@@ -50,7 +49,7 @@ class FeedOp : public OpLite {
// NOTE need boost here // NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc // TODO(Superjomn) drop the need of framework::op_desc
param_.col = boost::get<int>(opdesc.GetAttr("col")); param_.col = opdesc.GetAttr("col").get<int>();
return true; return true;
} }
......
...@@ -33,8 +33,7 @@ class FetchOp : public OpLite { ...@@ -33,8 +33,7 @@ class FetchOp : public OpLite {
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const framework::OpDesc& opdesc, bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
lite::Scope* scope) override {
auto _x = opdesc.Input("X").front(); auto _x = opdesc.Input("X").front();
auto* x = scope->FindVar(_x); auto* x = scope->FindVar(_x);
CHECK(x); CHECK(x);
...@@ -44,7 +43,7 @@ class FetchOp : public OpLite { ...@@ -44,7 +43,7 @@ class FetchOp : public OpLite {
auto* out = scope->FindVar(_out); auto* out = scope->FindVar(_out);
param_.fetch_list = out->GetMutable<std::vector<lite::Tensor>>(); param_.fetch_list = out->GetMutable<std::vector<lite::Tensor>>();
param_.col = boost::get<int>(opdesc.GetAttr("col")); param_.col = opdesc.GetAttr("col").get<int>();
return true; return true;
} }
......
...@@ -29,8 +29,7 @@ bool IoCopyOp::InferShape() const { ...@@ -29,8 +29,7 @@ bool IoCopyOp::InferShape() const {
return true; return true;
} }
bool IoCopyOp::Run() { return OpLite::Run(); } bool IoCopyOp::Run() { return OpLite::Run(); }
bool IoCopyOp::AttachImpl(const paddle::framework::OpDesc &opdesc, bool IoCopyOp::AttachImpl(const OpDesc &opdesc, paddle::lite::Scope *scope) {
paddle::lite::Scope *scope) {
auto x = opdesc.Input("Input").front(); auto x = opdesc.Input("Input").front();
auto out = opdesc.Output("Out").front(); auto out = opdesc.Output("Out").front();
param_.x = GetTensor(scope, x); param_.x = GetTensor(scope, x);
......
...@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite { ...@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
private: private:
operators::IoCopyParam param_; operators::IoCopyParam param_;
......
...@@ -38,8 +38,7 @@ class MulOpLite : public OpLite { ...@@ -38,8 +38,7 @@ class MulOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const framework::OpDesc &op_desc, bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override {
lite::Scope *scope) override {
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front(); auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
...@@ -48,8 +47,8 @@ class MulOpLite : public OpLite { ...@@ -48,8 +47,8 @@ class MulOpLite : public OpLite {
param_.y = scope->FindVar(W)->GetMutable<Tensor>(); param_.y = scope->FindVar(W)->GetMutable<Tensor>();
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>(); param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.x_num_col_dims = boost::get<int>(op_desc.GetAttr("x_num_col_dims")); param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims").get<int>();
param_.y_num_col_dims = boost::get<int>(op_desc.GetAttr("y_num_col_dims")); param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims").get<int>();
return true; return true;
} }
......
...@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const { ...@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
return true; return true;
} }
bool ReluOp::AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) { bool ReluOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<Tensor *>( param_.input = const_cast<Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<Tensor>()); &scope->FindVar(opdesc.Input("Input").front())->Get<Tensor>());
param_.output = param_.output =
......
...@@ -32,7 +32,7 @@ class ReluOp : public OpLite { ...@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool InferShape() const override; bool InferShape() const override;
bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "tanh"; } std::string DebugString() const override { return "tanh"; }
......
...@@ -46,18 +46,16 @@ class ScaleOp : public OpLite { ...@@ -46,18 +46,16 @@ class ScaleOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const framework::OpDesc &op_desc, bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override {
lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(x)->GetMutable<Tensor>(); param_.x = scope->FindVar(x)->GetMutable<Tensor>();
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>(); param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.scale = boost::get<float>(op_desc.GetAttr("scale")); param_.scale = op_desc.GetAttr("scale").get<float>();
param_.bias = boost::get<float>(op_desc.GetAttr("bias")); param_.bias = op_desc.GetAttr("bias").get<float>();
param_.bias_after_scale = param_.bias_after_scale = op_desc.GetAttr("bias_after_scale").get<bool>();
boost::get<bool>(op_desc.GetAttr("bias_after_scale"));
CHECK(kernel_); CHECK(kernel_);
kernel_->SetParam(param_); kernel_->SetParam(param_);
......
...@@ -114,7 +114,8 @@ struct variant { ...@@ -114,7 +114,8 @@ struct variant {
if (type_id == typeid(T).hash_code()) if (type_id == typeid(T).hash_code())
return *reinterpret_cast<T*>(&data); return *reinterpret_cast<T*>(&data);
else else
throw std::bad_cast(); LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
<< typeid(T).name();
} }
~variant() { helper_t::destroy(type_id, &data); } ~variant() { helper_t::destroy(type_id, &data); }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册