未验证 提交 dfff52ea 编写于 作者: L Leo Chen 提交者: GitHub

refine varbase init function (#34052)

* remove check on kwargs

* refine code, reuse commom function
上级 1412d3bc
...@@ -146,21 +146,33 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) { ...@@ -146,21 +146,33 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
} }
} }
static void InitTensorForVarBase(imperative::VarBase *self, // only initialize varbase, but not its tensor.
const py::array &array, static void InitVarBaseOnly(imperative::VarBase *self, const std::string &name,
const platform::Place place, bool persistable = false, int stop_gradient = -1) {
bool persistable = false, auto name_ = name == ""
bool zero_copy = false, std::string name = "", ? imperative::GetCurrentTracer()->GenerateUniqueName(
int stop_gradient = -1) { "generated_tensor")
if (name == "") { : name;
name =
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"); VLOG(5) << "Init Tensor as: / name: " << name_
} << " / persistable: " << persistable
VLOG(5) << "Init Tensor as: / name: " << name
<< " / persistable: " << persistable << " / zero_copy: " << zero_copy
<< " / stop_gradient: " << stop_gradient; << " / stop_gradient: " << stop_gradient;
new (self) imperative::VarBase(name); new (self) imperative::VarBase(name_);
if (stop_gradient != -1) {
self->SetOverridedStopGradient(stop_gradient);
}
self->SetPersistable(persistable);
self->SetType(framework::proto::VarType::LOD_TENSOR);
}
// initialize varbase and its tensor.
static void InitVarBaseAndTensor(
imperative::VarBase *self, const py::array &array,
const platform::Place &place, const std::string &name,
bool persistable = false, bool zero_copy = false, int stop_gradient = -1) {
InitVarBaseOnly(self, name, persistable, stop_gradient);
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
VLOG(4) << "zero_copy: " << zero_copy;
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
SetTensorFromPyArray<platform::CPUPlace>( SetTensorFromPyArray<platform::CPUPlace>(
tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy); tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
...@@ -182,26 +194,15 @@ static void InitTensorForVarBase(imperative::VarBase *self, ...@@ -182,26 +194,15 @@ static void InitTensorForVarBase(imperative::VarBase *self,
"Place should be one of " "Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace")); "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"));
} }
if (stop_gradient != -1) {
self->SetOverridedStopGradient(stop_gradient);
}
self->SetPersistable(persistable);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor->type()); self->SetDataType(tensor->type());
} }
static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
const py::kwargs &kwargs) { const py::kwargs &kwargs) {
VLOG(4) << "Init VarBase from kwargs: "; VLOG(4) << "Init VarBase from kwargs: ";
PADDLE_ENFORCE_EQ(
kwargs.contains("value"), true,
platform::errors::NotFound(
"The kwargs used to create Varbase misses argument: value"));
auto persistable = kwargs.contains("persistable") auto persistable = kwargs.contains("persistable")
? kwargs["persistable"].cast<bool>() ? kwargs["persistable"].cast<bool>()
: false; : false;
auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
: py::array();
auto zero_copy = auto zero_copy =
kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false; kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : ""; auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
...@@ -209,10 +210,18 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, ...@@ -209,10 +210,18 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
? kwargs["stop_gradient"].cast<int>() ? kwargs["stop_gradient"].cast<int>()
: -1; : -1;
auto default_place = imperative::GetCurrentTracer()->ExpectedPlace(); auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
if (kwargs.contains("value")) {
auto array = kwargs["value"].cast<py::array>();
// place is only used when array is given, otherwise, it is meaningless and
// ignored
auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"]) auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
: default_place; : default_place;
InitTensorForVarBase(self, array, place, persistable, zero_copy, name, InitVarBaseAndTensor(self, array, place, name, persistable, zero_copy,
stop_gradient); stop_gradient);
} else {
InitVarBaseOnly(self, name, persistable, stop_gradient);
}
} }
template <typename P> template <typename P>
...@@ -247,7 +256,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, ...@@ -247,7 +256,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
const py::array &array) { const py::array &array) {
auto place = imperative::GetCurrentTracer()->ExpectedPlace(); auto place = imperative::GetCurrentTracer()->ExpectedPlace();
VLOG(4) << "Init VarBase from numpy at " << place; VLOG(4) << "Init VarBase from numpy at " << place;
InitTensorForVarBase(self, array, place); InitVarBaseAndTensor(self, array, place, "");
} }
static void InitVarBaseFromTensorWithArgDefault( static void InitVarBaseFromTensorWithArgDefault(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册