提交 8b7f45a8 编写于 作者: T tangwei12

add longs in framework

上级 f3729db6
...@@ -35,6 +35,7 @@ enum AttrType { ...@@ -35,6 +35,7 @@ enum AttrType {
BLOCK = 8; BLOCK = 8;
LONG = 9; LONG = 9;
BLOCKS = 10; BLOCKS = 10;
LONGS = 11;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -55,6 +56,7 @@ message OpDesc { ...@@ -55,6 +56,7 @@ message OpDesc {
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14; repeated int32 blocks_idx = 14;
optional int64 longs = 15;
}; };
message Var { message Var {
......
...@@ -421,6 +421,11 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -421,6 +421,11 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(const std::vector<int64_t> &v) const {
VectorToRepeated(v, attr_->mutable_longs());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
......
...@@ -36,7 +36,7 @@ using Attribute = ...@@ -36,7 +36,7 @@ using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*, int64_t, std::vector<bool>, BlockDesc*, int64_t,
std::vector<BlockDesc*>>; std::vector<BlockDesc*>, std::vector<int64_t>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -259,6 +259,8 @@ void BindOpDesc(pybind11::module *m) { ...@@ -259,6 +259,8 @@ void BindOpDesc(pybind11::module *m) {
pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "") pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "")
.value("INT", pd::proto::AttrType::INT) .value("INT", pd::proto::AttrType::INT)
.value("INTS", pd::proto::AttrType::INTS) .value("INTS", pd::proto::AttrType::INTS)
.value("LONG", pd::proto::AttrType::FLOAT)
.value("LONGS", pd::proto::AttrType::FLOAT)
.value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOAT", pd::proto::AttrType::FLOAT)
.value("FLOATS", pd::proto::AttrType::FLOATS) .value("FLOATS", pd::proto::AttrType::FLOATS)
.value("STRING", pd::proto::AttrType::STRING) .value("STRING", pd::proto::AttrType::STRING)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册