未验证 提交 fdf62e1e 编写于 作者: B Baibaifan 提交者: GitHub

Add varbase init name (#37947)

上级 a9e0d28c
...@@ -252,12 +252,16 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, ...@@ -252,12 +252,16 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
InitVarBaseAndTensor(self, array, place, ""); InitVarBaseAndTensor(self, array, place, "");
} }
static void InitVarBaseFromTensorWithArgDefault( static void InitVarBaseFromTensorWithArgDefault(imperative::VarBase *self,
imperative::VarBase *self, const framework::Tensor &tensor) { const framework::Tensor &tensor,
const std::string &name) {
VLOG(4) << "Init VarBase"; VLOG(4) << "Init VarBase";
auto place = imperative::GetCurrentTracer()->ExpectedPlace(); auto place = imperative::GetCurrentTracer()->ExpectedPlace();
new (self) imperative::VarBase( auto name_ = name == ""
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor")); ? imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_tensor")
: name;
new (self) imperative::VarBase(name_);
self->SetPersistable(false); self->SetPersistable(false);
self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor.type()); self->SetDataType(tensor.type());
...@@ -275,10 +279,14 @@ static void InitVarBaseFromTensorWithArgDefault( ...@@ -275,10 +279,14 @@ static void InitVarBaseFromTensorWithArgDefault(
template <typename P> template <typename P>
static void InitVarBaseFromTensorWithArg(imperative::VarBase *self, static void InitVarBaseFromTensorWithArg(imperative::VarBase *self,
const framework::Tensor &tensor, const framework::Tensor &tensor,
const P &place) { const P &place,
const std::string &name) {
VLOG(4) << "Init VarBase"; VLOG(4) << "Init VarBase";
new (self) imperative::VarBase( auto name_ = name == ""
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor")); ? imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_tensor")
: name;
new (self) imperative::VarBase(name_);
self->SetPersistable(false); self->SetPersistable(false);
self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor.type()); self->SetDataType(tensor.type());
...@@ -917,17 +925,18 @@ void BindImperative(py::module *m_ptr) { ...@@ -917,17 +925,18 @@ void BindImperative(py::module *m_ptr) {
py::arg("zero_copy") = false, py::arg("name") = "", py::arg("zero_copy") = false, py::arg("name") = "",
py::arg("stop_gradient") = -1) py::arg("stop_gradient") = -1)
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"),
py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CPUPlace>, .def("__init__", &InitVarBaseFromTensorWithArg<platform::CPUPlace>,
py::arg("tensor"), py::arg("place")) py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::XPUPlace>, .def("__init__", &InitVarBaseFromTensorWithArg<platform::XPUPlace>,
py::arg("tensor"), py::arg("place")) py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPlace>, .def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPlace>,
py::arg("tensor"), py::arg("place")) py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPinnedPlace>, .def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPinnedPlace>,
py::arg("tensor"), py::arg("place")) py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::NPUPlace>, .def("__init__", &InitVarBaseFromTensorWithArg<platform::NPUPlace>,
py::arg("tensor"), py::arg("place")) py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromNumpyWithKwargs) .def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def( .def(
"__setitem_varbase__", "__setitem_varbase__",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册