From 8b7f45a8895a42cdd3e69576f9c9b3e4dce86245 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 24 Oct 2018 11:27:23 +0800 Subject: [PATCH] add longs in framework --- paddle/fluid/framework/framework.proto | 2 ++ paddle/fluid/framework/op_desc.cc | 5 +++++ paddle/fluid/framework/type_defs.h | 2 +- paddle/fluid/pybind/protobuf.cc | 2 ++ 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index c99406799ba..2545e6c6f18 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -35,6 +35,7 @@ enum AttrType { BLOCK = 8; LONG = 9; BLOCKS = 10; + LONGS = 11; } // OpDesc describes an instance of a C++ framework::OperatorBase @@ -55,6 +56,7 @@ message OpDesc { optional int32 block_idx = 12; optional int64 l = 13; repeated int32 blocks_idx = 14; + optional int64 longs = 15; }; message Var { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index c293cf92b4f..7f81fb86415 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -421,6 +421,11 @@ struct SetAttrDescVisitor : public boost::static_visitor { } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(int64_t v) const { attr_->set_l(v); } + + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_longs()); + } + void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } }; diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index e099e40f121..2de6233a9e0 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -36,7 +36,7 @@ using Attribute = boost::variant, std::vector, std::vector, bool, std::vector, BlockDesc*, int64_t, - std::vector>; + std::vector, std::vector>; using AttributeMap = std::unordered_map; diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 3b22718a8c6..0edfbfaa875 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -259,6 +259,8 @@ void BindOpDesc(pybind11::module *m) { pybind11::enum_(*m, "AttrType", "") .value("INT", pd::proto::AttrType::INT) .value("INTS", pd::proto::AttrType::INTS) + .value("LONG", pd::proto::AttrType::FLOAT) + .value("LONGS", pd::proto::AttrType::FLOAT) .value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOATS", pd::proto::AttrType::FLOATS) .value("STRING", pd::proto::AttrType::STRING) -- GitLab