// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once /* * This file implements a light-weight OpDesc using NaiveBuffer. */ #include #include #include #include #include #include "lite/model_parser/desc_apis.h" #include "lite/model_parser/naive_buffer/proto/framework.nb.h" namespace paddle { namespace lite { namespace naive_buffer { /* * The lite::naive_buffer::OpDesc, an light-weight implementation of wrapper of * lite::naive_buffer::proto::OpDesc. */ class OpDesc : public OpDescAPI { public: using var_list_t = ListBuilder; using str_list_t = ListBuilder; using attr_list_t = ListBuilder; OpDesc() = delete; explicit OpDesc(proto::OpDesc *desc) : desc_(desc) { CHECK(desc_); } void CopyFrom(OpDesc &op_desc) { CHECK(op_desc.Proto()) << "Source proto::OpDesc pointer can't be null"; desc_ = op_desc.Proto(); } proto::OpDesc *Proto() { return desc_; } const proto::OpDesc &ReadonlyProto() const { return *desc_; } std::string Type() const override { auto &builder = desc_->GetField("type"); return builder.data(); } void SetType(const std::string &type) override { auto *builder = desc_->GetMutableField("type"); CHECK(builder); return builder->set(type); } // Get the arguments of parameter called `param` std::vector Input(const std::string ¶m) const override { return GetArguments(desc_->GetField("inputs"), param); } std::vector InputArgumentNames() const override { return GetArgumentNames(desc_->GetField("inputs")); } void SetInput(const std::string ¶m, const std::vector &args) override { SetArgument(desc_->GetMutableField("inputs"), param, args); } std::vector Output(const std::string ¶m) const override { return GetArguments(desc_->GetField("outputs"), param); } std::vector OutputArgumentNames() const override { return GetArgumentNames(desc_->GetField("outputs")); } void SetOutput(const std::string ¶m, const std::vector &args) override { SetArgument(desc_->GetMutableField("outputs"), param, args); } bool HasAttr(const std::string &name) const override { const auto &xs = desc_->GetField("attrs"); auto it = std::find_if(xs.begin(), xs.end(), [&](const proto::OpDesc::Attr &x) { auto &builder = x.GetField("name"); return builder.data() == name; }); return it != xs.end(); } AttrType GetAttrType(const std::string &name) const override { const auto &xs = desc_->GetField("attrs"); auto it = std::find_if(xs.begin(), xs.end(), [&](const proto::OpDesc::Attr &x) { auto &builder = x.GetField("name"); return builder.data() == name; }); CHECK(it != xs.end()); #define DEF_ONE(type__) \ case proto::OpDesc::AttrType::type__: \ return AttrType::type__; auto &builder = it->GetField>("type"); switch (builder.data()) { DEF_ONE(INT); DEF_ONE(FLOAT); DEF_ONE(STRING); DEF_ONE(INTS); DEF_ONE(FLOATS); DEF_ONE(STRINGS); DEF_ONE(BOOLEAN); DEF_ONE(BOOLEANS); DEF_ONE(BLOCK); DEF_ONE(LONG); DEF_ONE(BLOCKS); DEF_ONE(LONGS); default: LOG(FATAL) << "Unknown attribute type"; return static_cast(-1); } #undef DEF_ONE } std::vector AttrNames() const override { std::vector res; const auto &xs = desc_->GetField("attrs"); std::transform(xs.begin(), xs.end(), std::back_inserter(res), [](const proto::OpDesc::Attr &x) { auto &builder = x.GetField("name"); return builder.data(); }); return res; } template void SetAttr(const std::string &name, const T &v); template T GetAttr(const std::string &name) const; std::string DebugString() const { return "Not Implemented"; } private: std::vector GetArguments(const var_list_t &xs, const std::string ¶m) const { std::vector res; auto it = std::find_if(xs.begin(), xs.end(), [&](const proto::OpDesc::Var &it) { auto &builder = it.GetField("parameter"); return builder.data() == param; }); CHECK(it != xs.end()); auto &list_builder = it->GetField("arguments"); std::transform(list_builder.begin(), list_builder.end(), std::back_inserter(res), [](const StringBuilder &x) { return x.data(); }); return res; } void SetArgument(var_list_t *xs, const std::string ¶m, const std::vector &args) { auto it = std::find_if(xs->begin(), xs->end(), [&](const proto::OpDesc::Var &it) { auto &builder = it.GetField("parameter"); return builder.data() == param; }); if (it == xs->end()) { auto *new_arg = xs->New(); auto *param_builder = new_arg->GetMutableField("parameter"); CHECK(param_builder); param_builder->set(param); auto *arg_builder = new_arg->GetMutableField("arguments"); CHECK(arg_builder); for (const auto &arg : args) { arg_builder->New()->set(arg); } } else { auto *arg_builder = it->GetMutableField("arguments"); CHECK(arg_builder); arg_builder->Clear(); for (const auto &arg : args) { arg_builder->New()->set(arg); } } } std::vector GetArgumentNames(const var_list_t &xs) const { std::vector res; std::transform(xs.begin(), xs.end(), std::back_inserter(res), [](const proto::OpDesc::Var &x) { auto &builder = x.GetField("parameter"); return builder.data(); }); return res; } private: // Don't owned by naive_buffer::OpDesc proto::OpDesc *desc_; }; template <> void OpDesc::SetAttr(const std::string &name, const std::string &v); template <> void OpDesc::SetAttr>(const std::string &name, const std::vector &v); } // namespace naive_buffer } // namespace lite } // namespace paddle