// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once /* * This file implements a light-weight OpDesc like the framework::OpDesc. We * delete the unnecessary methods, and remove the underlying dependencies, such * as framework::Operator and boost::varient to make it runnable in mobile. */ #include <algorithm> #include <map> #include <set> #include <string> #include <unordered_map> #include <vector> #include "paddle/fluid/lite/core/framework.pb.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { namespace pb { using Attribute = variant<int, float, bool, std::vector<std::string>, std::vector<int>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>; /* * The lite::OpDesc, an light-weight implementation of wrapper of proto::OpDesc. * Unlike the original one in framework::OpDesc, we remove the local members * except the desc_, to avoid the inconsistent state, which is normal in the * original interface and results in bugs. */ class OpDesc { public: OpDesc() {} explicit OpDesc(const framework::proto::OpDesc &desc) : desc_(desc) {} void CopyFrom(const OpDesc &op_desc) { desc_ = op_desc.ReadonlyProto(); } framework::proto::OpDesc *Proto() { return &desc_; } const framework::proto::OpDesc &ReadonlyProto() const { return desc_; } std::string Type() const { return desc_.type(); } void SetType(const std::string &type) { desc_.set_type(type); } // Get the arguments of parameter called `param` std::vector<std::string> Input(const std::string ¶m) const { return GetArguments(desc_.inputs(), param); } std::vector<std::string> InputArgumentNames() const { return GetArgumentNames(desc_.inputs()); } void SetInput(const std::string ¶m, const std::vector<std::string> &args) { SetArgument(desc_.mutable_inputs(), param, args); } std::vector<std::string> Output(const std::string ¶m) const { return GetArguments(desc_.outputs(), param); } std::vector<std::string> OutputArgumentNames() const { return GetArgumentNames(desc_.outputs()); } void SetOutput(const std::string ¶m, const std::vector<std::string> &args) { SetArgument(desc_.mutable_outputs(), param, args); } bool HasAttr(const std::string &name) const { const auto &xs = desc_.attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); return it != xs.end(); } framework::proto::AttrType GetAttrType(const std::string &name) const { const auto &xs = desc_.attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); CHECK(it != xs.end()); return it->type(); } std::vector<std::string> AttrNames() const { std::vector<std::string> res; const auto &xs = desc_.attrs(); std::transform( xs.begin(), xs.end(), std::back_inserter(res), [](const framework::proto::OpDesc_Attr &x) { return x.name(); }); return res; } template <typename T> void SetAttr(const std::string &name, const T &v) { auto &xs = *desc_.mutable_attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); if (it == xs.end()) { auto *attr = xs.Add(); attr->set_name(name); it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); } size_t hash = typeid(T).hash_code(); if (hash == typeid(int).hash_code()) { // NOLINT it->set_type(framework::proto::INT); it->set_i(v); } else if (hash == typeid(float).hash_code()) { // NOLINT it->set_type(framework::proto::FLOAT); it->set_f(v); } else if (hash == typeid(bool).hash_code()) { // NOLINT it->set_type(framework::proto::BOOLEAN); it->set_b(v); } else { LOG(FATAL) << "unsupport attr type"; } } Attribute GetAttr(const std::string &name) const { auto &xs = desc_.attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); Attribute res; CHECK(it != xs.end()); switch (it->type()) { case framework::proto::INT: res.set<int>(it->i()); break; case framework::proto::FLOAT: res.set<float>(it->f()); break; case framework::proto::STRING: res.set<std::string>(it->s()); break; case framework::proto::BOOLEAN: res.set<bool>(it->b()); break; case framework::proto::INTS: { std::vector<int> values; const auto &ys = it->ints(); std::transform(ys.begin(), ys.end(), std::back_inserter(values), [](const int &x) { return x; }); res.set<std::vector<int>>(values); } break; default: LOG(FATAL) << "unsupported attr type"; } return res; } private: std::vector<std::string> GetArguments( const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var> &xs, const std::string ¶m) const { std::vector<std::string> res; auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Var &it) { return it.parameter() == param; }); CHECK(it != xs.end()); const auto &ys = it->arguments(); std::transform(ys.begin(), ys.end(), std::back_inserter(res), [](const std::string &x) { return x; }); return res; } void SetArgument( google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var> *xs, const std::string ¶m, const std::vector<std::string> &args) { auto it = std::find_if(xs->begin(), xs->end(), [&](const framework::proto::OpDesc_Var &it) { return it.parameter() == param; }); if (it == xs->end()) { auto *new_arg = xs->Add(); new_arg->set_parameter(param); for (const auto &arg : args) { *new_arg->mutable_arguments()->Add() = arg; } } else { it->mutable_arguments()->Clear(); for (const auto &arg : args) { *it->mutable_arguments()->Add() = arg; } } } std::vector<std::string> GetArgumentNames( const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var> &xs) const { std::vector<std::string> res; std::transform( xs.begin(), xs.end(), std::back_inserter(res), [](const framework::proto::OpDesc_Var &x) { return x.parameter(); }); return res; } private: framework::proto::OpDesc desc_; }; template <> void OpDesc::SetAttr<std::string>(const std::string &name, const std::string &v); template <> void OpDesc::SetAttr<std::vector<int>>(const std::string &name, const std::vector<int> &v); } // namespace pb } // namespace lite } // namespace paddle