未验证 提交 fdf62e1e 编写于 作者: B Baibaifan 提交者: GitHub

Add varbase init name (#37947)

上级 a9e0d28c
......@@ -252,12 +252,16 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
InitVarBaseAndTensor(self, array, place, "");
}
static void InitVarBaseFromTensorWithArgDefault(
imperative::VarBase *self, const framework::Tensor &tensor) {
static void InitVarBaseFromTensorWithArgDefault(imperative::VarBase *self,
const framework::Tensor &tensor,
const std::string &name) {
VLOG(4) << "Init VarBase";
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
new (self) imperative::VarBase(
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
auto name_ = name == ""
? imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_tensor")
: name;
new (self) imperative::VarBase(name_);
self->SetPersistable(false);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor.type());
......@@ -275,10 +279,14 @@ static void InitVarBaseFromTensorWithArgDefault(
template <typename P>
static void InitVarBaseFromTensorWithArg(imperative::VarBase *self,
const framework::Tensor &tensor,
const P &place) {
const P &place,
const std::string &name) {
VLOG(4) << "Init VarBase";
new (self) imperative::VarBase(
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
auto name_ = name == ""
? imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_tensor")
: name;
new (self) imperative::VarBase(name_);
self->SetPersistable(false);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor.type());
......@@ -917,17 +925,18 @@ void BindImperative(py::module *m_ptr) {
py::arg("zero_copy") = false, py::arg("name") = "",
py::arg("stop_gradient") = -1)
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"),
py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CPUPlace>,
py::arg("tensor"), py::arg("place"))
py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::XPUPlace>,
py::arg("tensor"), py::arg("place"))
py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPlace>,
py::arg("tensor"), py::arg("place"))
py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPinnedPlace>,
py::arg("tensor"), py::arg("place"))
py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromTensorWithArg<platform::NPUPlace>,
py::arg("tensor"), py::arg("place"))
py::arg("tensor"), py::arg("place"), py::arg("name") = "")
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def(
"__setitem_varbase__",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册