From 70540d1b0055ab3c241a116a4ab619d9f6958902 Mon Sep 17 00:00:00 2001 From: superjomn Date: Mon, 29 Apr 2019 16:55:26 +0800 Subject: [PATCH] add a new lightweight OpDesc compatible with the original framework::OpDesc to support mobile --- CMakeLists.txt | 3 +- cmake/configure.cmake | 4 + paddle/fluid/lite/api/CMakeLists.txt | 6 +- paddle/fluid/lite/api/cxx_api.h | 3 +- paddle/fluid/lite/core/CMakeLists.txt | 4 +- .../fluid/lite/core/mir/io_complement_pass.cc | 18 +- paddle/fluid/lite/core/op_executor.h | 2 + paddle/fluid/lite/core/op_lite.cc | 2 +- paddle/fluid/lite/core/op_lite.h | 8 +- paddle/fluid/lite/core/program.h | 28 +- paddle/fluid/lite/model_parser/CMakeLists.txt | 8 + .../fluid/lite/model_parser/compatible_pb.cc | 15 + .../fluid/lite/model_parser/compatible_pb.h | 45 +++ .../fluid/lite/model_parser/pb/CMakeLists.txt | 2 + .../fluid/lite/model_parser/pb/block_desc.cc | 13 + .../fluid/lite/model_parser/pb/block_desc.h | 123 ++++++++ paddle/fluid/lite/model_parser/pb/op_desc.cc | 15 + paddle/fluid/lite/model_parser/pb/op_desc.h | 234 +++++++++++++++ .../lite/model_parser/pb/program_desc.cc | 13 + .../fluid/lite/model_parser/pb/program_desc.h | 13 + paddle/fluid/lite/model_parser/pb/var_desc.cc | 271 ++++++++++++++++++ paddle/fluid/lite/model_parser/pb/var_desc.h | 123 ++++++++ paddle/fluid/lite/operators/fc_op.h | 6 +- paddle/fluid/lite/operators/feed_op.cc | 5 +- paddle/fluid/lite/operators/fetch_op.cc | 5 +- paddle/fluid/lite/operators/io_copy_op.cc | 3 +- paddle/fluid/lite/operators/io_copy_op.h | 2 +- paddle/fluid/lite/operators/mul_op.h | 7 +- paddle/fluid/lite/operators/relu_op.cc | 2 +- paddle/fluid/lite/operators/relu_op.h | 2 +- paddle/fluid/lite/operators/scale_op.cc | 10 +- paddle/fluid/lite/utils/varient.h | 3 +- 32 files changed, 929 insertions(+), 69 deletions(-) create mode 100644 paddle/fluid/lite/model_parser/compatible_pb.cc create mode 100644 paddle/fluid/lite/model_parser/compatible_pb.h create mode 100644 paddle/fluid/lite/model_parser/pb/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index a343a65912f..f088d872cfc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -186,7 +186,8 @@ endif() # for lite option(LITE_WITH_CUDA "Enable CUDA in lite mode" ON) -option(LITE_WITH_X86 "Enable X86 in lite mode" ON) +option(LITE_WITH_X86 "Enable X86 in lite mode" ON) +option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" ON) include(external/threadpool) include(flags) # set paddle compile flags diff --git a/cmake/configure.cmake b/cmake/configure.cmake index f859fd10a75..6c9c3fd4889 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -171,3 +171,7 @@ endif() if (LITE_WITH_X86) add_definitions("-DLITE_WITH_X86") endif() + +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + add_definitions("-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK") +endif() diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index ec0aab9063f..56262c92129 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,7 +1,7 @@ -cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host ) -cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) - if(LITE_WITH_CUDA) cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite) +else() + cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host ) + cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels) endif() diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index ea577b7211e..ed2654c02a5 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -33,9 +33,8 @@ class Predictor { const std::vector& valid_places) { framework::proto::ProgramDesc prog; LoadModel(model_path, scope_.get(), &prog); - framework::ProgramDesc prog_desc(prog); - Program program(prog_desc, scope_, valid_places); + Program program(prog, scope_, valid_places); Optimizer optimizer; optimizer.KernelPickPreferPlace(prefer_place); diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 2755baf1048..7b872d85593 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -5,10 +5,10 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(variable_lite SRCS variable.cc) cc_library(op_registry_lite SRCS op_registry.cc) cc_library(scope_lite SRCS scope.cc) -cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite) +cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite) cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite #TODO(Superjomn) remove these dependencies from original framework - proto_desc) + ) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS tensor_lite) diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index d1c7f73a749..4adba67673e 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -65,19 +65,6 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, } } -void UpdateOpdescInputName(framework::OpDesc* desc, - const std::string& old_arg_name, - const std::string& new_arg_name) { - for (auto& item : *desc->Proto()->mutable_inputs()) { - for (int i = 0; i < item.mutable_arguments()->size(); i++) { - auto* arg = item.mutable_arguments(i); - if (*arg == old_arg_name) { - *arg = new_arg_name; - } - } - } -} - void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, const std::string& var, SSAGraph* graph, Node* inst_node, @@ -99,11 +86,10 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, inst_node->AsInstruct().op->scope()->Var(io_copy_output_name); // Create IoCopy Instruction. - framework::OpDesc op_desc; + lite::OpDesc op_desc; op_desc.SetType("io_copy"); op_desc.SetInput("Input", {var}); op_desc.SetOutput("Out", {io_copy_output_name}); - op_desc.Flush(); io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope()); auto kernels = io_copy_op->CreateKernels(valid_places); @@ -126,7 +112,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc(); UpdateInputTo(&desc_dummy, var, io_copy_output_name); - framework::OpDesc desc_fake(desc_dummy, nullptr); + lite::OpDesc desc_fake(desc_dummy); inst_node->AsInstruct().op->Attach(desc_fake, inst_node->AsInstruct().op->scope()); diff --git a/paddle/fluid/lite/core/op_executor.h b/paddle/fluid/lite/core/op_executor.h index d5e63a5c8d0..eb5e0a1d1be 100644 --- a/paddle/fluid/lite/core/op_executor.h +++ b/paddle/fluid/lite/core/op_executor.h @@ -23,6 +23,7 @@ namespace paddle { namespace lite { +/* // The Executor is used to run the operators. class Executor { public: @@ -63,6 +64,7 @@ class RuntimeExecutor { private: RuntimeProgram* program_{}; }; + */ } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index f214155981f..d189fb15d99 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -76,7 +76,7 @@ bool OpLite::Run() { return true; } -bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { +bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) { CHECK(scope); scope_ = scope; op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 2092703d33a..85c9cec7fb8 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -19,12 +19,11 @@ #include #include #include -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/model_parser/compatible_pb.h" namespace paddle { namespace lite { @@ -82,7 +81,7 @@ class OpLite : public Registry { virtual bool Run(); // Link the external execution environ to internal context. - bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope); + bool Attach(const OpDesc &opdesc, lite::Scope *scope); const OpInfo *op_info() const { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); } @@ -109,8 +108,7 @@ class OpLite : public Registry { protected: // Attach it with the runtime environment. - virtual bool AttachImpl(const framework::OpDesc &opdesc, - lite::Scope *scope) = 0; + virtual bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) = 0; // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 91b78981972..43b2f535f17 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -38,10 +38,10 @@ struct Program { std::vector valid_places; // Runtime scope. lite::Scope* exec_scope{}; - const framework::ProgramDesc desc; + const framework::proto::ProgramDesc desc; explicit Program(const std::shared_ptr& root) { scope = root; } - Program(const framework::ProgramDesc& desc, + Program(const framework::proto::ProgramDesc& desc, const std::shared_ptr& root, const std::vector& valid_places) : scope(root), valid_places(valid_places), desc(desc) { @@ -56,24 +56,25 @@ struct Program { private: // Build from a program and scope. - void Build(const framework::ProgramDesc& program, + void Build(const framework::proto::ProgramDesc& program, const std::vector& valid_places) { CHECK(ops.empty()) << "Executor duplicate Build found"; // Create operators. - for (auto* op_desc : program.Block(0).AllOps()) { - auto op_type = op_desc->Type(); + for (const auto& proto_op_desc : program.blocks(0).ops()) { + lite::OpDesc op_desc(proto_op_desc); + auto op_type = op_desc.Type(); // if (op_type == "feed" || op_type == "fetch") continue; VLOG(4) << "create Op [" << op_type << "]"; ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); // pick initial kernel ops.back()->PickKernel(valid_places); - ops.back()->Attach(*op_desc, exec_scope); + ops.back()->Attach(op_desc, exec_scope); } } // Create temporary variables. - void PrepareWorkspace(const framework::ProgramDesc& program) { + void PrepareWorkspace(const framework::proto::ProgramDesc& program) { CHECK(!exec_scope) << "Duplicate PrepareWorkspace found"; exec_scope = &scope->NewScope(); // Create Feed and Fetch var. @@ -82,13 +83,14 @@ struct Program { tmp_vars.push_back("feed"); tmp_vars.push_back("fetch"); - for (auto var_desc : program.Block(0).AllVars()) { - if (!var_desc->Persistable()) { - tmp_vars.push_back(var_desc->Name()); - exec_scope->Var(var_desc->Name()); + for (auto proto_var_desc : program.blocks(0).vars()) { + lite::VarDesc var_desc(proto_var_desc); + if (!var_desc.Persistable()) { + tmp_vars.push_back(var_desc.Name()); + exec_scope->Var(var_desc.Name()); } else { - if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue; - weights.push_back(var_desc->Name()); + if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; + weights.push_back(var_desc.Name()); } } } diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 36044b9ad9f..18d4f151789 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -1,3 +1,11 @@ cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite) cc_library(runtime_lite SRCS runtime.cc) cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite) +if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite) +else() + cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto) +endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + + +add_subdirectory(pb) diff --git a/paddle/fluid/lite/model_parser/compatible_pb.cc b/paddle/fluid/lite/model_parser/compatible_pb.cc new file mode 100644 index 00000000000..ee0f7c41acc --- /dev/null +++ b/paddle/fluid/lite/model_parser/compatible_pb.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/compatible_pb.h" diff --git a/paddle/fluid/lite/model_parser/compatible_pb.h b/paddle/fluid/lite/model_parser/compatible_pb.h new file mode 100644 index 00000000000..c77d180031d --- /dev/null +++ b/paddle/fluid/lite/model_parser/compatible_pb.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +/* + * This file implements the interface to manipute the protobuf message. We use + * macros to make a compatible interface with the framework::XXDesc and + * lite::pb::XXDesc. + */ + +#include "paddle/fluid/framework/framework.pb.h" +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#include "paddle/fluid/lite/model_parser/pb/op_desc.h" +#include "paddle/fluid/lite/model_parser/pb/var_desc.h" +#else +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/var_desc.h" +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +using OpDesc = lite::pb::OpDesc; +using VarDesc = lite::pb::VarDesc; +#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +using Attribute = framework::Attribute; +using OpDesc = framework::OpDesc; +using VarDesc = framework::VarDesc; +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/CMakeLists.txt b/paddle/fluid/lite/model_parser/pb/CMakeLists.txt new file mode 100644 index 00000000000..d0f1af4cad2 --- /dev/null +++ b/paddle/fluid/lite/model_parser/pb/CMakeLists.txt @@ -0,0 +1,2 @@ +cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto) +cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto) diff --git a/paddle/fluid/lite/model_parser/pb/block_desc.cc b/paddle/fluid/lite/model_parser/pb/block_desc.cc index e69de29bb2d..ce71e4de2b8 100644 --- a/paddle/fluid/lite/model_parser/pb/block_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/block_desc.cc @@ -0,0 +1,13 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/paddle/fluid/lite/model_parser/pb/block_desc.h b/paddle/fluid/lite/model_parser/pb/block_desc.h index e69de29bb2d..e45fbe6ecd0 100644 --- a/paddle/fluid/lite/model_parser/pb/block_desc.h +++ b/paddle/fluid/lite/model_parser/pb/block_desc.h @@ -0,0 +1,123 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/proto_desc.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace lite { + +class ProgramDesc; + +// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize +// read/write speed. Only when we want the protobuf message, the local changes +// will be synchronized (by `Sync` method). + +class BlockDesc { + public: + BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc); + + BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ProgramDesc *prog); + + int32_t ID() const { return desc_->idx(); } + + int32_t Parent() const { return desc_->parent_idx(); } + + int32_t ForwardBlockID() const { return desc_->forward_block_idx(); } + + VarDesc *Var(const std::string &name_bytes); + + VarDesc *FindVar(const std::string &name_bytes) const; + + bool HasVar(const std::string &var_name) const; + + VarDesc *RenameVar(const std::string &old_name, const std::string &new_name); + + VarDesc *FindVarRecursive(const std::string &name_bytes) const; + + VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes); + + bool HasVarRecursive(const std::string &var_name) const; + + std::set LocalVarNames() const { + std::set var_names; + for (auto &var : vars_) { + var_names.insert(var.first); + } + return var_names; + } + + std::vector AllVars() const; + + BlockDesc *ParentBlock() const; + + BlockDesc *ForwardBlock() const; + + void SetForwardBlockID(int32_t forward_block_id); + + OpDesc *AppendOp(); + + void AppendAllocatedOp(std::unique_ptr &&op_desc); + + OpDesc *PrependOp(); + + void PrependAllocatedOp(std::unique_ptr &&op_desc); + + OpDesc *InsertOp(size_t index); + + /* + * Only remove op itself, + * do nothing to its input and output variables + */ + void RemoveOp(size_t s, size_t e); + + void RemoveOpInternal(const OpDesc *op_desc); + + void RemoveVar(const std::string &name) { vars_.erase(name); } + + std::vector AllOps() const; + + size_t OpSize() const { return ops_.size(); } + + OpDesc *Op(int idx) const { return ops_.at(idx).get(); } + + void Flush(); + + proto::BlockDesc *Proto(); + + ProgramDesc *Program() const { return this->prog_; } + + private: + ProgramDesc *prog_; // not_own + proto::BlockDesc *desc_; // not_own + bool need_update_; + + std::deque> ops_; + std::unordered_map> vars_; + + DISABLE_COPY_AND_ASSIGN(BlockDesc); +}; +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.cc b/paddle/fluid/lite/model_parser/pb/op_desc.cc index e69de29bb2d..c546eccc926 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/op_desc.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/pb/op_desc.h" diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.h b/paddle/fluid/lite/model_parser/pb/op_desc.h index e69de29bb2d..7b1c362a125 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.h +++ b/paddle/fluid/lite/model_parser/pb/op_desc.h @@ -0,0 +1,234 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +/* + * This file implements a light-weight OpDesc like the framework::OpDesc. We + * delete the unnecessary methods, and remove the underlying dependencies, such + * as framework::Operator and boost::varient to make it runnable in mobile. + */ + +#include +#include +#include +#include +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace pb { + +using Attribute = variant>; +using VariableNameMap = std::map>; + +/* + * The lite::OpDesc, an light-weight implementation of wrapper of proto::OpDesc. + * Unlike the original one in framework::OpDesc, we remove the local members + * except the desc_, to avoid the inconsistent state, which is normal in the + * original interface and results in bugs. + */ +class OpDesc { + public: + OpDesc() {} + + OpDesc(const framework::proto::OpDesc &desc) : desc_(desc) {} + + void CopyFrom(const OpDesc &op_desc) { desc_ = op_desc.ReadonlyProto(); } + + framework::proto::OpDesc *Proto() { return &desc_; } + const framework::proto::OpDesc &ReadonlyProto() const { return desc_; } + + std::string Type() const { return desc_.type(); } + + void SetType(const std::string &type) { desc_.set_type(type); } + + // Get the arguments of parameter called `param` + std::vector Input(const std::string ¶m) const { + return GetArguments(desc_.inputs(), param); + } + + std::vector InputArgumentNames() const { + return GetArgumentNames(desc_.inputs()); + } + + void SetInput(const std::string ¶m, + const std::vector &args) { + SetArgument(desc_.mutable_inputs(), param, args); + } + + std::vector Output(const std::string ¶m) const { + return GetArguments(desc_.outputs(), param); + } + + std::vector OutputArgumentNames() const { + return GetArgumentNames(desc_.outputs()); + } + + void SetOutput(const std::string ¶m, + const std::vector &args) { + SetArgument(desc_.mutable_outputs(), param, args); + } + + bool HasAttr(const std::string &name) const { + const auto &xs = desc_.attrs(); + auto it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { + return x.name() == name; + }); + return it != xs.end(); + } + + framework::proto::AttrType GetAttrType(const std::string &name) const { + const auto &xs = desc_.attrs(); + auto it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { + return x.name() == name; + }); + CHECK(it != xs.end()); + return it->type(); + } + + std::vector AttrNames() const { + std::vector res; + const auto &xs = desc_.attrs(); + std::transform( + xs.begin(), xs.end(), std::back_inserter(res), + [](const framework::proto::OpDesc_Attr &x) { return x.name(); }); + return res; + } + + template + void SetAttr(const std::string &name, const T &v) { + auto &xs = *desc_.mutable_attrs(); + auto it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { + return x.name() == name; + }); + if (it == xs.end()) { + auto *attr = xs.Add(); + attr->set_name(name); + it = std::find(xs.begin(), xs.end(), name); + } + + switch (typeid(T).hash_code()) { + case typeid(int).hash_code(): + it->set_type(framework::proto::INT); + it->set_i(v); + break; + case typeid(float).hash_code(): + it->set_type(framework::proto::FLOAT); + it->set_f(v); + break; + case typeid(std::string).hash_code(): + it->set_type(framework::proto::STRING); + it->set_s(v.c_str()); + break; + case typeid(std::string).hash_code(): + it->set_type(framework::proto::BOOLEAN); + it->set_b(v); + break; + default: + LOG(FATAL) << "unsupport attr type"; + } + } + + Attribute GetAttr(const std::string &name) const { + auto &xs = desc_.attrs(); + auto it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { + return x.name() == name; + }); + + Attribute res; + CHECK(it != xs.end()); + + switch (it->type()) { + case framework::proto::INT: + res.set(it->i()); + break; + case framework::proto::FLOAT: + res.set(it->f()); + break; + case framework::proto::STRING: + res.set(it->s()); + break; + case framework::proto::BOOLEAN: + res.set(it->b()); + break; + + default: + LOG(FATAL) << "unsupported attr type"; + } + + return res; + } + + private: + std::vector GetArguments( + const google::protobuf::RepeatedPtrField + &xs, + const std::string ¶m) const { + std::vector res; + auto it = std::find_if(xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Var &it) { + return it.parameter() == param; + }); + CHECK(it != xs.end()); + + const auto &ys = it->arguments(); + std::transform(ys.begin(), ys.end(), std::back_inserter(res), + [](const std::string &x) { return x; }); + return res; + } + + void SetArgument( + google::protobuf::RepeatedPtrField *xs, + const std::string ¶m, const std::vector &args) { + auto it = std::find_if(xs->begin(), xs->end(), + [&](const framework::proto::OpDesc_Var &it) { + return it.parameter() == param; + }); + if (it == xs->end()) { + auto *new_arg = xs->Add(); + new_arg->set_parameter(param); + for (const auto &arg : args) { + *new_arg->mutable_arguments()->Add() = arg; + } + } else { + it->mutable_arguments()->Clear(); + for (const auto &arg : args) { + *it->mutable_arguments()->Add() = arg; + } + } + } + + std::vector GetArgumentNames( + const google::protobuf::RepeatedPtrField + &xs) const { + std::vector res; + std::transform( + xs.begin(), xs.end(), std::back_inserter(res), + [](const framework::proto::OpDesc_Var &x) { return x.parameter(); }); + return res; + } + + private: + framework::proto::OpDesc desc_; +}; + +} // namespace pb +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/program_desc.cc b/paddle/fluid/lite/model_parser/pb/program_desc.cc index e69de29bb2d..ce71e4de2b8 100644 --- a/paddle/fluid/lite/model_parser/pb/program_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/program_desc.cc @@ -0,0 +1,13 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/paddle/fluid/lite/model_parser/pb/program_desc.h b/paddle/fluid/lite/model_parser/pb/program_desc.h index e69de29bb2d..ce71e4de2b8 100644 --- a/paddle/fluid/lite/model_parser/pb/program_desc.h +++ b/paddle/fluid/lite/model_parser/pb/program_desc.h @@ -0,0 +1,13 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/paddle/fluid/lite/model_parser/pb/var_desc.cc b/paddle/fluid/lite/model_parser/pb/var_desc.cc index e69de29bb2d..2aaf5ee14d0 100644 --- a/paddle/fluid/lite/model_parser/pb/var_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/var_desc.cc @@ -0,0 +1,271 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/pb/var_desc.h" + +namespace paddle { +namespace lite { +namespace pb { + +using namespace framework; + +proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); } + +void VarDesc::SetType(proto::VarType::Type type) { + desc_.mutable_type()->set_type(type); +} + +void VarDesc::SetShape(const std::vector &dims) { + VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); +} + +void VarDesc::SetTensorDescNum(size_t num) { + switch (desc_.type().type()) { + case proto::VarType::READER: { + auto *lod_tensors_ptr = + desc_.mutable_type()->mutable_reader()->mutable_lod_tensor(); + lod_tensors_ptr->Clear(); + for (size_t i = 0; i < num; ++i) { + lod_tensors_ptr->Add(); + } + return; + } break; + default: + LOG(FATAL) << "Setting 'sub_tensor_number' is not supported by the type " + "of var %s." + << this->Name(); + } +} + +size_t VarDesc::GetTensorDescNum() const { + switch (desc_.type().type()) { + case proto::VarType::READER: + return desc_.type().reader().lod_tensor_size(); + break; + default: + LOG(FATAL) << "Getting 'sub_tensor_number' is not supported by the type " + "of var %s." + << this->Name(); + } +} + +void VarDesc::SetShapes( + const std::vector> &multiple_dims) { + if (multiple_dims.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_dims.size()); + } + std::vector tensors = mutable_tensor_descs(); + for (size_t i = 0; i < multiple_dims.size(); ++i) { + VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); + } +} + +std::vector VarDesc::GetShape() const { + return RepeatedToVector(tensor_desc().dims()); +} + +std::vector> VarDesc::GetShapes() const { + std::vector descs = tensor_descs(); + std::vector> res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(RepeatedToVector(tensor_desc.dims())); + } + return res; +} + +void VarDesc::SetDataType(proto::VarType::Type data_type) { + mutable_tensor_desc()->set_data_type(data_type); +} + +void VarDesc::SetDataTypes( + const std::vector &multiple_data_type) { + if (multiple_data_type.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given data types(" + << multiple_data_type.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_data_type.size()); + } + std::vector tensor_descs = + mutable_tensor_descs(); + for (size_t i = 0; i < multiple_data_type.size(); ++i) { + tensor_descs[i]->set_data_type(multiple_data_type[i]); + } +} + +proto::VarType::Type VarDesc::GetDataType() const { + return tensor_desc().data_type(); +} + +std::vector VarDesc::GetDataTypes() const { + std::vector descs = tensor_descs(); + std::vector res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(tensor_desc.data_type()); + } + return res; +} + +void VarDesc::SetLoDLevel(int32_t lod_level) { + switch (desc_.type().type()) { + case proto::VarType::LOD_TENSOR: + desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level); + break; + case proto::VarType::LOD_TENSOR_ARRAY: + desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level); + break; + default: + LOG(FATAL) + << "Setting 'lod_level' is not supported by the type of var %s." + << this->Name(); + } +} + +void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { + if (multiple_lod_level.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given lod_levels(" + << multiple_lod_level.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_lod_level.size()); + } + switch (desc_.type().type()) { + case proto::VarType::READER: { + size_t i = 0; + for (auto &lod_tensor : + *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) { + lod_tensor.set_lod_level(multiple_lod_level[i++]); + } + } break; + default: + LOG(FATAL) + << "Setting 'lod_levels' is not supported by the type of var %s." + << this->Name(); + } +} + +int32_t VarDesc::GetLoDLevel() const { + switch (desc_.type().type()) { + case proto::VarType::LOD_TENSOR: + return desc_.type().lod_tensor().lod_level(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.type().tensor_array().lod_level(); + default: + LOG(FATAL) + << "Getting 'lod_level' is not supported by the type of var %s." + << this->Name(); + } +} + +std::vector VarDesc::GetLoDLevels() const { + std::vector res; + switch (desc_.type().type()) { + case proto::VarType::READER: + res.reserve(desc_.type().reader().lod_tensor_size()); + for (auto &lod_tensor : desc_.type().reader().lod_tensor()) { + res.push_back(lod_tensor.lod_level()); + } + return res; + break; + default: + LOG(FATAL) + << "Getting 'lod_levels' is not supported by the type of var %s." + << this->Name(); + } +} + +const proto::VarType::TensorDesc &VarDesc::tensor_desc() const { + CHECK(desc_.has_type()) << "The var's type hasn't been set."; + CHECK(desc_.type().has_type()) << "The var type hasn't been set."; + switch (desc_.type().type()) { + case proto::VarType::SELECTED_ROWS: + return desc_.type().selected_rows(); + case proto::VarType::LOD_TENSOR: + return desc_.type().lod_tensor().tensor(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.type().tensor_array().tensor(); + default: + LOG(FATAL) + << "Getting 'tensor_desc' is not supported by the type of var %s." + << this->Name(); + } +} + +std::vector VarDesc::tensor_descs() const { + CHECK(desc_.has_type()) << "The var type hasn't been set."; + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_.type().type()) { + case proto::VarType::READER: + for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) { + res.push_back(lod_tensor.tensor()); + } + return res; + default: + LOG(FATAL) + << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); + } +} + +proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { + CHECK(desc_.has_type()) << "The var type hasn't been set."; + CHECK(desc_.type().has_type()) << "The var type hasn't been set."; + switch (desc_.type().type()) { + case proto::VarType::SELECTED_ROWS: + return desc_.mutable_type()->mutable_selected_rows(); + case proto::VarType::LOD_TENSOR: + return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor(); + default: + LOG(FATAL) << "Getting 'mutable_tensor_desc' is not supported by the " + "type of var " + "%s." + << this->Name(); + } +} + +std::vector VarDesc::mutable_tensor_descs() { + CHECK(desc_.has_type()) << "The var type hasn't been set."; + CHECK(desc_.type().has_type()) << "The var type hasn't been set."; + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_.type().type()) { + case proto::VarType::READER: + for (auto &lod_tensor : + *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) { + res.push_back(lod_tensor.mutable_tensor()); + } + return res; + default: + LOG(FATAL) + << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); + } +} + +} // namespace pb +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/var_desc.h b/paddle/fluid/lite/model_parser/pb/var_desc.h index e69de29bb2d..4975a0e0d43 100644 --- a/paddle/fluid/lite/model_parser/pb/var_desc.h +++ b/paddle/fluid/lite/model_parser/pb/var_desc.h @@ -0,0 +1,123 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/framework.pb.h" + +namespace paddle { +namespace lite { +namespace pb { + +// convert between std::vector and protobuf repeated. +template +inline std::vector RepeatedToVector( + const google::protobuf::RepeatedField &repeated_field) { + std::vector ret; + ret.reserve(repeated_field.size()); + std::copy(repeated_field.begin(), repeated_field.end(), + std::back_inserter(ret)); + return ret; +} + +template +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { + repeated_field->Clear(); + repeated_field->Reserve(vec.size()); + for (const auto &elem : vec) { + *repeated_field->Add() = elem; + } +} + +// Specialize vector. +template +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { + repeated_field->Clear(); + repeated_field->Reserve(vec.size()); + for (auto elem : vec) { + *repeated_field->Add() = elem; + } +} + +class VarDesc { + public: + explicit VarDesc(const std::string &name) { + desc_.set_name(name); + // TODO(paddle-dev): Why default to lodtensor. + desc_.mutable_type()->set_type(framework::proto::VarType::LOD_TENSOR); + } + + explicit VarDesc(const framework::proto::VarDesc &desc) : desc_(desc) {} + + framework::proto::VarDesc *Proto() { return &desc_; } + + std::string Name() const { return desc_.name(); } + + void SetName(std::string name) { desc_.set_name(name); } + + void SetTensorDescNum(size_t num); + + size_t GetTensorDescNum() const; + + void SetShape(const std::vector &dims); + + void SetShapes(const std::vector> &multiple_dims); + + std::vector GetShape() const; + + std::vector> GetShapes() const; + + void SetDataType(framework::proto::VarType::Type data_type); + + void SetDataTypes( + const std::vector &multiple_data_type); + + framework::proto::VarType::Type GetDataType() const; + + std::vector GetDataTypes() const; + + void SetLoDLevel(int32_t lod_level); + + void SetLoDLevels(const std::vector &multiple_lod_level); + + int32_t GetLoDLevel() const; + + std::vector GetLoDLevels() const; + + framework::proto::VarType::Type GetType() const; + + void SetType(framework::proto::VarType::Type type); + + bool Persistable() const { return desc_.persistable(); } + + void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } + + private: + const framework::proto::VarType::TensorDesc &tensor_desc() const; + std::vector tensor_descs() const; + framework::proto::VarType::TensorDesc *mutable_tensor_desc(); + std::vector mutable_tensor_descs(); + + framework::proto::VarDesc desc_; +}; + +} // namespace pb +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index 15d1693b719..4c322e41b8f 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -46,8 +46,7 @@ class FcOpLite : public OpLite { */ // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const framework::OpDesc &op_desc, - lite::Scope *scope) override { + bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { auto input = op_desc.Input("Input").front(); auto W = op_desc.Input("W").front(); auto bias = op_desc.Input("Bias").front(); @@ -58,8 +57,7 @@ class FcOpLite : public OpLite { param_.bias = scope->FindVar(bias)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.in_num_col_dims = - boost::get(op_desc.GetAttr("in_num_col_dims")); + param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims").get(); CHECK(kernel_); kernel_->SetParam(param_); diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc index 47876b55e76..0b5ffcfd63a 100644 --- a/paddle/fluid/lite/operators/feed_op.cc +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -35,8 +35,7 @@ class FeedOp : public OpLite { void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const framework::OpDesc& opdesc, - lite::Scope* scope) override { + bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { auto feed_var_name = opdesc.Input("X").front(); auto* feed_var = scope->FindVar(feed_var_name); CHECK(feed_var); @@ -50,7 +49,7 @@ class FeedOp : public OpLite { // NOTE need boost here // TODO(Superjomn) drop the need of framework::op_desc - param_.col = boost::get(opdesc.GetAttr("col")); + param_.col = opdesc.GetAttr("col").get(); return true; } diff --git a/paddle/fluid/lite/operators/fetch_op.cc b/paddle/fluid/lite/operators/fetch_op.cc index f4e53c6699a..b34d57645ef 100644 --- a/paddle/fluid/lite/operators/fetch_op.cc +++ b/paddle/fluid/lite/operators/fetch_op.cc @@ -33,8 +33,7 @@ class FetchOp : public OpLite { void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const framework::OpDesc& opdesc, - lite::Scope* scope) override { + bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { auto _x = opdesc.Input("X").front(); auto* x = scope->FindVar(_x); CHECK(x); @@ -44,7 +43,7 @@ class FetchOp : public OpLite { auto* out = scope->FindVar(_out); param_.fetch_list = out->GetMutable>(); - param_.col = boost::get(opdesc.GetAttr("col")); + param_.col = opdesc.GetAttr("col").get(); return true; } diff --git a/paddle/fluid/lite/operators/io_copy_op.cc b/paddle/fluid/lite/operators/io_copy_op.cc index c9a71160731..220853fc263 100644 --- a/paddle/fluid/lite/operators/io_copy_op.cc +++ b/paddle/fluid/lite/operators/io_copy_op.cc @@ -29,8 +29,7 @@ bool IoCopyOp::InferShape() const { return true; } bool IoCopyOp::Run() { return OpLite::Run(); } -bool IoCopyOp::AttachImpl(const paddle::framework::OpDesc &opdesc, - paddle::lite::Scope *scope) { +bool IoCopyOp::AttachImpl(const OpDesc &opdesc, paddle::lite::Scope *scope) { auto x = opdesc.Input("Input").front(); auto out = opdesc.Output("Out").front(); param_.x = GetTensor(scope, x); diff --git a/paddle/fluid/lite/operators/io_copy_op.h b/paddle/fluid/lite/operators/io_copy_op.h index 7d07f333576..efcd11bc309 100644 --- a/paddle/fluid/lite/operators/io_copy_op.h +++ b/paddle/fluid/lite/operators/io_copy_op.h @@ -31,7 +31,7 @@ class IoCopyOp : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; private: operators::IoCopyParam param_; diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index a0e91ba9865..334321c457c 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -38,8 +38,7 @@ class MulOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const framework::OpDesc &op_desc, - lite::Scope *scope) override { + bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { auto input = op_desc.Input("X").front(); auto W = op_desc.Input("Y").front(); auto out = op_desc.Output("Out").front(); @@ -48,8 +47,8 @@ class MulOpLite : public OpLite { param_.y = scope->FindVar(W)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.x_num_col_dims = boost::get(op_desc.GetAttr("x_num_col_dims")); - param_.y_num_col_dims = boost::get(op_desc.GetAttr("y_num_col_dims")); + param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims").get(); + param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims").get(); return true; } diff --git a/paddle/fluid/lite/operators/relu_op.cc b/paddle/fluid/lite/operators/relu_op.cc index ea3dea6585d..00571a8e4c9 100644 --- a/paddle/fluid/lite/operators/relu_op.cc +++ b/paddle/fluid/lite/operators/relu_op.cc @@ -31,7 +31,7 @@ bool ReluOp::InferShape() const { return true; } -bool ReluOp::AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) { +bool ReluOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { param_.input = const_cast( &scope->FindVar(opdesc.Input("Input").front())->Get()); param_.output = diff --git a/paddle/fluid/lite/operators/relu_op.h b/paddle/fluid/lite/operators/relu_op.h index 8fa311373f0..088f1314dac 100644 --- a/paddle/fluid/lite/operators/relu_op.h +++ b/paddle/fluid/lite/operators/relu_op.h @@ -32,7 +32,7 @@ class ReluOp : public OpLite { bool InferShape() const override; - bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "tanh"; } diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc index 42a6a588914..95dfecb9eba 100644 --- a/paddle/fluid/lite/operators/scale_op.cc +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -46,18 +46,16 @@ class ScaleOp : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const framework::OpDesc &op_desc, - lite::Scope *scope) override { + bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); auto out = op_desc.Output("Out").front(); param_.x = scope->FindVar(x)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.scale = boost::get(op_desc.GetAttr("scale")); - param_.bias = boost::get(op_desc.GetAttr("bias")); - param_.bias_after_scale = - boost::get(op_desc.GetAttr("bias_after_scale")); + param_.scale = op_desc.GetAttr("scale").get(); + param_.bias = op_desc.GetAttr("bias").get(); + param_.bias_after_scale = op_desc.GetAttr("bias_after_scale").get(); CHECK(kernel_); kernel_->SetParam(param_); diff --git a/paddle/fluid/lite/utils/varient.h b/paddle/fluid/lite/utils/varient.h index 6a4b33a4fc6..40290f1fcef 100644 --- a/paddle/fluid/lite/utils/varient.h +++ b/paddle/fluid/lite/utils/varient.h @@ -114,7 +114,8 @@ struct variant { if (type_id == typeid(T).hash_code()) return *reinterpret_cast(&data); else - throw std::bad_cast(); + LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " + << typeid(T).name(); } ~variant() { helper_t::destroy(type_id, &data); } }; -- GitLab