diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index c99406799ba5f664c4b9f80e0567b293e4ffea51..2545e6c6f186b67ef1d22e7b5714d6f3913716ac 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -35,6 +35,7 @@ enum AttrType { BLOCK = 8; LONG = 9; BLOCKS = 10; + LONGS = 11; } // OpDesc describes an instance of a C++ framework::OperatorBase @@ -55,6 +56,7 @@ message OpDesc { optional int32 block_idx = 12; optional int64 l = 13; repeated int32 blocks_idx = 14; + optional int64 longs = 15; }; message Var { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index c293cf92b4f3d530109c76850df184af9cad7399..7f81fb8641545e88750f118d6e80d1f1ae9fc860 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -421,6 +421,11 @@ struct SetAttrDescVisitor : public boost::static_visitor { } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(int64_t v) const { attr_->set_l(v); } + + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_longs()); + } + void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } }; diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index e099e40f121ff13657e563eb608feecbca0551be..2de6233a9e0d320ec9a06d547db3575eb61925c0 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -36,7 +36,7 @@ using Attribute = boost::variant, std::vector, std::vector, bool, std::vector, BlockDesc*, int64_t, - std::vector>; + std::vector, std::vector>; using AttributeMap = std::unordered_map; diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 3b22718a8c6f994dbc2dc3e7aaa19a7163f716ba..0edfbfaa875d1258f2c61915d26362d684320941 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -259,6 +259,8 @@ void BindOpDesc(pybind11::module *m) { pybind11::enum_(*m, "AttrType", "") .value("INT", pd::proto::AttrType::INT) .value("INTS", pd::proto::AttrType::INTS) + .value("LONG", pd::proto::AttrType::FLOAT) + .value("LONGS", pd::proto::AttrType::FLOAT) .value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOATS", pd::proto::AttrType::FLOATS) .value("STRING", pd::proto::AttrType::STRING)