未验证 提交 dfff52ea 编写于 作者: L Leo Chen 提交者: GitHub

refine varbase init function (#34052)

* remove check on kwargs

* refine code, reuse commom function
上级 1412d3bc
......@@ -146,21 +146,33 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
}
}
static void InitTensorForVarBase(imperative::VarBase *self,
const py::array &array,
const platform::Place place,
bool persistable = false,
bool zero_copy = false, std::string name = "",
int stop_gradient = -1) {
if (name == "") {
name =
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
}
VLOG(5) << "Init Tensor as: / name: " << name
<< " / persistable: " << persistable << " / zero_copy: " << zero_copy
// only initialize varbase, but not its tensor.
static void InitVarBaseOnly(imperative::VarBase *self, const std::string &name,
bool persistable = false, int stop_gradient = -1) {
auto name_ = name == ""
? imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_tensor")
: name;
VLOG(5) << "Init Tensor as: / name: " << name_
<< " / persistable: " << persistable
<< " / stop_gradient: " << stop_gradient;
new (self) imperative::VarBase(name);
new (self) imperative::VarBase(name_);
if (stop_gradient != -1) {
self->SetOverridedStopGradient(stop_gradient);
}
self->SetPersistable(persistable);
self->SetType(framework::proto::VarType::LOD_TENSOR);
}
// initialize varbase and its tensor.
static void InitVarBaseAndTensor(
imperative::VarBase *self, const py::array &array,
const platform::Place &place, const std::string &name,
bool persistable = false, bool zero_copy = false, int stop_gradient = -1) {
InitVarBaseOnly(self, name, persistable, stop_gradient);
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
VLOG(4) << "zero_copy: " << zero_copy;
if (platform::is_cpu_place(place)) {
SetTensorFromPyArray<platform::CPUPlace>(
tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
......@@ -182,26 +194,15 @@ static void InitTensorForVarBase(imperative::VarBase *self,
"Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"));
}
if (stop_gradient != -1) {
self->SetOverridedStopGradient(stop_gradient);
}
self->SetPersistable(persistable);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor->type());
}
static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
const py::kwargs &kwargs) {
VLOG(4) << "Init VarBase from kwargs: ";
PADDLE_ENFORCE_EQ(
kwargs.contains("value"), true,
platform::errors::NotFound(
"The kwargs used to create Varbase misses argument: value"));
auto persistable = kwargs.contains("persistable")
? kwargs["persistable"].cast<bool>()
: false;
auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
: py::array();
auto zero_copy =
kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
......@@ -209,10 +210,18 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
? kwargs["stop_gradient"].cast<int>()
: -1;
auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
: default_place;
InitTensorForVarBase(self, array, place, persistable, zero_copy, name,
stop_gradient);
if (kwargs.contains("value")) {
auto array = kwargs["value"].cast<py::array>();
// place is only used when array is given, otherwise, it is meaningless and
// ignored
auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
: default_place;
InitVarBaseAndTensor(self, array, place, name, persistable, zero_copy,
stop_gradient);
} else {
InitVarBaseOnly(self, name, persistable, stop_gradient);
}
}
template <typename P>
......@@ -247,7 +256,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
const py::array &array) {
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
VLOG(4) << "Init VarBase from numpy at " << place;
InitTensorForVarBase(self, array, place);
InitVarBaseAndTensor(self, array, place, "");
}
static void InitVarBaseFromTensorWithArgDefault(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册