From ae60589f9f4b730f2346cd13e32eb1c8794d3d63 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 4 Jun 2019 19:35:50 +0800 Subject: [PATCH] Lite/refactor cp desc (#17831) --- paddle/fluid/lite/core/CMakeLists.txt | 4 +- .../core/mir/type_target_transform_pass.cc | 12 +- .../core/mir/type_target_transform_pass.h | 6 +- .../core/mir/variable_place_inference_pass.h | 4 +- paddle/fluid/lite/core/op_lite.cc | 95 +------------ paddle/fluid/lite/core/op_lite.h | 93 ++++++++----- paddle/fluid/lite/core/program.cc | 8 +- paddle/fluid/lite/core/program.h | 8 +- paddle/fluid/lite/model_parser/CMakeLists.txt | 9 +- .../fluid/lite/model_parser/compatible_pb.cc | 111 +++++++++++++++ .../fluid/lite/model_parser/compatible_pb.h | 25 +--- .../lite/model_parser/cpp/CMakeLists.txt | 1 + paddle/fluid/lite/model_parser/cpp/op_desc.cc | 66 +++++++++ paddle/fluid/lite/model_parser/cpp/op_desc.h | 126 ++++++++++++++++++ paddle/fluid/lite/model_parser/desc_apis.h | 85 ++++++++++++ .../fluid/lite/model_parser/op_desc_test.cc | 107 +++++++++++++++ paddle/fluid/lite/model_parser/pb/op_desc.cc | 99 +++++++++++--- paddle/fluid/lite/model_parser/pb/op_desc.h | 116 ++++++---------- paddle/fluid/lite/operators/activation_ops.cc | 4 +- paddle/fluid/lite/operators/concat_op.cc | 4 +- paddle/fluid/lite/operators/concat_op.h | 2 +- paddle/fluid/lite/operators/concat_op_test.cc | 2 +- paddle/fluid/lite/operators/dropout_op.cc | 12 +- .../fluid/lite/operators/elementwise_ops.cc | 10 +- paddle/fluid/lite/operators/fc_op.h | 4 +- paddle/fluid/lite/operators/fc_op_test.cc | 2 +- paddle/fluid/lite/operators/feed_op.cc | 4 +- paddle/fluid/lite/operators/fetch_op.cc | 4 +- .../fluid/lite/operators/fill_constant_op.cc | 10 +- paddle/fluid/lite/operators/io_copy_op.cc | 3 +- paddle/fluid/lite/operators/io_copy_op.h | 2 +- paddle/fluid/lite/operators/mean_op.cc | 6 +- paddle/fluid/lite/operators/mul_op.cc | 2 +- paddle/fluid/lite/operators/mul_op.h | 8 +- paddle/fluid/lite/operators/relu_op.cc | 2 +- paddle/fluid/lite/operators/relu_op.h | 2 +- paddle/fluid/lite/operators/reshape_op.cc | 10 +- paddle/fluid/lite/operators/reshape_op.h | 4 +- .../fluid/lite/operators/reshape_op_test.cc | 8 +- paddle/fluid/lite/operators/scale_op.cc | 8 +- paddle/fluid/lite/operators/scale_op.h | 2 +- paddle/fluid/lite/operators/scale_op_test.cc | 4 +- paddle/fluid/lite/operators/softmax_op.cc | 7 +- paddle/fluid/lite/operators/softmax_op.h | 2 +- .../fluid/lite/operators/softmax_op_test.cc | 2 +- 45 files changed, 774 insertions(+), 331 deletions(-) create mode 100644 paddle/fluid/lite/model_parser/cpp/CMakeLists.txt create mode 100644 paddle/fluid/lite/model_parser/cpp/op_desc.cc create mode 100644 paddle/fluid/lite/model_parser/cpp/op_desc.h create mode 100644 paddle/fluid/lite/model_parser/desc_apis.h create mode 100644 paddle/fluid/lite/model_parser/op_desc_test.cc diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index d37e44c733f..94085934112 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -25,7 +25,9 @@ cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite) cc_library(scope_lite SRCS scope.cc DEPS ${tensor_lite}) cc_library(cpu_info_lite SRCS cpu_info.cc) cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite) -cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite ${tensor_lite}) +cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite + cpp_op_desc_lite + ${tensor_lite}) cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc index ddd07970166..25789d34dca 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc @@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst( inst_node->AsStmt().op->scope()->Var(io_copy_output_name); // Create IoCopy Instruction. - lite::OpDesc op_desc; + cpp::OpDesc op_desc; op_desc.SetType("io_copy"); op_desc.SetInput("Input", {var}); op_desc.SetOutput("Out", {io_copy_output_name}); @@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst( // Update the original instruction OpDesc. // Update its input to the io_copy_output_name - auto& inst = inst_node->AsStmt(); - auto inst_program_desc = inst.op_info()->desc(); // Add new link, var -> new_inst, new_inst->newarg, newarg->inst DirectedLink(graph->Argument(var), io_copy_inst); @@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst( DirectedLink(io_copy_output_arg, inst_node); // reset opdesc and update kernel information - auto desc_dummy = inst_node->AsStmt().op->op_info()->desc(); - UpdateInputTo(&desc_dummy, var, io_copy_output_name); + UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var, + io_copy_output_name); - lite::OpDesc desc_fake(desc_dummy); - inst_node->AsStmt().op->Attach(desc_fake, inst_node->AsStmt().op->scope()); + inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), + inst_node->AsStmt().op->scope()); std::string tmp; if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.h b/paddle/fluid/lite/core/mir/type_target_transform_pass.h index f8557f44e3c..838c0bcdabc 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.h +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.h @@ -24,10 +24,10 @@ namespace paddle { namespace lite { namespace mir { -static void UpdateInputTo(framework::proto::OpDesc* desc, - const std::string& from, const std::string& to) { +static void UpdateInputTo(cpp::OpDesc* desc, const std::string& from, + const std::string& to) { for (auto& item : *desc->mutable_inputs()) { - for (auto& input : *item.mutable_arguments()) { + for (auto& input : item.second) { if (input == from) { input = to; } diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 79c977b331f..4d555d638a9 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass { // check if inputs's place is set, if not set, update them with the // kernel's declaration. auto type = inst.picked_kernel().GetInputDeclType(arg_name); - auto arg_names = inst.op_info()->input_argument().at(arg_name); + auto arg_names = inst.op_info()->inputs().at(arg_name); for (auto& arg_name : arg_names) { VLOG(3) << "--- var " << arg_name; @@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass { for (auto& arg_name : inst.op_info()->output_argnames()) { VLOG(3) << "-- output arg_name " << arg_name; auto type = inst.picked_kernel().GetOutputDeclType(arg_name); - auto arg_names = inst.op_info()->output_argument().at(arg_name); + auto arg_names = inst.op_info()->outputs().at(arg_name); // check if outputs's place is set, if not set, update them with the // kernel's declaration. for (auto& arg_name : arg_names) { diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index 47e0c441e75..bd98b23bf25 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -68,13 +68,13 @@ bool OpLite::Run() { return true; } -bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) { +bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { // valid_places_.clear(); CHECK(scope != nullptr); // CHECK(!op_info_.get()); scope_ = scope; - op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. - op_info_->Build(opdesc.ReadonlyProto()); + op_info_.reset( + new OpInfo(opdesc)); // Force clean the out-of-date infomation. return AttachImpl(opdesc, scope); } @@ -92,94 +92,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope, return var->GetMutable(); } -bool OpInfo::GetInputArgname(const std::string &value_name, - std::string *out) const { - for (auto &item : input_argument_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = item.first; - return true; - } - } - return false; -} -bool OpInfo::GetOutputArgname(const std::string &value_name, - std::string *out) const { - for (auto &item : output_argument_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = item.first; - return true; - } - } - return false; -} - -void OpInfo::ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) { - for (const auto &item : opdesc.inputs()) { - for (const auto &x : item.arguments()) { - input_names_.push_back(x); - } - } - for (const auto &item : opdesc.outputs()) { - for (const auto &x : item.arguments()) { - output_names_.push_back(x); - } - } -} - -void OpInfo::CollectInputAndOutputArgnames( - const framework::proto::OpDesc &opdesc) { - for (const auto &item : opdesc.inputs()) { - input_argnames_.push_back(item.parameter()); - } - for (const auto &item : opdesc.outputs()) { - output_argnames_.push_back(item.parameter()); - } -} - -void OpInfo::CollectArguments(const framework::proto::OpDesc &opdesc) { - for (const auto &item : opdesc.inputs()) { - for (auto &x : item.arguments()) { - input_argument_[item.parameter()].push_back(x); - } - } - for (const auto &item : opdesc.outputs()) { - for (auto &x : item.arguments()) { - output_argument_[item.parameter()].push_back(x); - } - } -} - -void OpInfo::Build(const framework::proto::OpDesc &desc) { - ExtractInputsAndOutputs(desc); - CollectInputAndOutputArgnames(desc); - CollectArguments(desc); - desc_.reset(new framework::proto::OpDesc(desc)); -} - -const std::map> &OpInfo::input_argument() - const { - return input_argument_; -} - -const std::map> &OpInfo::output_argument() - const { - return output_argument_; -} - -const std::list &OpInfo::input_argnames() const { - return input_argnames_; -} - -const std::list &OpInfo::output_argnames() const { - return output_argnames_; -} - -const framework::proto::OpDesc &OpInfo::desc() const { - CHECK(desc_) << "desc has't set"; - return *desc_; -} - } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 8845760234d..e55806eb41d 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -23,7 +23,7 @@ #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/model_parser/compatible_pb.h" +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" namespace paddle { namespace lite { @@ -71,7 +71,7 @@ class OpLite : public Registry { virtual bool Run(); // Link the external execution environ to internal context. - bool Attach(const OpDesc &opdesc, lite::Scope *scope); + bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope); const OpInfo *op_info() const { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); } @@ -94,7 +94,7 @@ class OpLite : public Registry { protected: // Attach it with the runtime environment. - virtual bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) = 0; + virtual bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) = 0; // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. @@ -144,40 +144,61 @@ class OpLite : public Registry { * Operator Information, such as some description. It will be shared by all the * kernels of the same operator. */ -class OpInfo { +class OpInfo : public cpp::OpDesc { public: - // To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf - // message instead. - void Build(const framework::proto::OpDesc &desc); - - const framework::proto::OpDesc &desc() const; - framework::proto::OpDesc *mutable_desc() { return desc_.get(); } - const std::list &input_names() const { return input_names_; } - const std::list &output_names() const { return output_names_; } - const std::map> &input_argument() const; - const std::map> &output_argument() const; - bool GetInputArgname(const std::string &value_name, std::string *out) const; - bool GetOutputArgname(const std::string &value_name, std::string *out) const; - - const std::list &input_argnames() const; - const std::list &output_argnames() const; - - private: - void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc); - - void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc); - - void CollectArguments(const framework::proto::OpDesc &opdesc); - - private: - std::list input_names_; - std::list output_names_; - std::list input_argnames_; - std::list output_argnames_; - std::map> input_argument_; - std::map> output_argument_; - // NOTE too heavy. - std::unique_ptr desc_; + OpInfo(const OpInfo &) = default; + OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {} + + // Collect all the input variable's name. + std::vector input_names() const { + std::vector res; + for (auto ¶m : InputArgumentNames()) { + for (auto &x : Input(param)) { + res.push_back(x); + } + } + return res; + } + + // Collect all the output variable's name. + std::vector output_names() const { + std::vector res; + for (auto ¶m : OutputArgumentNames()) { + for (auto &x : Output(param)) { + res.push_back(x); + } + } + return res; + } + + std::vector input_argnames() const { + return InputArgumentNames(); + } + + std::vector output_argnames() const { + return OutputArgumentNames(); + } + + bool GetInputArgname(const std::string &value_name, std::string *out) const { + for (auto &item : inputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; + } + bool GetOutputArgname(const std::string &value_name, std::string *out) const { + for (auto &item : outputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; + } }; } // namespace lite diff --git a/paddle/fluid/lite/core/program.cc b/paddle/fluid/lite/core/program.cc index 0ec9590d09c..20133287e87 100644 --- a/paddle/fluid/lite/core/program.cc +++ b/paddle/fluid/lite/core/program.cc @@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram( auto program_dummy = desc; program_dummy.mutable_blocks(0)->clear_ops(); for (auto &node : instructions_) { - auto desc_dummy = node.op()->op_info()->desc(); - OpDesc desc(desc_dummy); - desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); + pb::OpDesc pb_desc; + TransformOpDescCppToPb(*node.op()->op_info(), &pb_desc); + pb_desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); // append new opdesc - *program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto(); + *program_dummy.mutable_blocks(0)->add_ops() = *pb_desc.Proto(); } return program_dummy.SerializeAsString(); } diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 20b61da3573..1ebd6b437c7 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -22,6 +22,7 @@ #include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/model_parser/compatible_pb.h" #ifdef LITE_WITH_PROFILE #include "paddle/fluid/lite/core/profile/basic_profiler.h" #endif // LITE_WITH_PROFILE @@ -67,7 +68,7 @@ struct Program { CHECK(ops.empty()) << "Executor duplicate Build found"; // Create operators. for (const auto& proto_op_desc : program.blocks(0).ops()) { - lite::OpDesc op_desc(proto_op_desc); + pb::OpDesc op_desc(proto_op_desc); auto op_type = op_desc.Type(); // if (op_type == "feed" || op_type == "fetch") continue; VLOG(4) << "create Op [" << op_type << "]"; @@ -75,7 +76,10 @@ struct Program { auto op = LiteOpRegistry::Global().Create(op_type); CHECK(op) << "no Op found for " << op_type; ops.emplace_back(std::move(op)); - ops.back()->Attach(op_desc, exec_scope); + + cpp::OpDesc cpp_op_desc; + TransformOpDescPbToCpp(op_desc, &cpp_op_desc); + ops.back()->Attach(cpp_op_desc, exec_scope); } } diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 9fe73c49c61..a284f0388fa 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -11,11 +11,7 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) endif() -if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) - cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite) -else() - cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc) -endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite) set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite target_wrapper_host @@ -27,4 +23,7 @@ if (LITE_WITH_CUDA) endif() cc_library(model_parser_lite SRCS model_parser.cc DEPS ${model_parser_deps}) +cc_test(test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite any_lite op_desc_lite compatible_pb_lite) + add_subdirectory(pb) +add_subdirectory(cpp) diff --git a/paddle/fluid/lite/model_parser/compatible_pb.cc b/paddle/fluid/lite/model_parser/compatible_pb.cc index ee0f7c41acc..00c1478c8a4 100644 --- a/paddle/fluid/lite/model_parser/compatible_pb.cc +++ b/paddle/fluid/lite/model_parser/compatible_pb.cc @@ -13,3 +13,114 @@ // limitations under the License. #include "paddle/fluid/lite/model_parser/compatible_pb.h" +#include "compatible_pb.h" + +namespace paddle { +namespace lite { + +void InputsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + for (const std::string ¶m : pb_desc.InputArgumentNames()) { + cpp_desc->SetInput(param, pb_desc.Input(param)); + } +} + +void InputsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + for (const std::string ¶m : cpp_desc.InputArgumentNames()) { + pb_desc->SetInput(param, cpp_desc.Input(param)); + } +} + +void OutputsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + for (const std::string ¶m : pb_desc.OutputArgumentNames()) { + cpp_desc->SetOutput(param, pb_desc.Output(param)); + } +} + +void OutputsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + for (const std::string ¶m : cpp_desc.OutputArgumentNames()) { + pb_desc->SetOutput(param, cpp_desc.Output(param)); + } +} + +void AttrsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + using AttrType = OpDescAPI::AttrType; + auto set_attr = [&](const std::string &name, AttrType type) { + switch (type) { + case AttrType::INT: + cpp_desc->SetAttr(name, pb_desc.GetAttr(name)); + break; + case AttrType::FLOAT: + cpp_desc->SetAttr(name, pb_desc.GetAttr(name)); + break; + case AttrType::STRING: + cpp_desc->SetAttr(name, + pb_desc.GetAttr(name)); + break; + case AttrType::INTS: + cpp_desc->SetAttr>( + name, pb_desc.GetAttr>(name)); + break; + case AttrType::FLOATS: + cpp_desc->SetAttr>( + name, pb_desc.GetAttr>(name)); + break; + case AttrType::BOOLEAN: + cpp_desc->SetAttr(name, pb_desc.GetAttr(name)); + break; + case AttrType::STRINGS: + cpp_desc->SetAttr>( + name, pb_desc.GetAttr>(name)); + break; + default: + LOG(FATAL) << "Unsupported attr type found " << static_cast(type); + } + }; + + for (const auto &attr_name : pb_desc.AttrNames()) { + auto type = pb_desc.GetAttrType(attr_name); + set_attr(attr_name, type); + } +} + +void AttrsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + using AttrType = OpDescAPI::AttrType; + auto set_attr = [&](const std::string &name, AttrType type) { + switch (type) { +#define IMPL_ONE(type__, T) \ + case AttrType::type__: \ + pb_desc->SetAttr(name, cpp_desc.GetAttr(name)); \ + break; + IMPL_ONE(INT, int32_t); + IMPL_ONE(FLOAT, float); + IMPL_ONE(STRING, std::string); + IMPL_ONE(STRINGS, std::vector); + IMPL_ONE(FLOATS, std::vector); + IMPL_ONE(INTS, std::vector); + IMPL_ONE(BOOLEAN, bool); + default: + LOG(FATAL) << "Unsupported attr type found: " << static_cast(type); + } + }; +#undef IMPL_ONE + for (const auto &attr_name : cpp_desc.AttrNames()) { + auto type = cpp_desc.GetAttrType(attr_name); + set_attr(attr_name, type); + } +} + +void TransformOpDescPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + cpp_desc->SetType(pb_desc.Type()); + InputsPbToCpp(pb_desc, cpp_desc); + OutputsPbToCpp(pb_desc, cpp_desc); + AttrsPbToCpp(pb_desc, cpp_desc); +} + +void TransformOpDescCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + pb_desc->SetType(cpp_desc.Type()); + InputsCppToPb(cpp_desc, pb_desc); + OutputsCppToPb(cpp_desc, pb_desc); + AttrsCppToPb(cpp_desc, pb_desc); +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/compatible_pb.h b/paddle/fluid/lite/model_parser/compatible_pb.h index cef1406f985..23041ea1fe5 100644 --- a/paddle/fluid/lite/model_parser/compatible_pb.h +++ b/paddle/fluid/lite/model_parser/compatible_pb.h @@ -20,39 +20,28 @@ * lite::pb::XXDesc. */ -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include "paddle/fluid/lite/core/framework.pb.h" +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" #include "paddle/fluid/lite/model_parser/pb/op_desc.h" #include "paddle/fluid/lite/model_parser/pb/var_desc.h" -#else -#include "paddle/fluid/framework/op_desc.h" -#include "paddle/fluid/framework/var_desc.h" -#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK namespace paddle { namespace lite { -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK using Attribute = lite::pb::Attribute; using OpDesc = lite::pb::OpDesc; using VarDesc = lite::pb::VarDesc; -#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -using Attribute = framework::Attribute; -using OpDesc = framework::OpDesc; -using VarDesc = framework::VarDesc; -#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK template T GetAttr(const Attribute& x) { return x.get(); } -#else -template -T GetAttr(const Attribute& x) { - return boost::get(x); -} -#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +/// Transform an OpDesc from pb to cpp format. +void TransformOpDescPbToCpp(const pb::OpDesc& pb_desc, cpp::OpDesc* cpp_desc); + +/// Transform an OpDesc from cpp to pb format. +void TransformOpDescCppToPb(const cpp::OpDesc& cpp_desc, pb::OpDesc* pb_desc); } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt new file mode 100644 index 00000000000..71073179991 --- /dev/null +++ b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite) diff --git a/paddle/fluid/lite/model_parser/cpp/op_desc.cc b/paddle/fluid/lite/model_parser/cpp/op_desc.cc new file mode 100644 index 00000000000..01ee4703143 --- /dev/null +++ b/paddle/fluid/lite/model_parser/cpp/op_desc.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" +#include + +namespace paddle { +namespace lite { +namespace cpp { + +#define SET_ATTR_IMPL(T, repr__) \ + template <> \ + void OpDesc::SetAttr(const std::string& name, const T& v) { \ + attr_types_[name] = AttrType::repr__; \ + attrs_[name].set(v); \ + } + +SET_ATTR_IMPL(int32_t, INT); +SET_ATTR_IMPL(float, FLOAT); +SET_ATTR_IMPL(std::string, STRING); +SET_ATTR_IMPL(bool, BOOLEAN); +SET_ATTR_IMPL(std::vector, INTS); +SET_ATTR_IMPL(std::vector, FLOATS); +SET_ATTR_IMPL(std::vector, STRINGS); + +std::pair +FindAttr(const cpp::OpDesc& desc, const std::string& name) { + auto it = desc.attrs().find(name); + CHECK(it != desc.attrs().end()) << "No attributes called " << name + << " found"; + auto attr_it = desc.attr_types().find(name); + CHECK(attr_it != desc.attr_types().end()); + return std::make_pair(it, attr_it); +} + +#define GET_IMPL_ONE(T, repr__) \ + template <> \ + T OpDesc::GetAttr(const std::string& name) const { \ + auto pair = FindAttr(*this, name); \ + CHECK(pair.second->second == AttrType::repr__); \ + return pair.first->second.get(); \ + } + +GET_IMPL_ONE(int32_t, INT); +GET_IMPL_ONE(float, FLOAT); +GET_IMPL_ONE(std::string, STRING); +GET_IMPL_ONE(bool, BOOLEAN); +GET_IMPL_ONE(std::vector, LONGS); +GET_IMPL_ONE(std::vector, FLOATS); +GET_IMPL_ONE(std::vector, INTS); +GET_IMPL_ONE(std::vector, STRINGS); + +} // namespace cpp +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/cpp/op_desc.h b/paddle/fluid/lite/model_parser/cpp/op_desc.h new file mode 100644 index 00000000000..b70c1692659 --- /dev/null +++ b/paddle/fluid/lite/model_parser/cpp/op_desc.h @@ -0,0 +1,126 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/lite/model_parser/desc_apis.h" +#include "paddle/fluid/lite/utils/any.h" +#include "paddle/fluid/lite/utils/varient.h" + +namespace paddle { +namespace lite { +namespace cpp { + +/* + * The cpp::OpDesc is the internal representation for Op. All the internal + * imprementation should use it, not the pb::OpDesc. + */ +class OpDesc : public OpDescAPI { + public: + using attrs_t = std::map; + using attr_types_t = std::map; + + protected: + std::string type_; + std::map> inputs_; + std::map> outputs_; + std::map attrs_; + std::map attr_types_; + + public: + OpDesc() = default; + + std::string Type() const override { return type_; } + void SetType(const std::string& x) override { type_ = x; } + + const std::map>& inputs() const { + return inputs_; + } + const std::map>& outputs() const { + return outputs_; + } + std::map>* mutable_inputs() { + return &inputs_; + } + std::map>* mutable_outputs() { + return &outputs_; + } + std::vector Input(const std::string& param) const override { + auto it = inputs_.find(param); + CHECK(it != inputs_.end()); + return it->second; + } + + std::vector InputArgumentNames() const override { + std::vector res; + for (const auto& x : inputs_) res.push_back(x.first); + return res; + } + std::vector OutputArgumentNames() const override { + std::vector res; + for (const auto& x : outputs_) res.push_back(x.first); + return res; + } + + std::vector Output(const std::string& param) const override { + auto it = outputs_.find(param); + CHECK(it != outputs_.end()); + return it->second; + } + + void SetInput(const std::string& param, + const std::vector& args) override { + inputs_[param] = args; + } + + void SetOutput(const std::string& param, + const std::vector& args) override { + outputs_[param] = args; + } + + bool HasAttr(const std::string& name) const override { + return attrs_.count(name); + } + + AttrType GetAttrType(const std::string& name) const override { + auto it = attr_types_.find(name); + CHECK(it != attr_types_.end()); + return it->second; + } + + std::vector AttrNames() const override { + std::vector res; + for (const auto& x : attrs_) { + res.push_back(x.first); + } + return res; + } + + template + void SetAttr(const std::string& name, const T& v); + + template + T GetAttr(const std::string& name) const; + + const std::map& attrs() const { return attrs_; } + const std::map& attr_types() const { + return attr_types_; + } +}; + +} // namespace cpp +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/desc_apis.h b/paddle/fluid/lite/model_parser/desc_apis.h new file mode 100644 index 00000000000..d28f82a0e73 --- /dev/null +++ b/paddle/fluid/lite/model_parser/desc_apis.h @@ -0,0 +1,85 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace paddle { +namespace lite { + +/* + * Compatible interfaces for all the different kinds of opdesc. All the OpDesc + * classes should implement this. + * NOTE Some interfaces are weried, we remain them unchanged to keep compatible + * with framework::OpDesc in Fluid framework. + */ +class OpDescAPI { + public: + // The AttrType is used to make the proto::AttrType portable. + enum class AttrType { + INT = 0, + FLOAT = 1, + STRING = 2, + INTS = 3, + FLOATS = 4, + STRINGS = 5, + BOOLEAN = 6, + BOOLEANS = 7, + BLOCK = 8, + LONG = 9, + BLOCKS = 10, + LONGS = 11, + UNK, + }; + + virtual ~OpDescAPI() = default; + + /// Get operator's type. + virtual std::string Type() const = 0; + /// Set operator's type. + virtual void SetType(const std::string& type) = 0; + /// Get arguments given the parameter. + virtual std::vector Input(const std::string& param) const = 0; + /// Get parameters. + virtual std::vector InputArgumentNames() const = 0; + /// Get arguments given the parameter. + virtual std::vector Output(const std::string& param) const = 0; + /// Get parameters. + virtual std::vector OutputArgumentNames() const = 0; + /// Set a input given the parameter and arguments. + virtual void SetInput(const std::string& param, + const std::vector& args) = 0; + virtual void SetOutput(const std::string& param, + const std::vector& args) = 0; + /// Tell whether this desc has an attribute. + virtual bool HasAttr(const std::string& name) const = 0; + + /// Get the type of an attribute. + virtual AttrType GetAttrType(const std::string& name) const = 0; + + virtual std::vector AttrNames() const = 0; + + /// Set an attribute. + template + void SetAttr(const std::string& name, const T& v); + + /// Get an attribute. + template + T GetAttr(const std::string& name) const; +}; + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/op_desc_test.cc b/paddle/fluid/lite/model_parser/op_desc_test.cc new file mode 100644 index 00000000000..df74c626040 --- /dev/null +++ b/paddle/fluid/lite/model_parser/op_desc_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" +#include +#include "paddle/fluid/lite/model_parser/compatible_pb.h" +#include "paddle/fluid/lite/model_parser/pb/op_desc.h" + +namespace paddle { +namespace lite { + +template +void TestX() { + OpDesc desc; + + desc.SetInput("X", {"a", "b"}); + auto X = desc.Input("X"); + ASSERT_EQ(X.size(), 2UL); + ASSERT_EQ(X[0], "a"); + ASSERT_EQ(X[1], "b"); + + desc.SetOutput("Y", {"c", "d"}); + auto Y = desc.Output("Y"); + ASSERT_EQ(Y.size(), 2UL); + ASSERT_EQ(Y[0], "c"); + ASSERT_EQ(Y[1], "d"); + + desc.template SetAttr("aint", 100); + ASSERT_TRUE(desc.HasAttr("aint")); + ASSERT_FALSE(desc.HasAttr("afloat")); + ASSERT_EQ(desc.template GetAttr("aint"), 100); +} + +TEST(OpDesc, Basic) { + TestX(); + TestX(); +} + +TEST(OpDesc, CppToPb) { + cpp::OpDesc desc; + + desc.SetInput("X", {"a", "b"}); + desc.SetOutput("Y", {"c", "d"}); + desc.template SetAttr("aint", 100); + + pb::OpDesc pb_desc; + + TransformOpDescCppToPb(desc, &pb_desc); + { + auto& desc = pb_desc; + auto X = desc.Input("X"); + ASSERT_EQ(X.size(), 2UL); + ASSERT_EQ(X[0], "a"); + ASSERT_EQ(X[1], "b"); + + auto Y = desc.Output("Y"); + ASSERT_EQ(Y.size(), 2UL); + ASSERT_EQ(Y[0], "c"); + ASSERT_EQ(Y[1], "d"); + + ASSERT_TRUE(desc.HasAttr("aint")); + ASSERT_FALSE(desc.HasAttr("afloat")); + ASSERT_EQ(desc.template GetAttr("aint"), 100); + } +} + +TEST(OpDesc, PbToCpp) { + pb::OpDesc desc; + + desc.SetInput("X", {"a", "b"}); + desc.SetOutput("Y", {"c", "d"}); + desc.template SetAttr("aint", 100); + + cpp::OpDesc cpp_desc; + + TransformOpDescPbToCpp(desc, &cpp_desc); + { + auto& desc = cpp_desc; + auto X = desc.Input("X"); + ASSERT_EQ(X.size(), 2UL); + ASSERT_EQ(X[0], "a"); + ASSERT_EQ(X[1], "b"); + + auto Y = desc.Output("Y"); + ASSERT_EQ(Y.size(), 2UL); + ASSERT_EQ(Y[0], "c"); + ASSERT_EQ(Y[1], "d"); + + ASSERT_TRUE(desc.HasAttr("aint")); + ASSERT_FALSE(desc.HasAttr("afloat")); + ASSERT_EQ(desc.template GetAttr("aint"), 100); + } +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.cc b/paddle/fluid/lite/model_parser/pb/op_desc.cc index 27ccc5c686a..1de4fb275e4 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/op_desc.cc @@ -18,10 +18,9 @@ namespace paddle { namespace lite { namespace pb { -template <> -void OpDesc::SetAttr(const std::string &name, - const std::string &v) { - auto &xs = *desc_.mutable_attrs(); +google::protobuf::internal::RepeatedPtrIterator +FindAttr(framework::proto::OpDesc *desc, const std::string &name) { + auto &xs = *desc->mutable_attrs(); auto it = std::find_if( xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); @@ -33,33 +32,95 @@ void OpDesc::SetAttr(const std::string &name, return x.name() == name; }); } + return it; +} + +#define SET_IMPL_ONE(T, ty__, pb_f__) \ + template <> \ + void OpDesc::SetAttr(const std::string &name, const T &v) { \ + auto it = FindAttr(&desc_, name); \ + it->set_type(framework::proto::ty__); \ + it->set_##pb_f__(v); \ + } +SET_IMPL_ONE(int, INT, i); +SET_IMPL_ONE(float, FLOAT, f); +SET_IMPL_ONE(bool, FLOAT, f); + +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(&desc_, name); + it->set_type(framework::proto::INTS); + it->clear_ints(); + for (auto &i : v) { + it->add_ints(i); + } +} +template <> +void OpDesc::SetAttr(const std::string &name, + const std::string &v) { + auto it = FindAttr(&desc_, name); it->set_type(framework::proto::STRING); it->set_s(v.c_str()); } template <> -void OpDesc::SetAttr>(const std::string &name, - const std::vector &v) { - auto &xs = *desc_.mutable_attrs(); +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(&desc_, name); + it->set_type(framework::proto::FLOATS); + it->clear_floats(); + for (auto &i : v) { + it->add_floats(i); + } +} + +template <> +void OpDesc::SetAttr>( + const std::string &name, const std::vector &v) { + auto it = FindAttr(&desc_, name); + it->set_type(framework::proto::STRINGS); + it->clear_strings(); + for (auto &i : v) { + it->add_strings(i); + } +} + +google::protobuf::internal::RepeatedPtrIterator< + const framework::proto::OpDesc_Attr> +GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) { + auto &xs = desc.attrs(); auto it = std::find_if( xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); - if (it == xs.end()) { - auto *attr = xs.Add(); - attr->set_name(name); - it = std::find_if(xs.begin(), xs.end(), - [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == name; - }); + return it; +} + +#define GET_ATTR_IMPL(T, pb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string &name) const { \ + auto it = GetFindAttr(desc_, name); \ + return it->pb_f__(); \ } - it->set_type(framework::proto::INTS); - it->clear_ints(); - for (auto &i : v) { - it->add_ints(i); +#define GET_ATTRS_IMPL(T, pb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string &name) const { \ + auto it = GetFindAttr(desc_, name); \ + T res; \ + for (const auto &v : it->pb_f__()) { \ + res.push_back(v); \ + } \ + return res; \ } -} +GET_ATTR_IMPL(int32_t, i); +GET_ATTR_IMPL(float, f); +GET_ATTR_IMPL(bool, b); +GET_ATTRS_IMPL(std::vector, ints); +GET_ATTRS_IMPL(std::vector, floats); +GET_ATTRS_IMPL(std::vector, strings); +GET_ATTR_IMPL(std::string, s); } // namespace pb } // namespace lite diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.h b/paddle/fluid/lite/model_parser/pb/op_desc.h index 0be809da8db..e8772e162a5 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.h +++ b/paddle/fluid/lite/model_parser/pb/op_desc.h @@ -27,6 +27,7 @@ #include #include #include "paddle/fluid/lite/core/framework.pb.h" +#include "paddle/fluid/lite/model_parser/desc_apis.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { @@ -43,7 +44,7 @@ using VariableNameMap = std::map>; * except the desc_, to avoid the inconsistent state, which is normal in the * original interface and results in bugs. */ -class OpDesc { +class OpDesc : public OpDescAPI { public: OpDesc() {} @@ -54,38 +55,38 @@ class OpDesc { framework::proto::OpDesc *Proto() { return &desc_; } const framework::proto::OpDesc &ReadonlyProto() const { return desc_; } - std::string Type() const { return desc_.type(); } + std::string Type() const override { return desc_.type(); } - void SetType(const std::string &type) { desc_.set_type(type); } + void SetType(const std::string &type) override { desc_.set_type(type); } // Get the arguments of parameter called `param` - std::vector Input(const std::string ¶m) const { + std::vector Input(const std::string ¶m) const override { return GetArguments(desc_.inputs(), param); } - std::vector InputArgumentNames() const { + std::vector InputArgumentNames() const override { return GetArgumentNames(desc_.inputs()); } void SetInput(const std::string ¶m, - const std::vector &args) { + const std::vector &args) override { SetArgument(desc_.mutable_inputs(), param, args); } - std::vector Output(const std::string ¶m) const { + std::vector Output(const std::string ¶m) const override { return GetArguments(desc_.outputs(), param); } - std::vector OutputArgumentNames() const { + std::vector OutputArgumentNames() const override { return GetArgumentNames(desc_.outputs()); } void SetOutput(const std::string ¶m, - const std::vector &args) { + const std::vector &args) override { SetArgument(desc_.mutable_outputs(), param, args); } - bool HasAttr(const std::string &name) const { + bool HasAttr(const std::string &name) const override { const auto &xs = desc_.attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { @@ -94,17 +95,38 @@ class OpDesc { return it != xs.end(); } - framework::proto::AttrType GetAttrType(const std::string &name) const { + AttrType GetAttrType(const std::string &name) const override { const auto &xs = desc_.attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); CHECK(it != xs.end()); - return it->type(); +#define DEF_ONE(type__) \ + case framework::proto::AttrType::type__: \ + return AttrType::type__; + + switch (it->type()) { + DEF_ONE(INT); + DEF_ONE(FLOAT); + DEF_ONE(STRING); + DEF_ONE(INTS); + DEF_ONE(FLOATS); + DEF_ONE(STRINGS); + DEF_ONE(BOOLEAN); + DEF_ONE(BOOLEANS); + DEF_ONE(BLOCK); + DEF_ONE(LONG); + DEF_ONE(BLOCKS); + DEF_ONE(LONGS); + default: + LOG(ERROR) << "Unknown attribute type"; + return AttrType::UNK; + } +#undef DEF_ONE } - std::vector AttrNames() const { + std::vector AttrNames() const override { std::vector res; const auto &xs = desc_.attrs(); std::transform( @@ -114,72 +136,10 @@ class OpDesc { } template - void SetAttr(const std::string &name, const T &v) { - auto &xs = *desc_.mutable_attrs(); - auto it = std::find_if(xs.begin(), xs.end(), - [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == name; - }); - if (it == xs.end()) { - auto *attr = xs.Add(); - attr->set_name(name); - it = std::find_if(xs.begin(), xs.end(), - [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == name; - }); - } + void SetAttr(const std::string &name, const T &v); - size_t hash = typeid(T).hash_code(); - if (hash == typeid(int).hash_code()) { // NOLINT - it->set_type(framework::proto::INT); - it->set_i(v); - } else if (hash == typeid(float).hash_code()) { // NOLINT - it->set_type(framework::proto::FLOAT); - it->set_f(v); - } else if (hash == typeid(bool).hash_code()) { // NOLINT - it->set_type(framework::proto::BOOLEAN); - it->set_b(v); - } else { - LOG(FATAL) << "unsupport attr type"; - } - } - - Attribute GetAttr(const std::string &name) const { - auto &xs = desc_.attrs(); - auto it = std::find_if(xs.begin(), xs.end(), - [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == name; - }); - - Attribute res; - CHECK(it != xs.end()); - switch (it->type()) { - case framework::proto::INT: - res.set(it->i()); - break; - case framework::proto::FLOAT: - res.set(it->f()); - break; - case framework::proto::STRING: - res.set(it->s()); - break; - case framework::proto::BOOLEAN: - res.set(it->b()); - break; - case framework::proto::INTS: { - std::vector values; - const auto &ys = it->ints(); - std::transform(ys.begin(), ys.end(), std::back_inserter(values), - [](const int &x) { return x; }); - res.set>(values); - } break; - - default: - LOG(FATAL) << "unsupported attr type"; - } - - return res; - } + template + T GetAttr(const std::string &name) const; private: std::vector GetArguments( diff --git a/paddle/fluid/lite/operators/activation_ops.cc b/paddle/fluid/lite/operators/activation_ops.cc index 4b99c4d9e06..8cda67af14a 100644 --- a/paddle/fluid/lite/operators/activation_ops.cc +++ b/paddle/fluid/lite/operators/activation_ops.cc @@ -33,7 +33,7 @@ class ActivationOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto X_name = opdesc.Input("X").front(); auto Out_name = opdesc.Output("Out").front(); @@ -66,7 +66,7 @@ class ActivationGradOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); diff --git a/paddle/fluid/lite/operators/concat_op.cc b/paddle/fluid/lite/operators/concat_op.cc index e8fd910f9d0..e9b773ceebd 100644 --- a/paddle/fluid/lite/operators/concat_op.cc +++ b/paddle/fluid/lite/operators/concat_op.cc @@ -54,7 +54,7 @@ bool ConcatOpLite::InferShape() const { } // TODO(Superjomn) replace framework::OpDesc with a lite one. -bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { +bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto inputs = op_desc.Input("X"); auto out = op_desc.Output("Out").front(); @@ -63,7 +63,7 @@ bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { } CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.axis = GetAttr(op_desc.GetAttr("axis")); + param_.axis = op_desc.GetAttr("axis"); return true; } diff --git a/paddle/fluid/lite/operators/concat_op.h b/paddle/fluid/lite/operators/concat_op.h index 86f58be45f3..17408289a61 100644 --- a/paddle/fluid/lite/operators/concat_op.h +++ b/paddle/fluid/lite/operators/concat_op.h @@ -32,7 +32,7 @@ class ConcatOpLite : public OpLite { bool InferShape() const override; - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "concat"; } diff --git a/paddle/fluid/lite/operators/concat_op_test.cc b/paddle/fluid/lite/operators/concat_op_test.cc index d5a412893ee..3af3fc8ef78 100644 --- a/paddle/fluid/lite/operators/concat_op_test.cc +++ b/paddle/fluid/lite/operators/concat_op_test.cc @@ -42,7 +42,7 @@ TEST(concat_op_lite, test) { } // prepare op desc - lite::OpDesc desc; + cpp::OpDesc desc; desc.SetType("concat"); desc.SetInput("X", {"x0", "x1"}); desc.SetOutput("Out", {"output"}); diff --git a/paddle/fluid/lite/operators/dropout_op.cc b/paddle/fluid/lite/operators/dropout_op.cc index cc0761b2bc7..b5b50dc3d16 100644 --- a/paddle/fluid/lite/operators/dropout_op.cc +++ b/paddle/fluid/lite/operators/dropout_op.cc @@ -42,7 +42,7 @@ class DropoutOpLite : public OpLite { void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const OpDesc& op_desc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { auto input = op_desc.Input("X").front(); auto out = op_desc.Output("Out").front(); auto Mask = op_desc.Output("Mask").front(); @@ -51,14 +51,14 @@ class DropoutOpLite : public OpLite { param_.output = GetMutableVar(scope, out); param_.mask = GetMutableVar(scope, Mask); - param_.dropout_prob = boost::get(op_desc.GetAttr("dropout_prob")); + param_.dropout_prob = op_desc.GetAttr("dropout_prob"); if (op_desc.HasAttr("axis")) { - param_.is_test = boost::get(op_desc.GetAttr("is_test")); + param_.is_test = op_desc.GetAttr("is_test"); } - param_.fix_seed = boost::get(op_desc.GetAttr("fix_seed")); - param_.seed = boost::get(op_desc.GetAttr("seed")); + param_.fix_seed = op_desc.GetAttr("fix_seed"); + param_.seed = op_desc.GetAttr("seed"); param_.dropout_implementation = - boost::get(op_desc.GetAttr("dropout_implementation")); + op_desc.GetAttr("dropout_implementation"); return true; } diff --git a/paddle/fluid/lite/operators/elementwise_ops.cc b/paddle/fluid/lite/operators/elementwise_ops.cc index 0ca89cccf30..b400b1ab26c 100644 --- a/paddle/fluid/lite/operators/elementwise_ops.cc +++ b/paddle/fluid/lite/operators/elementwise_ops.cc @@ -36,7 +36,7 @@ class ElementwiseOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto X_name = opdesc.Input("X").front(); auto Y_name = opdesc.Input("Y").front(); auto Out_name = opdesc.Output("Out").front(); @@ -44,7 +44,7 @@ class ElementwiseOp : public OpLite { param_.X = GetVar(scope, X_name); param_.Y = GetVar(scope, Y_name); param_.Out = GetMutableVar(scope, Out_name); - param_.axis = boost::get(opdesc.GetAttr("axis")); + param_.axis = opdesc.GetAttr("axis"); return true; } @@ -75,8 +75,8 @@ class ElementwiseGradExplicitOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { - CHECK_EQ(opdesc.Inputs().size(), 1UL); + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL); auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); auto X_name = opdesc.Output(framework::GradVarName("X")).front(); auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); @@ -84,7 +84,7 @@ class ElementwiseGradExplicitOp : public OpLite { param_.Out_grad = GetVar(scope, Out_name); param_.X_grad = GetMutableVar(scope, X_name); param_.Y_grad = GetMutableVar(scope, Y_name); - param_.axis = boost::get(opdesc.GetAttr("axis")); + param_.axis = opdesc.GetAttr("axis"); return true; } diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index a6043fa7b1f..0e738018322 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -46,7 +46,7 @@ class FcOpLite : public OpLite { */ // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto input = op_desc.Input("Input").front(); auto W = op_desc.Input("W").front(); auto bias = op_desc.Input("Bias").front(); @@ -57,7 +57,7 @@ class FcOpLite : public OpLite { param_.bias = scope->FindVar(bias)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.in_num_col_dims = GetAttr(op_desc.GetAttr("in_num_col_dims")); + param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims"); return true; } diff --git a/paddle/fluid/lite/operators/fc_op_test.cc b/paddle/fluid/lite/operators/fc_op_test.cc index 9ef91dbc147..880b8a820e5 100644 --- a/paddle/fluid/lite/operators/fc_op_test.cc +++ b/paddle/fluid/lite/operators/fc_op_test.cc @@ -47,7 +47,7 @@ TEST(fc_op_lite, TestX86) { } // prepare op desc - lite::OpDesc desc; + cpp::OpDesc desc; desc.SetType("fc"); desc.SetInput("Input", {"x"}); desc.SetInput("W", {"w"}); diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc index 45a7c198cb6..8c7d33e9e59 100644 --- a/paddle/fluid/lite/operators/feed_op.cc +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -34,7 +34,7 @@ class FeedOp : public OpLite { void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto feed_var_name = opdesc.Input("X").front(); auto* feed_var = scope->FindVar(feed_var_name); CHECK(feed_var); @@ -48,7 +48,7 @@ class FeedOp : public OpLite { // NOTE need boost here // TODO(Superjomn) drop the need of framework::op_desc - param_.col = GetAttr(opdesc.GetAttr("col")); + param_.col = opdesc.GetAttr("col"); return true; } diff --git a/paddle/fluid/lite/operators/fetch_op.cc b/paddle/fluid/lite/operators/fetch_op.cc index 337a6ecc9d5..51efda776b2 100644 --- a/paddle/fluid/lite/operators/fetch_op.cc +++ b/paddle/fluid/lite/operators/fetch_op.cc @@ -33,7 +33,7 @@ class FetchOp : public OpLite { void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto _x = opdesc.Input("X").front(); auto* x = scope->FindVar(_x); CHECK(x); @@ -43,7 +43,7 @@ class FetchOp : public OpLite { auto* out = scope->FindVar(_out); param_.fetch_list = out->GetMutable>(); - param_.col = GetAttr(opdesc.GetAttr("col")); + param_.col = opdesc.GetAttr("col"); return true; } diff --git a/paddle/fluid/lite/operators/fill_constant_op.cc b/paddle/fluid/lite/operators/fill_constant_op.cc index f701dd48775..b762f0d3c92 100644 --- a/paddle/fluid/lite/operators/fill_constant_op.cc +++ b/paddle/fluid/lite/operators/fill_constant_op.cc @@ -33,14 +33,14 @@ class FillConstantOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto Out_name = opdesc.Output("Out").front(); param_.Out = GetMutableVar(scope, Out_name); - param_.dtype = GetAttr(opdesc.GetAttr("dtype")); - param_.shape = GetAttr>(opdesc.GetAttr("shape")); - param_.value = GetAttr(opdesc.GetAttr("value")); - param_.force_cpu = GetAttr(opdesc.GetAttr("force_cpu")); + param_.dtype = opdesc.GetAttr("dtype"); + param_.shape = opdesc.GetAttr>("shape"); + param_.value = opdesc.GetAttr("value"); + param_.force_cpu = opdesc.GetAttr("force_cpu"); return true; } diff --git a/paddle/fluid/lite/operators/io_copy_op.cc b/paddle/fluid/lite/operators/io_copy_op.cc index 220853fc263..44d49a30a0e 100644 --- a/paddle/fluid/lite/operators/io_copy_op.cc +++ b/paddle/fluid/lite/operators/io_copy_op.cc @@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const { return true; } bool IoCopyOp::Run() { return OpLite::Run(); } -bool IoCopyOp::AttachImpl(const OpDesc &opdesc, paddle::lite::Scope *scope) { +bool IoCopyOp::AttachImpl(const cpp::OpDesc &opdesc, + paddle::lite::Scope *scope) { auto x = opdesc.Input("Input").front(); auto out = opdesc.Output("Out").front(); param_.x = GetTensor(scope, x); diff --git a/paddle/fluid/lite/operators/io_copy_op.h b/paddle/fluid/lite/operators/io_copy_op.h index efcd11bc309..dd95ef8d33a 100644 --- a/paddle/fluid/lite/operators/io_copy_op.h +++ b/paddle/fluid/lite/operators/io_copy_op.h @@ -31,7 +31,7 @@ class IoCopyOp : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; private: operators::IoCopyParam param_; diff --git a/paddle/fluid/lite/operators/mean_op.cc b/paddle/fluid/lite/operators/mean_op.cc index 20e3709872f..411dcbb735a 100644 --- a/paddle/fluid/lite/operators/mean_op.cc +++ b/paddle/fluid/lite/operators/mean_op.cc @@ -37,7 +37,7 @@ class MeanOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto X_name = opdesc.Input("X").front(); auto Out_name = opdesc.Output("Out").front(); @@ -72,8 +72,8 @@ class MeanGradOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { - CHECK_EQ(opdesc.Inputs().size(), 3UL); + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.InputArgumentNames().size(), 3UL); auto X_name = opdesc.Input("X").front(); auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); diff --git a/paddle/fluid/lite/operators/mul_op.cc b/paddle/fluid/lite/operators/mul_op.cc index 8958d91956a..70eb37dd09b 100644 --- a/paddle/fluid/lite/operators/mul_op.cc +++ b/paddle/fluid/lite/operators/mul_op.cc @@ -85,7 +85,7 @@ bool MulGradOpLite::InferShape() const { return true; } -bool MulGradOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { +bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto X_name = op_desc.Input("X").front(); auto Y_name = op_desc.Input("Y").front(); auto Out_grad_name = op_desc.Output(framework::GradVarName("Out")).front(); diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 73827753bd2..e21540d2c6f 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -37,7 +37,7 @@ class MulOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto input = op_desc.Input("X").front(); auto W = op_desc.Input("Y").front(); auto out = op_desc.Output("Out").front(); @@ -49,8 +49,8 @@ class MulOpLite : public OpLite { param_.y = var->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.x_num_col_dims = GetAttr(op_desc.GetAttr("x_num_col_dims")); - param_.y_num_col_dims = GetAttr(op_desc.GetAttr("y_num_col_dims")); + param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims"); + param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims"); return true; } @@ -73,7 +73,7 @@ class MulGradOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; std::string DebugString() const override { return "mul_grad"; } diff --git a/paddle/fluid/lite/operators/relu_op.cc b/paddle/fluid/lite/operators/relu_op.cc index 4fa02c5eb94..b073e2db43a 100644 --- a/paddle/fluid/lite/operators/relu_op.cc +++ b/paddle/fluid/lite/operators/relu_op.cc @@ -30,7 +30,7 @@ bool ReluOp::InferShape() const { return true; } -bool ReluOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { +bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.input = const_cast( &scope->FindVar(opdesc.Input("Input").front())->Get()); param_.output = diff --git a/paddle/fluid/lite/operators/relu_op.h b/paddle/fluid/lite/operators/relu_op.h index ffb03368788..945a9680a75 100644 --- a/paddle/fluid/lite/operators/relu_op.h +++ b/paddle/fluid/lite/operators/relu_op.h @@ -32,7 +32,7 @@ class ReluOp : public OpLite { bool InferShape() const override; - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "relu"; } diff --git a/paddle/fluid/lite/operators/reshape_op.cc b/paddle/fluid/lite/operators/reshape_op.cc index bf43f52340c..6fc9c1af1e6 100644 --- a/paddle/fluid/lite/operators/reshape_op.cc +++ b/paddle/fluid/lite/operators/reshape_op.cc @@ -33,7 +33,7 @@ bool ReshapeOp::InferShape() const { return true; } -bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { +bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { auto x_var = scope->FindVar(opdesc.Input("X").front()); auto output_var = scope->FindVar(opdesc.Output("Out").front()); CHECK(x_var); @@ -49,9 +49,9 @@ bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { const_cast(&(actual_shape_var->Get())); } } - param_.shape = GetAttr>(opdesc.GetAttr("shape")); + param_.shape = (opdesc.GetAttr>("shape")); if (opdesc.HasAttr("inplace")) { - param_.inplace = GetAttr(opdesc.GetAttr("inplace")); + param_.inplace = opdesc.GetAttr("inplace"); } CHECK(param_.x) << "Input(X) of ReshapeOp should not be null."; CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null."; @@ -70,14 +70,14 @@ bool Reshape2Op::InferShape() const { ReshapeOp::InferShape(); auto x_dims = param_.x->dims(); std::vector xshape_dims(x_dims.size() + 1, 0); - for (int i = 0; i < x_dims.size(); i++) { + for (size_t i = 0; i < x_dims.size(); i++) { xshape_dims[i + 1] = x_dims[i]; } param_.xshape->Resize(DDim(xshape_dims)); return true; } -bool Reshape2Op::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { +bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ReshapeOp::AttachImpl(opdesc, scope); auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); CHECK(xshape_var); diff --git a/paddle/fluid/lite/operators/reshape_op.h b/paddle/fluid/lite/operators/reshape_op.h index d96da8d5d01..4f7e0b9c134 100644 --- a/paddle/fluid/lite/operators/reshape_op.h +++ b/paddle/fluid/lite/operators/reshape_op.h @@ -32,7 +32,7 @@ class ReshapeOp : public OpLite { bool InferShape() const override; - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "reshape"; } @@ -50,7 +50,7 @@ class Reshape2Op : public ReshapeOp { bool InferShape() const override; - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "reshape2"; } diff --git a/paddle/fluid/lite/operators/reshape_op_test.cc b/paddle/fluid/lite/operators/reshape_op_test.cc index 41f6999c1d8..4bf137f16fe 100644 --- a/paddle/fluid/lite/operators/reshape_op_test.cc +++ b/paddle/fluid/lite/operators/reshape_op_test.cc @@ -47,7 +47,7 @@ TEST(reshape_op_lite, test) { for (auto& has_actual_shape : {true, false}) { for (auto& inplace : {true, false}) { // prepare op desc - lite::OpDesc desc; + cpp::OpDesc desc; desc.SetType("reshape"); desc.SetInput("X", {"x"}); if (has_actual_shape) { @@ -68,7 +68,7 @@ TEST(reshape_op_lite, test) { // check output dims auto output_dims = output->dims(); CHECK_EQ(output_dims.size(), shape.second.size()); - for (int i = 0; i < output_dims.size(); i++) { + for (size_t i = 0; i < output_dims.size(); i++) { CHECK_EQ(output_dims[i], shape.second[i]); } } @@ -102,7 +102,7 @@ TEST(reshape2_op_lite, test) { for (auto& has_actual_shape : {true, false}) { for (auto& inplace : {true, false}) { // prepare op desc - lite::OpDesc desc; + cpp::OpDesc desc; desc.SetType("reshape"); desc.SetInput("X", {"x"}); if (has_actual_shape) { @@ -132,7 +132,7 @@ TEST(reshape2_op_lite, test) { auto xshape_dims = xshape->dims(); CHECK_EQ(xshape_dims.size(), x_dims.size() + 1); CHECK_EQ(xshape_dims[0], 0); - for (int i = 0; i < x_dims.size(); i++) { + for (size_t i = 0; i < x_dims.size(); i++) { CHECK_EQ(xshape_dims[i + 1], x_dims[i]); } } diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc index 0a6dec991a0..fb55366488c 100644 --- a/paddle/fluid/lite/operators/scale_op.cc +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -29,14 +29,14 @@ bool ScaleOp::InferShape() const { return true; } -bool ScaleOp::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { +bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto x = op_desc.Input("X").front(); auto output = op_desc.Output("Out").front(); param_.x = scope->FindVar(x)->GetMutable(); param_.output = scope->FindVar(output)->GetMutable(); - param_.scale = GetAttr(op_desc.GetAttr("scale")); - param_.bias = GetAttr(op_desc.GetAttr("bias")); - param_.bias_after_scale = GetAttr(op_desc.GetAttr("bias_after_scale")); + param_.scale = op_desc.GetAttr("scale"); + param_.bias = op_desc.GetAttr("bias"); + param_.bias_after_scale = op_desc.GetAttr("bias_after_scale"); CHECK(param_.x); CHECK(param_.output); return true; diff --git a/paddle/fluid/lite/operators/scale_op.h b/paddle/fluid/lite/operators/scale_op.h index 8866e6a29b7..43493710bba 100644 --- a/paddle/fluid/lite/operators/scale_op.h +++ b/paddle/fluid/lite/operators/scale_op.h @@ -32,7 +32,7 @@ class ScaleOp : public OpLite { bool InferShape() const override; - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "scale"; } diff --git a/paddle/fluid/lite/operators/scale_op_test.cc b/paddle/fluid/lite/operators/scale_op_test.cc index ad61a27a1c3..33ab91ff05c 100644 --- a/paddle/fluid/lite/operators/scale_op_test.cc +++ b/paddle/fluid/lite/operators/scale_op_test.cc @@ -29,7 +29,7 @@ TEST(scale_op_lite, test) { output->Resize(DDim(std::vector{1, 1})); // prepare op desc - lite::OpDesc desc; + cpp::OpDesc desc; desc.SetType("scale"); desc.SetInput("X", {"x"}); desc.SetOutput("Out", {"output"}); @@ -48,7 +48,7 @@ TEST(scale_op_lite, test) { auto x_dims = x->dims(); auto output_dims = output->dims(); CHECK_EQ(output_dims.size(), x_dims.size()); - for (int i = 0; i < output_dims.size(); i++) { + for (size_t i = 0; i < output_dims.size(); i++) { CHECK_EQ(output_dims[i], x_dims[i]); } } diff --git a/paddle/fluid/lite/operators/softmax_op.cc b/paddle/fluid/lite/operators/softmax_op.cc index 518d6a3d36a..41d7b335e80 100644 --- a/paddle/fluid/lite/operators/softmax_op.cc +++ b/paddle/fluid/lite/operators/softmax_op.cc @@ -24,7 +24,8 @@ bool SoftmaxOp::CheckShape() const { CHECK_OR_FALSE(param_.output); auto x_dims = param_.x->dims(); auto x_rank = x_dims.size(); - CHECK_OR_FALSE(param_.axis >= -x_rank && param_.axis < x_rank); + CHECK_OR_FALSE(param_.axis >= -static_cast(x_rank) && + param_.axis < static_cast(x_rank)); return true; } @@ -33,12 +34,12 @@ bool SoftmaxOp::InferShape() const { return true; } -bool SoftmaxOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { +bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.x = const_cast( &scope->FindVar(opdesc.Input("X").front())->Get()); param_.output = scope->FindVar(opdesc.Output("Out").front())->GetMutable(); - param_.axis = GetAttr(opdesc.GetAttr("axis")); + param_.axis = opdesc.GetAttr("axis"); CHECK(param_.x); CHECK(param_.output); return true; diff --git a/paddle/fluid/lite/operators/softmax_op.h b/paddle/fluid/lite/operators/softmax_op.h index 062f707c6e0..515e4493c99 100644 --- a/paddle/fluid/lite/operators/softmax_op.h +++ b/paddle/fluid/lite/operators/softmax_op.h @@ -32,7 +32,7 @@ class SoftmaxOp : public OpLite { bool InferShape() const override; - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "softmax"; } diff --git a/paddle/fluid/lite/operators/softmax_op_test.cc b/paddle/fluid/lite/operators/softmax_op_test.cc index f999564541a..4659b35cd7b 100644 --- a/paddle/fluid/lite/operators/softmax_op_test.cc +++ b/paddle/fluid/lite/operators/softmax_op_test.cc @@ -37,7 +37,7 @@ TEST(softmax_op_lite, test) { } // prepare op desc - lite::OpDesc desc; + cpp::OpDesc desc; desc.SetType("softmax"); desc.SetInput("X", {"x"}); desc.SetOutput("Out", {"output"}); -- GitLab