From fdf62e1ea95ff0bf5ff7a3bc03346669c5760f50 Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Thu, 9 Dec 2021 11:14:48 +0800 Subject: [PATCH] Add varbase init name (#37947) --- paddle/fluid/pybind/imperative.cc | 35 +++++++++++++++++++------------ 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index dc97d98e8c..080323bbc2 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -252,12 +252,16 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, InitVarBaseAndTensor(self, array, place, ""); } -static void InitVarBaseFromTensorWithArgDefault( - imperative::VarBase *self, const framework::Tensor &tensor) { +static void InitVarBaseFromTensorWithArgDefault(imperative::VarBase *self, + const framework::Tensor &tensor, + const std::string &name) { VLOG(4) << "Init VarBase"; auto place = imperative::GetCurrentTracer()->ExpectedPlace(); - new (self) imperative::VarBase( - imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor")); + auto name_ = name == "" + ? imperative::GetCurrentTracer()->GenerateUniqueName( + "generated_tensor") + : name; + new (self) imperative::VarBase(name_); self->SetPersistable(false); self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetDataType(tensor.type()); @@ -275,10 +279,14 @@ static void InitVarBaseFromTensorWithArgDefault( template static void InitVarBaseFromTensorWithArg(imperative::VarBase *self, const framework::Tensor &tensor, - const P &place) { + const P &place, + const std::string &name) { VLOG(4) << "Init VarBase"; - new (self) imperative::VarBase( - imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor")); + auto name_ = name == "" + ? imperative::GetCurrentTracer()->GenerateUniqueName( + "generated_tensor") + : name; + new (self) imperative::VarBase(name_); self->SetPersistable(false); self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetDataType(tensor.type()); @@ -917,17 +925,18 @@ void BindImperative(py::module *m_ptr) { py::arg("zero_copy") = false, py::arg("name") = "", py::arg("stop_gradient") = -1) .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) - .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) + .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"), + py::arg("name") = "") .def("__init__", &InitVarBaseFromTensorWithArg, - py::arg("tensor"), py::arg("place")) + py::arg("tensor"), py::arg("place"), py::arg("name") = "") .def("__init__", &InitVarBaseFromTensorWithArg, - py::arg("tensor"), py::arg("place")) + py::arg("tensor"), py::arg("place"), py::arg("name") = "") .def("__init__", &InitVarBaseFromTensorWithArg, - py::arg("tensor"), py::arg("place")) + py::arg("tensor"), py::arg("place"), py::arg("name") = "") .def("__init__", &InitVarBaseFromTensorWithArg, - py::arg("tensor"), py::arg("place")) + py::arg("tensor"), py::arg("place"), py::arg("name") = "") .def("__init__", &InitVarBaseFromTensorWithArg, - py::arg("tensor"), py::arg("place")) + py::arg("tensor"), py::arg("place"), py::arg("name") = "") .def("__init__", &InitVarBaseFromNumpyWithKwargs) .def( "__setitem_varbase__", -- GitLab