提交 027fc62c 编写于 作者: Y Yu Yang

Use Vec2Repeated Repeated2Vec

上级 f5aa8b4d
...@@ -17,6 +17,25 @@ limitations under the License. */ ...@@ -17,6 +17,25 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
template <typename T>
inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T> &repeated_field) {
std::vector<T> ret;
ret.reserve(repeated_field.size());
std::copy(
repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
return ret;
}
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Reserve(vec.size());
for (auto &elem : vec) {
*repeated_field->Add() = elem;
}
}
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
using namespace paddle::framework; // NOLINT using namespace paddle::framework; // NOLINT
py::class_<ProgramDesc>(m, "ProgramDesc", "") py::class_<ProgramDesc>(m, "ProgramDesc", "")
...@@ -70,10 +89,7 @@ void BindVarDsec(py::module &m) { ...@@ -70,10 +89,7 @@ void BindVarDsec(py::module &m) {
[](VarDesc &self, const std::string &name) { self.set_name(name); }) [](VarDesc &self, const std::string &name) { self.set_name(name); })
.def("set_shape", .def("set_shape",
[](VarDesc &self, const std::vector<int64_t> &dims) { [](VarDesc &self, const std::vector<int64_t> &dims) {
LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); VectorToRepeated(dims, self.mutable_lod_tensor()->mutable_dims());
for (const int64_t &i : dims) {
lod_tensor_desc->add_dims(i);
}
}) })
.def("set_data_type", .def("set_data_type",
[](VarDesc &self, int type_id) { [](VarDesc &self, int type_id) {
...@@ -82,12 +98,7 @@ void BindVarDsec(py::module &m) { ...@@ -82,12 +98,7 @@ void BindVarDsec(py::module &m) {
}) })
.def("shape", [](VarDesc &self) { .def("shape", [](VarDesc &self) {
const LoDTensorDesc &lod_tensor_desc = self.lod_tensor(); const LoDTensorDesc &lod_tensor_desc = self.lod_tensor();
int rank = lod_tensor_desc.dims_size(); return RepeatedToVector(lod_tensor_desc.dims());
std::vector<int64_t> res(rank);
for (int i = 0; i < rank; ++i) {
res[i] = lod_tensor_desc.dims(i);
}
return res;
}); });
} }
......
...@@ -27,25 +27,6 @@ namespace py = pybind11; ...@@ -27,25 +27,6 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
template <typename T>
inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T>& repeated_field) {
std::vector<T> ret;
ret.reserve(repeated_field.size());
std::copy(
repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
return ret;
}
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T>& vec,
RepeatedField* repeated_field) {
repeated_field->Reserve(vec.size());
for (auto& elem : vec) {
*repeated_field->Add() = elem;
}
}
void BindProgramDesc(py::module& m); void BindProgramDesc(py::module& m);
void BindBlockDesc(py::module& m); void BindBlockDesc(py::module& m);
void BindVarDsec(py::module& m); void BindVarDsec(py::module& m);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册