提交 65be35af 编写于 作者: Y Yan Chunwei 提交者: GitHub

Lite/refactor cp desc (#17831)

上级 b7cf0984
...@@ -25,7 +25,9 @@ cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite) ...@@ -25,7 +25,9 @@ cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
cc_library(scope_lite SRCS scope.cc DEPS ${tensor_lite}) cc_library(scope_lite SRCS scope.cc DEPS ${tensor_lite})
cc_library(cpu_info_lite SRCS cpu_info.cc) cc_library(cpu_info_lite SRCS cpu_info.cc)
cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite) cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite ${tensor_lite}) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite
cpp_op_desc_lite
${tensor_lite})
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
......
...@@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
inst_node->AsStmt().op->scope()->Var(io_copy_output_name); inst_node->AsStmt().op->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction. // Create IoCopy Instruction.
lite::OpDesc op_desc; cpp::OpDesc op_desc;
op_desc.SetType("io_copy"); op_desc.SetType("io_copy");
op_desc.SetInput("Input", {var}); op_desc.SetInput("Input", {var});
op_desc.SetOutput("Out", {io_copy_output_name}); op_desc.SetOutput("Out", {io_copy_output_name});
...@@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Update the original instruction OpDesc. // Update the original instruction OpDesc.
// Update its input to the io_copy_output_name // Update its input to the io_copy_output_name
auto& inst = inst_node->AsStmt();
auto inst_program_desc = inst.op_info()->desc();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst // Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(graph->Argument(var), io_copy_inst); DirectedLink(graph->Argument(var), io_copy_inst);
...@@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink(io_copy_output_arg, inst_node); DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
auto desc_dummy = inst_node->AsStmt().op->op_info()->desc(); UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var,
UpdateInputTo(&desc_dummy, var, io_copy_output_name); io_copy_output_name);
lite::OpDesc desc_fake(desc_dummy); inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(),
inst_node->AsStmt().op->Attach(desc_fake, inst_node->AsStmt().op->scope()); inst_node->AsStmt().op->scope());
std::string tmp; std::string tmp;
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
......
...@@ -24,10 +24,10 @@ namespace paddle { ...@@ -24,10 +24,10 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
static void UpdateInputTo(framework::proto::OpDesc* desc, static void UpdateInputTo(cpp::OpDesc* desc, const std::string& from,
const std::string& from, const std::string& to) { const std::string& to) {
for (auto& item : *desc->mutable_inputs()) { for (auto& item : *desc->mutable_inputs()) {
for (auto& input : *item.mutable_arguments()) { for (auto& input : item.second) {
if (input == from) { if (input == from) {
input = to; input = to;
} }
......
...@@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass {
// check if inputs's place is set, if not set, update them with the // check if inputs's place is set, if not set, update them with the
// kernel's declaration. // kernel's declaration.
auto type = inst.picked_kernel().GetInputDeclType(arg_name); auto type = inst.picked_kernel().GetInputDeclType(arg_name);
auto arg_names = inst.op_info()->input_argument().at(arg_name); auto arg_names = inst.op_info()->inputs().at(arg_name);
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
VLOG(3) << "--- var " << arg_name; VLOG(3) << "--- var " << arg_name;
...@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
for (auto& arg_name : inst.op_info()->output_argnames()) { for (auto& arg_name : inst.op_info()->output_argnames()) {
VLOG(3) << "-- output arg_name " << arg_name; VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name); auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
auto arg_names = inst.op_info()->output_argument().at(arg_name); auto arg_names = inst.op_info()->outputs().at(arg_name);
// check if outputs's place is set, if not set, update them with the // check if outputs's place is set, if not set, update them with the
// kernel's declaration. // kernel's declaration.
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
......
...@@ -68,13 +68,13 @@ bool OpLite::Run() { ...@@ -68,13 +68,13 @@ bool OpLite::Run() {
return true; return true;
} }
bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) { bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
// valid_places_.clear(); // valid_places_.clear();
CHECK(scope != nullptr); CHECK(scope != nullptr);
// CHECK(!op_info_.get()); // CHECK(!op_info_.get());
scope_ = scope; scope_ = scope;
op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. op_info_.reset(
op_info_->Build(opdesc.ReadonlyProto()); new OpInfo(opdesc)); // Force clean the out-of-date infomation.
return AttachImpl(opdesc, scope); return AttachImpl(opdesc, scope);
} }
...@@ -92,94 +92,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope, ...@@ -92,94 +92,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return var->GetMutable<lite::Tensor>(); return var->GetMutable<lite::Tensor>();
} }
bool OpInfo::GetInputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : input_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool OpInfo::GetOutputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : output_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
void OpInfo::ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (const auto &x : item.arguments()) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (const auto &x : item.arguments()) {
output_names_.push_back(x);
}
}
}
void OpInfo::CollectInputAndOutputArgnames(
const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
input_argnames_.push_back(item.parameter());
}
for (const auto &item : opdesc.outputs()) {
output_argnames_.push_back(item.parameter());
}
}
void OpInfo::CollectArguments(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (auto &x : item.arguments()) {
input_argument_[item.parameter()].push_back(x);
}
}
for (const auto &item : opdesc.outputs()) {
for (auto &x : item.arguments()) {
output_argument_[item.parameter()].push_back(x);
}
}
}
void OpInfo::Build(const framework::proto::OpDesc &desc) {
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
desc_.reset(new framework::proto::OpDesc(desc));
}
const std::map<std::string, std::list<std::string>> &OpInfo::input_argument()
const {
return input_argument_;
}
const std::map<std::string, std::list<std::string>> &OpInfo::output_argument()
const {
return output_argument_;
}
const std::list<std::string> &OpInfo::input_argnames() const {
return input_argnames_;
}
const std::list<std::string> &OpInfo::output_argnames() const {
return output_argnames_;
}
const framework::proto::OpDesc &OpInfo::desc() const {
CHECK(desc_) << "desc has't set";
return *desc_;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h" #include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -71,7 +71,7 @@ class OpLite : public Registry { ...@@ -71,7 +71,7 @@ class OpLite : public Registry {
virtual bool Run(); virtual bool Run();
// Link the external execution environ to internal context. // Link the external execution environ to internal context.
bool Attach(const OpDesc &opdesc, lite::Scope *scope); bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope);
const OpInfo *op_info() const { return op_info_.get(); } const OpInfo *op_info() const { return op_info_.get(); }
OpInfo *mutable_op_info() { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); }
...@@ -94,7 +94,7 @@ class OpLite : public Registry { ...@@ -94,7 +94,7 @@ class OpLite : public Registry {
protected: protected:
// Attach it with the runtime environment. // Attach it with the runtime environment.
virtual bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) = 0; virtual bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) = 0;
// Specify the kernel to run by default. This will specify the value of // Specify the kernel to run by default. This will specify the value of
// `kernel_place_`. // `kernel_place_`.
...@@ -144,40 +144,61 @@ class OpLite : public Registry { ...@@ -144,40 +144,61 @@ class OpLite : public Registry {
* Operator Information, such as some description. It will be shared by all the * Operator Information, such as some description. It will be shared by all the
* kernels of the same operator. * kernels of the same operator.
*/ */
class OpInfo { class OpInfo : public cpp::OpDesc {
public: public:
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf OpInfo(const OpInfo &) = default;
// message instead. OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {}
void Build(const framework::proto::OpDesc &desc);
// Collect all the input variable's name.
const framework::proto::OpDesc &desc() const; std::vector<std::string> input_names() const {
framework::proto::OpDesc *mutable_desc() { return desc_.get(); } std::vector<std::string> res;
const std::list<std::string> &input_names() const { return input_names_; } for (auto &param : InputArgumentNames()) {
const std::list<std::string> &output_names() const { return output_names_; } for (auto &x : Input(param)) {
const std::map<std::string, std::list<std::string>> &input_argument() const; res.push_back(x);
const std::map<std::string, std::list<std::string>> &output_argument() const; }
bool GetInputArgname(const std::string &value_name, std::string *out) const; }
bool GetOutputArgname(const std::string &value_name, std::string *out) const; return res;
}
const std::list<std::string> &input_argnames() const;
const std::list<std::string> &output_argnames() const; // Collect all the output variable's name.
std::vector<std::string> output_names() const {
private: std::vector<std::string> res;
void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc); for (auto &param : OutputArgumentNames()) {
for (auto &x : Output(param)) {
void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc); res.push_back(x);
}
void CollectArguments(const framework::proto::OpDesc &opdesc); }
return res;
private: }
std::list<std::string> input_names_;
std::list<std::string> output_names_; std::vector<std::string> input_argnames() const {
std::list<std::string> input_argnames_; return InputArgumentNames();
std::list<std::string> output_argnames_; }
std::map<std::string, std::list<std::string>> input_argument_;
std::map<std::string, std::list<std::string>> output_argument_; std::vector<std::string> output_argnames() const {
// NOTE too heavy. return OutputArgumentNames();
std::unique_ptr<framework::proto::OpDesc> desc_; }
bool GetInputArgname(const std::string &value_name, std::string *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool GetOutputArgname(const std::string &value_name, std::string *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
}; };
} // namespace lite } // namespace lite
......
...@@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram( ...@@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram(
auto program_dummy = desc; auto program_dummy = desc;
program_dummy.mutable_blocks(0)->clear_ops(); program_dummy.mutable_blocks(0)->clear_ops();
for (auto &node : instructions_) { for (auto &node : instructions_) {
auto desc_dummy = node.op()->op_info()->desc(); pb::OpDesc pb_desc;
OpDesc desc(desc_dummy); TransformOpDescCppToPb(*node.op()->op_info(), &pb_desc);
desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); pb_desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType());
// append new opdesc // append new opdesc
*program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto(); *program_dummy.mutable_blocks(0)->add_ops() = *pb_desc.Proto();
} }
return program_dummy.SerializeAsString(); return program_dummy.SerializeAsString();
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
#include "paddle/fluid/lite/core/profile/basic_profiler.h" #include "paddle/fluid/lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE #endif // LITE_WITH_PROFILE
...@@ -67,7 +68,7 @@ struct Program { ...@@ -67,7 +68,7 @@ struct Program {
CHECK(ops.empty()) << "Executor duplicate Build found"; CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators. // Create operators.
for (const auto& proto_op_desc : program.blocks(0).ops()) { for (const auto& proto_op_desc : program.blocks(0).ops()) {
lite::OpDesc op_desc(proto_op_desc); pb::OpDesc op_desc(proto_op_desc);
auto op_type = op_desc.Type(); auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue; // if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]"; VLOG(4) << "create Op [" << op_type << "]";
...@@ -75,7 +76,10 @@ struct Program { ...@@ -75,7 +76,10 @@ struct Program {
auto op = LiteOpRegistry::Global().Create(op_type); auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type; CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(std::move(op)); ops.emplace_back(std::move(op));
ops.back()->Attach(op_desc, exec_scope);
cpp::OpDesc cpp_op_desc;
TransformOpDescPbToCpp(op_desc, &cpp_op_desc);
ops.back()->Attach(cpp_op_desc, exec_scope);
} }
} }
......
...@@ -11,11 +11,7 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) ...@@ -11,11 +11,7 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
endif() endif()
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite)
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite)
else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite
target_wrapper_host target_wrapper_host
...@@ -27,4 +23,7 @@ if (LITE_WITH_CUDA) ...@@ -27,4 +23,7 @@ if (LITE_WITH_CUDA)
endif() endif()
cc_library(model_parser_lite SRCS model_parser.cc DEPS ${model_parser_deps}) cc_library(model_parser_lite SRCS model_parser.cc DEPS ${model_parser_deps})
cc_test(test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite any_lite op_desc_lite compatible_pb_lite)
add_subdirectory(pb) add_subdirectory(pb)
add_subdirectory(cpp)
...@@ -13,3 +13,114 @@ ...@@ -13,3 +13,114 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h" #include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "compatible_pb.h"
namespace paddle {
namespace lite {
void InputsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) {
for (const std::string &param : pb_desc.InputArgumentNames()) {
cpp_desc->SetInput(param, pb_desc.Input(param));
}
}
void InputsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) {
for (const std::string &param : cpp_desc.InputArgumentNames()) {
pb_desc->SetInput(param, cpp_desc.Input(param));
}
}
void OutputsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) {
for (const std::string &param : pb_desc.OutputArgumentNames()) {
cpp_desc->SetOutput(param, pb_desc.Output(param));
}
}
void OutputsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) {
for (const std::string &param : cpp_desc.OutputArgumentNames()) {
pb_desc->SetOutput(param, cpp_desc.Output(param));
}
}
void AttrsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) {
using AttrType = OpDescAPI::AttrType;
auto set_attr = [&](const std::string &name, AttrType type) {
switch (type) {
case AttrType::INT:
cpp_desc->SetAttr<int32_t>(name, pb_desc.GetAttr<int32_t>(name));
break;
case AttrType::FLOAT:
cpp_desc->SetAttr<float>(name, pb_desc.GetAttr<float>(name));
break;
case AttrType::STRING:
cpp_desc->SetAttr<std::string>(name,
pb_desc.GetAttr<std::string>(name));
break;
case AttrType::INTS:
cpp_desc->SetAttr<std::vector<int>>(
name, pb_desc.GetAttr<std::vector<int>>(name));
break;
case AttrType::FLOATS:
cpp_desc->SetAttr<std::vector<float>>(
name, pb_desc.GetAttr<std::vector<float>>(name));
break;
case AttrType::BOOLEAN:
cpp_desc->SetAttr<bool>(name, pb_desc.GetAttr<bool>(name));
break;
case AttrType::STRINGS:
cpp_desc->SetAttr<std::vector<std::string>>(
name, pb_desc.GetAttr<std::vector<std::string>>(name));
break;
default:
LOG(FATAL) << "Unsupported attr type found " << static_cast<int>(type);
}
};
for (const auto &attr_name : pb_desc.AttrNames()) {
auto type = pb_desc.GetAttrType(attr_name);
set_attr(attr_name, type);
}
}
void AttrsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) {
using AttrType = OpDescAPI::AttrType;
auto set_attr = [&](const std::string &name, AttrType type) {
switch (type) {
#define IMPL_ONE(type__, T) \
case AttrType::type__: \
pb_desc->SetAttr<T>(name, cpp_desc.GetAttr<T>(name)); \
break;
IMPL_ONE(INT, int32_t);
IMPL_ONE(FLOAT, float);
IMPL_ONE(STRING, std::string);
IMPL_ONE(STRINGS, std::vector<std::string>);
IMPL_ONE(FLOATS, std::vector<float>);
IMPL_ONE(INTS, std::vector<int>);
IMPL_ONE(BOOLEAN, bool);
default:
LOG(FATAL) << "Unsupported attr type found: " << static_cast<int>(type);
}
};
#undef IMPL_ONE
for (const auto &attr_name : cpp_desc.AttrNames()) {
auto type = cpp_desc.GetAttrType(attr_name);
set_attr(attr_name, type);
}
}
void TransformOpDescPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) {
cpp_desc->SetType(pb_desc.Type());
InputsPbToCpp(pb_desc, cpp_desc);
OutputsPbToCpp(pb_desc, cpp_desc);
AttrsPbToCpp(pb_desc, cpp_desc);
}
void TransformOpDescCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) {
pb_desc->SetType(cpp_desc.Type());
InputsCppToPb(cpp_desc, pb_desc);
OutputsCppToPb(cpp_desc, pb_desc);
AttrsCppToPb(cpp_desc, pb_desc);
}
} // namespace lite
} // namespace paddle
...@@ -20,39 +20,28 @@ ...@@ -20,39 +20,28 @@
* lite::pb::XXDesc. * lite::pb::XXDesc.
*/ */
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/framework.pb.h" #include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h" #include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h" #include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
namespace paddle { namespace paddle {
namespace lite { namespace lite {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using Attribute = lite::pb::Attribute; using Attribute = lite::pb::Attribute;
using OpDesc = lite::pb::OpDesc; using OpDesc = lite::pb::OpDesc;
using VarDesc = lite::pb::VarDesc; using VarDesc = lite::pb::VarDesc;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using Attribute = framework::Attribute;
using OpDesc = framework::OpDesc;
using VarDesc = framework::VarDesc;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
template <typename T> template <typename T>
T GetAttr(const Attribute& x) { T GetAttr(const Attribute& x) {
return x.get<T>(); return x.get<T>();
} }
#else
template <typename T> /// Transform an OpDesc from pb to cpp format.
T GetAttr(const Attribute& x) { void TransformOpDescPbToCpp(const pb::OpDesc& pb_desc, cpp::OpDesc* cpp_desc);
return boost::get<T>(x);
} /// Transform an OpDesc from cpp to pb format.
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK void TransformOpDescCppToPb(const cpp::OpDesc& cpp_desc, pb::OpDesc* pb_desc);
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <set>
namespace paddle {
namespace lite {
namespace cpp {
#define SET_ATTR_IMPL(T, repr__) \
template <> \
void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \
attr_types_[name] = AttrType::repr__; \
attrs_[name].set<T>(v); \
}
SET_ATTR_IMPL(int32_t, INT);
SET_ATTR_IMPL(float, FLOAT);
SET_ATTR_IMPL(std::string, STRING);
SET_ATTR_IMPL(bool, BOOLEAN);
SET_ATTR_IMPL(std::vector<int>, INTS);
SET_ATTR_IMPL(std::vector<float>, FLOATS);
SET_ATTR_IMPL(std::vector<std::string>, STRINGS);
std::pair<OpDesc::attrs_t::const_iterator, OpDesc::attr_types_t::const_iterator>
FindAttr(const cpp::OpDesc& desc, const std::string& name) {
auto it = desc.attrs().find(name);
CHECK(it != desc.attrs().end()) << "No attributes called " << name
<< " found";
auto attr_it = desc.attr_types().find(name);
CHECK(attr_it != desc.attr_types().end());
return std::make_pair(it, attr_it);
}
#define GET_IMPL_ONE(T, repr__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
auto pair = FindAttr(*this, name); \
CHECK(pair.second->second == AttrType::repr__); \
return pair.first->second.get<T>(); \
}
GET_IMPL_ONE(int32_t, INT);
GET_IMPL_ONE(float, FLOAT);
GET_IMPL_ONE(std::string, STRING);
GET_IMPL_ONE(bool, BOOLEAN);
GET_IMPL_ONE(std::vector<int64_t>, LONGS);
GET_IMPL_ONE(std::vector<float>, FLOATS);
GET_IMPL_ONE(std::vector<int>, INTS);
GET_IMPL_ONE(std::vector<std::string>, STRINGS);
} // namespace cpp
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/any.h"
#include "paddle/fluid/lite/utils/varient.h"
namespace paddle {
namespace lite {
namespace cpp {
/*
* The cpp::OpDesc is the internal representation for Op. All the internal
* imprementation should use it, not the pb::OpDesc.
*/
class OpDesc : public OpDescAPI {
public:
using attrs_t = std::map<std::string, Any>;
using attr_types_t = std::map<std::string, AttrType>;
protected:
std::string type_;
std::map<std::string, std::vector<std::string>> inputs_;
std::map<std::string, std::vector<std::string>> outputs_;
std::map<std::string, Any> attrs_;
std::map<std::string, AttrType> attr_types_;
public:
OpDesc() = default;
std::string Type() const override { return type_; }
void SetType(const std::string& x) override { type_ = x; }
const std::map<std::string, std::vector<std::string>>& inputs() const {
return inputs_;
}
const std::map<std::string, std::vector<std::string>>& outputs() const {
return outputs_;
}
std::map<std::string, std::vector<std::string>>* mutable_inputs() {
return &inputs_;
}
std::map<std::string, std::vector<std::string>>* mutable_outputs() {
return &outputs_;
}
std::vector<std::string> Input(const std::string& param) const override {
auto it = inputs_.find(param);
CHECK(it != inputs_.end());
return it->second;
}
std::vector<std::string> InputArgumentNames() const override {
std::vector<std::string> res;
for (const auto& x : inputs_) res.push_back(x.first);
return res;
}
std::vector<std::string> OutputArgumentNames() const override {
std::vector<std::string> res;
for (const auto& x : outputs_) res.push_back(x.first);
return res;
}
std::vector<std::string> Output(const std::string& param) const override {
auto it = outputs_.find(param);
CHECK(it != outputs_.end());
return it->second;
}
void SetInput(const std::string& param,
const std::vector<std::string>& args) override {
inputs_[param] = args;
}
void SetOutput(const std::string& param,
const std::vector<std::string>& args) override {
outputs_[param] = args;
}
bool HasAttr(const std::string& name) const override {
return attrs_.count(name);
}
AttrType GetAttrType(const std::string& name) const override {
auto it = attr_types_.find(name);
CHECK(it != attr_types_.end());
return it->second;
}
std::vector<std::string> AttrNames() const override {
std::vector<std::string> res;
for (const auto& x : attrs_) {
res.push_back(x.first);
}
return res;
}
template <typename T>
void SetAttr(const std::string& name, const T& v);
template <typename T>
T GetAttr(const std::string& name) const;
const std::map<std::string, Any>& attrs() const { return attrs_; }
const std::map<std::string, AttrType>& attr_types() const {
return attr_types_;
}
};
} // namespace cpp
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace paddle {
namespace lite {
/*
* Compatible interfaces for all the different kinds of opdesc. All the OpDesc
* classes should implement this.
* NOTE Some interfaces are weried, we remain them unchanged to keep compatible
* with framework::OpDesc in Fluid framework.
*/
class OpDescAPI {
public:
// The AttrType is used to make the proto::AttrType portable.
enum class AttrType {
INT = 0,
FLOAT = 1,
STRING = 2,
INTS = 3,
FLOATS = 4,
STRINGS = 5,
BOOLEAN = 6,
BOOLEANS = 7,
BLOCK = 8,
LONG = 9,
BLOCKS = 10,
LONGS = 11,
UNK,
};
virtual ~OpDescAPI() = default;
/// Get operator's type.
virtual std::string Type() const = 0;
/// Set operator's type.
virtual void SetType(const std::string& type) = 0;
/// Get arguments given the parameter.
virtual std::vector<std::string> Input(const std::string& param) const = 0;
/// Get parameters.
virtual std::vector<std::string> InputArgumentNames() const = 0;
/// Get arguments given the parameter.
virtual std::vector<std::string> Output(const std::string& param) const = 0;
/// Get parameters.
virtual std::vector<std::string> OutputArgumentNames() const = 0;
/// Set a input given the parameter and arguments.
virtual void SetInput(const std::string& param,
const std::vector<std::string>& args) = 0;
virtual void SetOutput(const std::string& param,
const std::vector<std::string>& args) = 0;
/// Tell whether this desc has an attribute.
virtual bool HasAttr(const std::string& name) const = 0;
/// Get the type of an attribute.
virtual AttrType GetAttrType(const std::string& name) const = 0;
virtual std::vector<std::string> AttrNames() const = 0;
/// Set an attribute.
template <typename T>
void SetAttr(const std::string& name, const T& v);
/// Get an attribute.
template <typename T>
T GetAttr(const std::string& name) const;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace paddle {
namespace lite {
template <typename OpDesc>
void TestX() {
OpDesc desc;
desc.SetInput("X", {"a", "b"});
auto X = desc.Input("X");
ASSERT_EQ(X.size(), 2UL);
ASSERT_EQ(X[0], "a");
ASSERT_EQ(X[1], "b");
desc.SetOutput("Y", {"c", "d"});
auto Y = desc.Output("Y");
ASSERT_EQ(Y.size(), 2UL);
ASSERT_EQ(Y[0], "c");
ASSERT_EQ(Y[1], "d");
desc.template SetAttr<int32_t>("aint", 100);
ASSERT_TRUE(desc.HasAttr("aint"));
ASSERT_FALSE(desc.HasAttr("afloat"));
ASSERT_EQ(desc.template GetAttr<int32_t>("aint"), 100);
}
TEST(OpDesc, Basic) {
TestX<pb::OpDesc>();
TestX<cpp::OpDesc>();
}
TEST(OpDesc, CppToPb) {
cpp::OpDesc desc;
desc.SetInput("X", {"a", "b"});
desc.SetOutput("Y", {"c", "d"});
desc.template SetAttr<int32_t>("aint", 100);
pb::OpDesc pb_desc;
TransformOpDescCppToPb(desc, &pb_desc);
{
auto& desc = pb_desc;
auto X = desc.Input("X");
ASSERT_EQ(X.size(), 2UL);
ASSERT_EQ(X[0], "a");
ASSERT_EQ(X[1], "b");
auto Y = desc.Output("Y");
ASSERT_EQ(Y.size(), 2UL);
ASSERT_EQ(Y[0], "c");
ASSERT_EQ(Y[1], "d");
ASSERT_TRUE(desc.HasAttr("aint"));
ASSERT_FALSE(desc.HasAttr("afloat"));
ASSERT_EQ(desc.template GetAttr<int32_t>("aint"), 100);
}
}
TEST(OpDesc, PbToCpp) {
pb::OpDesc desc;
desc.SetInput("X", {"a", "b"});
desc.SetOutput("Y", {"c", "d"});
desc.template SetAttr<int32_t>("aint", 100);
cpp::OpDesc cpp_desc;
TransformOpDescPbToCpp(desc, &cpp_desc);
{
auto& desc = cpp_desc;
auto X = desc.Input("X");
ASSERT_EQ(X.size(), 2UL);
ASSERT_EQ(X[0], "a");
ASSERT_EQ(X[1], "b");
auto Y = desc.Output("Y");
ASSERT_EQ(Y.size(), 2UL);
ASSERT_EQ(Y[0], "c");
ASSERT_EQ(Y[1], "d");
ASSERT_TRUE(desc.HasAttr("aint"));
ASSERT_FALSE(desc.HasAttr("afloat"));
ASSERT_EQ(desc.template GetAttr<int32_t>("aint"), 100);
}
}
} // namespace lite
} // namespace paddle
...@@ -18,10 +18,9 @@ namespace paddle { ...@@ -18,10 +18,9 @@ namespace paddle {
namespace lite { namespace lite {
namespace pb { namespace pb {
template <> google::protobuf::internal::RepeatedPtrIterator<framework::proto::OpDesc_Attr>
void OpDesc::SetAttr<std::string>(const std::string &name, FindAttr(framework::proto::OpDesc *desc, const std::string &name) {
const std::string &v) { auto &xs = *desc->mutable_attrs();
auto &xs = *desc_.mutable_attrs();
auto it = std::find_if( auto it = std::find_if(
xs.begin(), xs.end(), xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; });
...@@ -33,33 +32,95 @@ void OpDesc::SetAttr<std::string>(const std::string &name, ...@@ -33,33 +32,95 @@ void OpDesc::SetAttr<std::string>(const std::string &name,
return x.name() == name; return x.name() == name;
}); });
} }
return it;
}
#define SET_IMPL_ONE(T, ty__, pb_f__) \
template <> \
void OpDesc::SetAttr<T>(const std::string &name, const T &v) { \
auto it = FindAttr(&desc_, name); \
it->set_type(framework::proto::ty__); \
it->set_##pb_f__(v); \
}
SET_IMPL_ONE(int, INT, i);
SET_IMPL_ONE(float, FLOAT, f);
SET_IMPL_ONE(bool, FLOAT, f);
template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v) {
auto it = FindAttr(&desc_, name);
it->set_type(framework::proto::INTS);
it->clear_ints();
for (auto &i : v) {
it->add_ints(i);
}
}
template <>
void OpDesc::SetAttr<std::string>(const std::string &name,
const std::string &v) {
auto it = FindAttr(&desc_, name);
it->set_type(framework::proto::STRING); it->set_type(framework::proto::STRING);
it->set_s(v.c_str()); it->set_s(v.c_str());
} }
template <> template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name, void OpDesc::SetAttr<std::vector<float>>(const std::string &name,
const std::vector<int> &v) { const std::vector<float> &v) {
auto &xs = *desc_.mutable_attrs(); auto it = FindAttr(&desc_, name);
it->set_type(framework::proto::FLOATS);
it->clear_floats();
for (auto &i : v) {
it->add_floats(i);
}
}
template <>
void OpDesc::SetAttr<std::vector<std::string>>(
const std::string &name, const std::vector<std::string> &v) {
auto it = FindAttr(&desc_, name);
it->set_type(framework::proto::STRINGS);
it->clear_strings();
for (auto &i : v) {
it->add_strings(i);
}
}
google::protobuf::internal::RepeatedPtrIterator<
const framework::proto::OpDesc_Attr>
GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) {
auto &xs = desc.attrs();
auto it = std::find_if( auto it = std::find_if(
xs.begin(), xs.end(), xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; });
if (it == xs.end()) { return it;
auto *attr = xs.Add(); }
attr->set_name(name);
it = std::find_if(xs.begin(), xs.end(), #define GET_ATTR_IMPL(T, pb_f__) \
[&](const framework::proto::OpDesc_Attr &x) { template <> \
return x.name() == name; T OpDesc::GetAttr<T>(const std::string &name) const { \
}); auto it = GetFindAttr(desc_, name); \
return it->pb_f__(); \
} }
it->set_type(framework::proto::INTS); #define GET_ATTRS_IMPL(T, pb_f__) \
it->clear_ints(); template <> \
for (auto &i : v) { T OpDesc::GetAttr<T>(const std::string &name) const { \
it->add_ints(i); auto it = GetFindAttr(desc_, name); \
T res; \
for (const auto &v : it->pb_f__()) { \
res.push_back(v); \
} \
return res; \
} }
} GET_ATTR_IMPL(int32_t, i);
GET_ATTR_IMPL(float, f);
GET_ATTR_IMPL(bool, b);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<std::string>, strings);
GET_ATTR_IMPL(std::string, s);
} // namespace pb } // namespace pb
} // namespace lite } // namespace lite
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/framework.pb.h" #include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
...@@ -43,7 +44,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -43,7 +44,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
* except the desc_, to avoid the inconsistent state, which is normal in the * except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs. * original interface and results in bugs.
*/ */
class OpDesc { class OpDesc : public OpDescAPI {
public: public:
OpDesc() {} OpDesc() {}
...@@ -54,38 +55,38 @@ class OpDesc { ...@@ -54,38 +55,38 @@ class OpDesc {
framework::proto::OpDesc *Proto() { return &desc_; } framework::proto::OpDesc *Proto() { return &desc_; }
const framework::proto::OpDesc &ReadonlyProto() const { return desc_; } const framework::proto::OpDesc &ReadonlyProto() const { return desc_; }
std::string Type() const { return desc_.type(); } std::string Type() const override { return desc_.type(); }
void SetType(const std::string &type) { desc_.set_type(type); } void SetType(const std::string &type) override { desc_.set_type(type); }
// Get the arguments of parameter called `param` // Get the arguments of parameter called `param`
std::vector<std::string> Input(const std::string &param) const { std::vector<std::string> Input(const std::string &param) const override {
return GetArguments(desc_.inputs(), param); return GetArguments(desc_.inputs(), param);
} }
std::vector<std::string> InputArgumentNames() const { std::vector<std::string> InputArgumentNames() const override {
return GetArgumentNames(desc_.inputs()); return GetArgumentNames(desc_.inputs());
} }
void SetInput(const std::string &param, void SetInput(const std::string &param,
const std::vector<std::string> &args) { const std::vector<std::string> &args) override {
SetArgument(desc_.mutable_inputs(), param, args); SetArgument(desc_.mutable_inputs(), param, args);
} }
std::vector<std::string> Output(const std::string &param) const { std::vector<std::string> Output(const std::string &param) const override {
return GetArguments(desc_.outputs(), param); return GetArguments(desc_.outputs(), param);
} }
std::vector<std::string> OutputArgumentNames() const { std::vector<std::string> OutputArgumentNames() const override {
return GetArgumentNames(desc_.outputs()); return GetArgumentNames(desc_.outputs());
} }
void SetOutput(const std::string &param, void SetOutput(const std::string &param,
const std::vector<std::string> &args) { const std::vector<std::string> &args) override {
SetArgument(desc_.mutable_outputs(), param, args); SetArgument(desc_.mutable_outputs(), param, args);
} }
bool HasAttr(const std::string &name) const { bool HasAttr(const std::string &name) const override {
const auto &xs = desc_.attrs(); const auto &xs = desc_.attrs();
auto it = std::find_if(xs.begin(), xs.end(), auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) { [&](const framework::proto::OpDesc_Attr &x) {
...@@ -94,17 +95,38 @@ class OpDesc { ...@@ -94,17 +95,38 @@ class OpDesc {
return it != xs.end(); return it != xs.end();
} }
framework::proto::AttrType GetAttrType(const std::string &name) const { AttrType GetAttrType(const std::string &name) const override {
const auto &xs = desc_.attrs(); const auto &xs = desc_.attrs();
auto it = std::find_if(xs.begin(), xs.end(), auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) { [&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name; return x.name() == name;
}); });
CHECK(it != xs.end()); CHECK(it != xs.end());
return it->type(); #define DEF_ONE(type__) \
case framework::proto::AttrType::type__: \
return AttrType::type__;
switch (it->type()) {
DEF_ONE(INT);
DEF_ONE(FLOAT);
DEF_ONE(STRING);
DEF_ONE(INTS);
DEF_ONE(FLOATS);
DEF_ONE(STRINGS);
DEF_ONE(BOOLEAN);
DEF_ONE(BOOLEANS);
DEF_ONE(BLOCK);
DEF_ONE(LONG);
DEF_ONE(BLOCKS);
DEF_ONE(LONGS);
default:
LOG(ERROR) << "Unknown attribute type";
return AttrType::UNK;
}
#undef DEF_ONE
} }
std::vector<std::string> AttrNames() const { std::vector<std::string> AttrNames() const override {
std::vector<std::string> res; std::vector<std::string> res;
const auto &xs = desc_.attrs(); const auto &xs = desc_.attrs();
std::transform( std::transform(
...@@ -114,72 +136,10 @@ class OpDesc { ...@@ -114,72 +136,10 @@ class OpDesc {
} }
template <typename T> template <typename T>
void SetAttr(const std::string &name, const T &v) { void SetAttr(const std::string &name, const T &v);
auto &xs = *desc_.mutable_attrs();
auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
if (it == xs.end()) {
auto *attr = xs.Add();
attr->set_name(name);
it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
}
size_t hash = typeid(T).hash_code(); template <typename T>
if (hash == typeid(int).hash_code()) { // NOLINT T GetAttr(const std::string &name) const;
it->set_type(framework::proto::INT);
it->set_i(v);
} else if (hash == typeid(float).hash_code()) { // NOLINT
it->set_type(framework::proto::FLOAT);
it->set_f(v);
} else if (hash == typeid(bool).hash_code()) { // NOLINT
it->set_type(framework::proto::BOOLEAN);
it->set_b(v);
} else {
LOG(FATAL) << "unsupport attr type";
}
}
Attribute GetAttr(const std::string &name) const {
auto &xs = desc_.attrs();
auto it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
Attribute res;
CHECK(it != xs.end());
switch (it->type()) {
case framework::proto::INT:
res.set<int>(it->i());
break;
case framework::proto::FLOAT:
res.set<float>(it->f());
break;
case framework::proto::STRING:
res.set<std::string>(it->s());
break;
case framework::proto::BOOLEAN:
res.set<bool>(it->b());
break;
case framework::proto::INTS: {
std::vector<int> values;
const auto &ys = it->ints();
std::transform(ys.begin(), ys.end(), std::back_inserter(values),
[](const int &x) { return x; });
res.set<std::vector<int>>(values);
} break;
default:
LOG(FATAL) << "unsupported attr type";
}
return res;
}
private: private:
std::vector<std::string> GetArguments( std::vector<std::string> GetArguments(
......
...@@ -33,7 +33,7 @@ class ActivationOp : public OpLite { ...@@ -33,7 +33,7 @@ class ActivationOp : public OpLite {
return true; return true;
} }
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto X_name = opdesc.Input("X").front(); auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Output("Out").front(); auto Out_name = opdesc.Output("Out").front();
...@@ -66,7 +66,7 @@ class ActivationGradOp : public OpLite { ...@@ -66,7 +66,7 @@ class ActivationGradOp : public OpLite {
return true; return true;
} }
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front();
......
...@@ -54,7 +54,7 @@ bool ConcatOpLite::InferShape() const { ...@@ -54,7 +54,7 @@ bool ConcatOpLite::InferShape() const {
} }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto inputs = op_desc.Input("X"); auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
...@@ -63,7 +63,7 @@ bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { ...@@ -63,7 +63,7 @@ bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
} }
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.axis = GetAttr<int>(op_desc.GetAttr("axis")); param_.axis = op_desc.GetAttr<int>("axis");
return true; return true;
} }
......
...@@ -32,7 +32,7 @@ class ConcatOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class ConcatOpLite : public OpLite {
bool InferShape() const override; bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "concat"; } std::string DebugString() const override { return "concat"; }
......
...@@ -42,7 +42,7 @@ TEST(concat_op_lite, test) { ...@@ -42,7 +42,7 @@ TEST(concat_op_lite, test) {
} }
// prepare op desc // prepare op desc
lite::OpDesc desc; cpp::OpDesc desc;
desc.SetType("concat"); desc.SetType("concat");
desc.SetInput("X", {"x0", "x1"}); desc.SetInput("X", {"x0", "x1"});
desc.SetOutput("Out", {"output"}); desc.SetOutput("Out", {"output"});
......
...@@ -42,7 +42,7 @@ class DropoutOpLite : public OpLite { ...@@ -42,7 +42,7 @@ class DropoutOpLite : public OpLite {
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const OpDesc& op_desc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
auto Mask = op_desc.Output("Mask").front(); auto Mask = op_desc.Output("Mask").front();
...@@ -51,14 +51,14 @@ class DropoutOpLite : public OpLite { ...@@ -51,14 +51,14 @@ class DropoutOpLite : public OpLite {
param_.output = GetMutableVar<lite::Tensor>(scope, out); param_.output = GetMutableVar<lite::Tensor>(scope, out);
param_.mask = GetMutableVar<lite::Tensor>(scope, Mask); param_.mask = GetMutableVar<lite::Tensor>(scope, Mask);
param_.dropout_prob = boost::get<float>(op_desc.GetAttr("dropout_prob")); param_.dropout_prob = op_desc.GetAttr<float>("dropout_prob");
if (op_desc.HasAttr("axis")) { if (op_desc.HasAttr("axis")) {
param_.is_test = boost::get<bool>(op_desc.GetAttr("is_test")); param_.is_test = op_desc.GetAttr<bool>("is_test");
} }
param_.fix_seed = boost::get<bool>(op_desc.GetAttr("fix_seed")); param_.fix_seed = op_desc.GetAttr<bool>("fix_seed");
param_.seed = boost::get<int>(op_desc.GetAttr("seed")); param_.seed = op_desc.GetAttr<int>("seed");
param_.dropout_implementation = param_.dropout_implementation =
boost::get<int>(op_desc.GetAttr("dropout_implementation")); op_desc.GetAttr<int>("dropout_implementation");
return true; return true;
} }
......
...@@ -36,7 +36,7 @@ class ElementwiseOp : public OpLite { ...@@ -36,7 +36,7 @@ class ElementwiseOp : public OpLite {
return true; return true;
} }
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto X_name = opdesc.Input("X").front(); auto X_name = opdesc.Input("X").front();
auto Y_name = opdesc.Input("Y").front(); auto Y_name = opdesc.Input("Y").front();
auto Out_name = opdesc.Output("Out").front(); auto Out_name = opdesc.Output("Out").front();
...@@ -44,7 +44,7 @@ class ElementwiseOp : public OpLite { ...@@ -44,7 +44,7 @@ class ElementwiseOp : public OpLite {
param_.X = GetVar<lite::Tensor>(scope, X_name); param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Y = GetVar<lite::Tensor>(scope, Y_name); param_.Y = GetVar<lite::Tensor>(scope, Y_name);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name); param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.axis = boost::get<int>(opdesc.GetAttr("axis")); param_.axis = opdesc.GetAttr<int>("axis");
return true; return true;
} }
...@@ -75,8 +75,8 @@ class ElementwiseGradExplicitOp : public OpLite { ...@@ -75,8 +75,8 @@ class ElementwiseGradExplicitOp : public OpLite {
return true; return true;
} }
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.Inputs().size(), 1UL); CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL);
auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); auto Out_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_name = opdesc.Output(framework::GradVarName("X")).front(); auto X_name = opdesc.Output(framework::GradVarName("X")).front();
auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); auto Y_name = opdesc.Output(framework::GradVarName("Y")).front();
...@@ -84,7 +84,7 @@ class ElementwiseGradExplicitOp : public OpLite { ...@@ -84,7 +84,7 @@ class ElementwiseGradExplicitOp : public OpLite {
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_name); param_.Out_grad = GetVar<lite::Tensor>(scope, Out_name);
param_.X_grad = GetMutableVar<lite::Tensor>(scope, X_name); param_.X_grad = GetMutableVar<lite::Tensor>(scope, X_name);
param_.Y_grad = GetMutableVar<Tensor>(scope, Y_name); param_.Y_grad = GetMutableVar<Tensor>(scope, Y_name);
param_.axis = boost::get<int>(opdesc.GetAttr("axis")); param_.axis = opdesc.GetAttr<int>("axis");
return true; return true;
} }
......
...@@ -46,7 +46,7 @@ class FcOpLite : public OpLite { ...@@ -46,7 +46,7 @@ class FcOpLite : public OpLite {
*/ */
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front(); auto W = op_desc.Input("W").front();
auto bias = op_desc.Input("Bias").front(); auto bias = op_desc.Input("Bias").front();
...@@ -57,7 +57,7 @@ class FcOpLite : public OpLite { ...@@ -57,7 +57,7 @@ class FcOpLite : public OpLite {
param_.bias = scope->FindVar(bias)->GetMutable<lite::Tensor>(); param_.bias = scope->FindVar(bias)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = GetAttr<int>(op_desc.GetAttr("in_num_col_dims")); param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
return true; return true;
} }
......
...@@ -47,7 +47,7 @@ TEST(fc_op_lite, TestX86) { ...@@ -47,7 +47,7 @@ TEST(fc_op_lite, TestX86) {
} }
// prepare op desc // prepare op desc
lite::OpDesc desc; cpp::OpDesc desc;
desc.SetType("fc"); desc.SetType("fc");
desc.SetInput("Input", {"x"}); desc.SetInput("Input", {"x"});
desc.SetInput("W", {"w"}); desc.SetInput("W", {"w"});
......
...@@ -34,7 +34,7 @@ class FeedOp : public OpLite { ...@@ -34,7 +34,7 @@ class FeedOp : public OpLite {
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto feed_var_name = opdesc.Input("X").front(); auto feed_var_name = opdesc.Input("X").front();
auto* feed_var = scope->FindVar(feed_var_name); auto* feed_var = scope->FindVar(feed_var_name);
CHECK(feed_var); CHECK(feed_var);
...@@ -48,7 +48,7 @@ class FeedOp : public OpLite { ...@@ -48,7 +48,7 @@ class FeedOp : public OpLite {
// NOTE need boost here // NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc // TODO(Superjomn) drop the need of framework::op_desc
param_.col = GetAttr<int>(opdesc.GetAttr("col")); param_.col = opdesc.GetAttr<int>("col");
return true; return true;
} }
......
...@@ -33,7 +33,7 @@ class FetchOp : public OpLite { ...@@ -33,7 +33,7 @@ class FetchOp : public OpLite {
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto _x = opdesc.Input("X").front(); auto _x = opdesc.Input("X").front();
auto* x = scope->FindVar(_x); auto* x = scope->FindVar(_x);
CHECK(x); CHECK(x);
...@@ -43,7 +43,7 @@ class FetchOp : public OpLite { ...@@ -43,7 +43,7 @@ class FetchOp : public OpLite {
auto* out = scope->FindVar(_out); auto* out = scope->FindVar(_out);
param_.fetch_list = out->GetMutable<std::vector<lite::Tensor>>(); param_.fetch_list = out->GetMutable<std::vector<lite::Tensor>>();
param_.col = GetAttr<int>(opdesc.GetAttr("col")); param_.col = opdesc.GetAttr<int>("col");
return true; return true;
} }
......
...@@ -33,14 +33,14 @@ class FillConstantOp : public OpLite { ...@@ -33,14 +33,14 @@ class FillConstantOp : public OpLite {
return true; return true;
} }
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto Out_name = opdesc.Output("Out").front(); auto Out_name = opdesc.Output("Out").front();
param_.Out = GetMutableVar<Tensor>(scope, Out_name); param_.Out = GetMutableVar<Tensor>(scope, Out_name);
param_.dtype = GetAttr<int>(opdesc.GetAttr("dtype")); param_.dtype = opdesc.GetAttr<int>("dtype");
param_.shape = GetAttr<std::vector<int64_t>>(opdesc.GetAttr("shape")); param_.shape = opdesc.GetAttr<std::vector<int64_t>>("shape");
param_.value = GetAttr<float>(opdesc.GetAttr("value")); param_.value = opdesc.GetAttr<float>("value");
param_.force_cpu = GetAttr<bool>(opdesc.GetAttr("force_cpu")); param_.force_cpu = opdesc.GetAttr<bool>("force_cpu");
return true; return true;
} }
......
...@@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const { ...@@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const {
return true; return true;
} }
bool IoCopyOp::Run() { return OpLite::Run(); } bool IoCopyOp::Run() { return OpLite::Run(); }
bool IoCopyOp::AttachImpl(const OpDesc &opdesc, paddle::lite::Scope *scope) { bool IoCopyOp::AttachImpl(const cpp::OpDesc &opdesc,
paddle::lite::Scope *scope) {
auto x = opdesc.Input("Input").front(); auto x = opdesc.Input("Input").front();
auto out = opdesc.Output("Out").front(); auto out = opdesc.Output("Out").front();
param_.x = GetTensor(scope, x); param_.x = GetTensor(scope, x);
......
...@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite { ...@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
protected: protected:
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
private: private:
operators::IoCopyParam param_; operators::IoCopyParam param_;
......
...@@ -37,7 +37,7 @@ class MeanOp : public OpLite { ...@@ -37,7 +37,7 @@ class MeanOp : public OpLite {
return true; return true;
} }
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto X_name = opdesc.Input("X").front(); auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Output("Out").front(); auto Out_name = opdesc.Output("Out").front();
...@@ -72,8 +72,8 @@ class MeanGradOp : public OpLite { ...@@ -72,8 +72,8 @@ class MeanGradOp : public OpLite {
return true; return true;
} }
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.Inputs().size(), 3UL); CHECK_EQ(opdesc.InputArgumentNames().size(), 3UL);
auto X_name = opdesc.Input("X").front(); auto X_name = opdesc.Input("X").front();
auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front();
......
...@@ -85,7 +85,7 @@ bool MulGradOpLite::InferShape() const { ...@@ -85,7 +85,7 @@ bool MulGradOpLite::InferShape() const {
return true; return true;
} }
bool MulGradOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto X_name = op_desc.Input("X").front(); auto X_name = op_desc.Input("X").front();
auto Y_name = op_desc.Input("Y").front(); auto Y_name = op_desc.Input("Y").front();
auto Out_grad_name = op_desc.Output(framework::GradVarName("Out")).front(); auto Out_grad_name = op_desc.Output(framework::GradVarName("Out")).front();
......
...@@ -37,7 +37,7 @@ class MulOpLite : public OpLite { ...@@ -37,7 +37,7 @@ class MulOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front(); auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
...@@ -49,8 +49,8 @@ class MulOpLite : public OpLite { ...@@ -49,8 +49,8 @@ class MulOpLite : public OpLite {
param_.y = var->GetMutable<Tensor>(); param_.y = var->GetMutable<Tensor>();
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>(); param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.x_num_col_dims = GetAttr<int>(op_desc.GetAttr("x_num_col_dims")); param_.x_num_col_dims = op_desc.GetAttr<int>("x_num_col_dims");
param_.y_num_col_dims = GetAttr<int>(op_desc.GetAttr("y_num_col_dims")); param_.y_num_col_dims = op_desc.GetAttr<int>("y_num_col_dims");
return true; return true;
} }
...@@ -73,7 +73,7 @@ class MulGradOpLite : public OpLite { ...@@ -73,7 +73,7 @@ class MulGradOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "mul_grad"; } std::string DebugString() const override { return "mul_grad"; }
......
...@@ -30,7 +30,7 @@ bool ReluOp::InferShape() const { ...@@ -30,7 +30,7 @@ bool ReluOp::InferShape() const {
return true; return true;
} }
bool ReluOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<lite::Tensor *>( param_.input = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("Input").front())->Get<lite::Tensor>());
param_.output = param_.output =
......
...@@ -32,7 +32,7 @@ class ReluOp : public OpLite { ...@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool InferShape() const override; bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "relu"; } std::string DebugString() const override { return "relu"; }
......
...@@ -33,7 +33,7 @@ bool ReshapeOp::InferShape() const { ...@@ -33,7 +33,7 @@ bool ReshapeOp::InferShape() const {
return true; return true;
} }
bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto x_var = scope->FindVar(opdesc.Input("X").front()); auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front()); auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var); CHECK(x_var);
...@@ -49,9 +49,9 @@ bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { ...@@ -49,9 +49,9 @@ bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
const_cast<lite::Tensor *>(&(actual_shape_var->Get<lite::Tensor>())); const_cast<lite::Tensor *>(&(actual_shape_var->Get<lite::Tensor>()));
} }
} }
param_.shape = GetAttr<std::vector<int>>(opdesc.GetAttr("shape")); param_.shape = (opdesc.GetAttr<std::vector<int>>("shape"));
if (opdesc.HasAttr("inplace")) { if (opdesc.HasAttr("inplace")) {
param_.inplace = GetAttr<bool>(opdesc.GetAttr("inplace")); param_.inplace = opdesc.GetAttr<bool>("inplace");
} }
CHECK(param_.x) << "Input(X) of ReshapeOp should not be null."; CHECK(param_.x) << "Input(X) of ReshapeOp should not be null.";
CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null."; CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null.";
...@@ -70,14 +70,14 @@ bool Reshape2Op::InferShape() const { ...@@ -70,14 +70,14 @@ bool Reshape2Op::InferShape() const {
ReshapeOp::InferShape(); ReshapeOp::InferShape();
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (int i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i]; xshape_dims[i + 1] = x_dims[i];
} }
param_.xshape->Resize(DDim(xshape_dims)); param_.xshape->Resize(DDim(xshape_dims));
return true; return true;
} }
bool Reshape2Op::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
ReshapeOp::AttachImpl(opdesc, scope); ReshapeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var); CHECK(xshape_var);
......
...@@ -32,7 +32,7 @@ class ReshapeOp : public OpLite { ...@@ -32,7 +32,7 @@ class ReshapeOp : public OpLite {
bool InferShape() const override; bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reshape"; } std::string DebugString() const override { return "reshape"; }
...@@ -50,7 +50,7 @@ class Reshape2Op : public ReshapeOp { ...@@ -50,7 +50,7 @@ class Reshape2Op : public ReshapeOp {
bool InferShape() const override; bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reshape2"; } std::string DebugString() const override { return "reshape2"; }
......
...@@ -47,7 +47,7 @@ TEST(reshape_op_lite, test) { ...@@ -47,7 +47,7 @@ TEST(reshape_op_lite, test) {
for (auto& has_actual_shape : {true, false}) { for (auto& has_actual_shape : {true, false}) {
for (auto& inplace : {true, false}) { for (auto& inplace : {true, false}) {
// prepare op desc // prepare op desc
lite::OpDesc desc; cpp::OpDesc desc;
desc.SetType("reshape"); desc.SetType("reshape");
desc.SetInput("X", {"x"}); desc.SetInput("X", {"x"});
if (has_actual_shape) { if (has_actual_shape) {
...@@ -68,7 +68,7 @@ TEST(reshape_op_lite, test) { ...@@ -68,7 +68,7 @@ TEST(reshape_op_lite, test) {
// check output dims // check output dims
auto output_dims = output->dims(); auto output_dims = output->dims();
CHECK_EQ(output_dims.size(), shape.second.size()); CHECK_EQ(output_dims.size(), shape.second.size());
for (int i = 0; i < output_dims.size(); i++) { for (size_t i = 0; i < output_dims.size(); i++) {
CHECK_EQ(output_dims[i], shape.second[i]); CHECK_EQ(output_dims[i], shape.second[i]);
} }
} }
...@@ -102,7 +102,7 @@ TEST(reshape2_op_lite, test) { ...@@ -102,7 +102,7 @@ TEST(reshape2_op_lite, test) {
for (auto& has_actual_shape : {true, false}) { for (auto& has_actual_shape : {true, false}) {
for (auto& inplace : {true, false}) { for (auto& inplace : {true, false}) {
// prepare op desc // prepare op desc
lite::OpDesc desc; cpp::OpDesc desc;
desc.SetType("reshape"); desc.SetType("reshape");
desc.SetInput("X", {"x"}); desc.SetInput("X", {"x"});
if (has_actual_shape) { if (has_actual_shape) {
...@@ -132,7 +132,7 @@ TEST(reshape2_op_lite, test) { ...@@ -132,7 +132,7 @@ TEST(reshape2_op_lite, test) {
auto xshape_dims = xshape->dims(); auto xshape_dims = xshape->dims();
CHECK_EQ(xshape_dims.size(), x_dims.size() + 1); CHECK_EQ(xshape_dims.size(), x_dims.size() + 1);
CHECK_EQ(xshape_dims[0], 0); CHECK_EQ(xshape_dims[0], 0);
for (int i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
CHECK_EQ(xshape_dims[i + 1], x_dims[i]); CHECK_EQ(xshape_dims[i + 1], x_dims[i]);
} }
} }
......
...@@ -29,14 +29,14 @@ bool ScaleOp::InferShape() const { ...@@ -29,14 +29,14 @@ bool ScaleOp::InferShape() const {
return true; return true;
} }
bool ScaleOp::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto output = op_desc.Output("Out").front(); auto output = op_desc.Output("Out").front();
param_.x = scope->FindVar(x)->GetMutable<Tensor>(); param_.x = scope->FindVar(x)->GetMutable<Tensor>();
param_.output = scope->FindVar(output)->GetMutable<Tensor>(); param_.output = scope->FindVar(output)->GetMutable<Tensor>();
param_.scale = GetAttr<float>(op_desc.GetAttr("scale")); param_.scale = op_desc.GetAttr<float>("scale");
param_.bias = GetAttr<float>(op_desc.GetAttr("bias")); param_.bias = op_desc.GetAttr<float>("bias");
param_.bias_after_scale = GetAttr<bool>(op_desc.GetAttr("bias_after_scale")); param_.bias_after_scale = op_desc.GetAttr<bool>("bias_after_scale");
CHECK(param_.x); CHECK(param_.x);
CHECK(param_.output); CHECK(param_.output);
return true; return true;
......
...@@ -32,7 +32,7 @@ class ScaleOp : public OpLite { ...@@ -32,7 +32,7 @@ class ScaleOp : public OpLite {
bool InferShape() const override; bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "scale"; } std::string DebugString() const override { return "scale"; }
......
...@@ -29,7 +29,7 @@ TEST(scale_op_lite, test) { ...@@ -29,7 +29,7 @@ TEST(scale_op_lite, test) {
output->Resize(DDim(std::vector<int64_t>{1, 1})); output->Resize(DDim(std::vector<int64_t>{1, 1}));
// prepare op desc // prepare op desc
lite::OpDesc desc; cpp::OpDesc desc;
desc.SetType("scale"); desc.SetType("scale");
desc.SetInput("X", {"x"}); desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"}); desc.SetOutput("Out", {"output"});
...@@ -48,7 +48,7 @@ TEST(scale_op_lite, test) { ...@@ -48,7 +48,7 @@ TEST(scale_op_lite, test) {
auto x_dims = x->dims(); auto x_dims = x->dims();
auto output_dims = output->dims(); auto output_dims = output->dims();
CHECK_EQ(output_dims.size(), x_dims.size()); CHECK_EQ(output_dims.size(), x_dims.size());
for (int i = 0; i < output_dims.size(); i++) { for (size_t i = 0; i < output_dims.size(); i++) {
CHECK_EQ(output_dims[i], x_dims[i]); CHECK_EQ(output_dims[i], x_dims[i]);
} }
} }
......
...@@ -24,7 +24,8 @@ bool SoftmaxOp::CheckShape() const { ...@@ -24,7 +24,8 @@ bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
CHECK_OR_FALSE(param_.axis >= -x_rank && param_.axis < x_rank); CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) &&
param_.axis < static_cast<int>(x_rank));
return true; return true;
} }
...@@ -33,12 +34,12 @@ bool SoftmaxOp::InferShape() const { ...@@ -33,12 +34,12 @@ bool SoftmaxOp::InferShape() const {
return true; return true;
} }
bool SoftmaxOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.x = const_cast<lite::Tensor *>( param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output = param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.axis = GetAttr<int>(opdesc.GetAttr("axis")); param_.axis = opdesc.GetAttr<int>("axis");
CHECK(param_.x); CHECK(param_.x);
CHECK(param_.output); CHECK(param_.output);
return true; return true;
......
...@@ -32,7 +32,7 @@ class SoftmaxOp : public OpLite { ...@@ -32,7 +32,7 @@ class SoftmaxOp : public OpLite {
bool InferShape() const override; bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "softmax"; } std::string DebugString() const override { return "softmax"; }
......
...@@ -37,7 +37,7 @@ TEST(softmax_op_lite, test) { ...@@ -37,7 +37,7 @@ TEST(softmax_op_lite, test) {
} }
// prepare op desc // prepare op desc
lite::OpDesc desc; cpp::OpDesc desc;
desc.SetType("softmax"); desc.SetType("softmax");
desc.SetInput("X", {"x"}); desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"}); desc.SetOutput("Out", {"output"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册