提交 027fc62c 编写于 作者: Y Yu Yang

Use Vec2Repeated Repeated2Vec

上级 f5aa8b4d
......@@ -17,6 +17,25 @@ limitations under the License. */
namespace paddle {
namespace pybind {
template <typename T>
inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T> &repeated_field) {
std::vector<T> ret;
ret.reserve(repeated_field.size());
std::copy(
repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
return ret;
}
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Reserve(vec.size());
for (auto &elem : vec) {
*repeated_field->Add() = elem;
}
}
void BindProgramDesc(py::module &m) {
using namespace paddle::framework; // NOLINT
py::class_<ProgramDesc>(m, "ProgramDesc", "")
......@@ -70,10 +89,7 @@ void BindVarDsec(py::module &m) {
[](VarDesc &self, const std::string &name) { self.set_name(name); })
.def("set_shape",
[](VarDesc &self, const std::vector<int64_t> &dims) {
LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor();
for (const int64_t &i : dims) {
lod_tensor_desc->add_dims(i);
}
VectorToRepeated(dims, self.mutable_lod_tensor()->mutable_dims());
})
.def("set_data_type",
[](VarDesc &self, int type_id) {
......@@ -82,12 +98,7 @@ void BindVarDsec(py::module &m) {
})
.def("shape", [](VarDesc &self) {
const LoDTensorDesc &lod_tensor_desc = self.lod_tensor();
int rank = lod_tensor_desc.dims_size();
std::vector<int64_t> res(rank);
for (int i = 0; i < rank; ++i) {
res[i] = lod_tensor_desc.dims(i);
}
return res;
return RepeatedToVector(lod_tensor_desc.dims());
});
}
......
......@@ -27,25 +27,6 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
template <typename T>
inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T>& repeated_field) {
std::vector<T> ret;
ret.reserve(repeated_field.size());
std::copy(
repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
return ret;
}
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T>& vec,
RepeatedField* repeated_field) {
repeated_field->Reserve(vec.size());
for (auto& elem : vec) {
*repeated_field->Add() = elem;
}
}
void BindProgramDesc(py::module& m);
void BindBlockDesc(py::module& m);
void BindVarDsec(py::module& m);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册