From dfff52eab67240ecba499f3cdf0f1f32b12bbeb5 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 9 Jul 2021 20:01:59 +0800 Subject: [PATCH] refine varbase init function (#34052) * remove check on kwargs * refine code, reuse commom function --- paddle/fluid/pybind/imperative.cc | 67 ++++++++++++++++++------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index af7f03dc197..619301e3b45 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -146,21 +146,33 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) { } } -static void InitTensorForVarBase(imperative::VarBase *self, - const py::array &array, - const platform::Place place, - bool persistable = false, - bool zero_copy = false, std::string name = "", - int stop_gradient = -1) { - if (name == "") { - name = - imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"); - } - VLOG(5) << "Init Tensor as: / name: " << name - << " / persistable: " << persistable << " / zero_copy: " << zero_copy +// only initialize varbase, but not its tensor. +static void InitVarBaseOnly(imperative::VarBase *self, const std::string &name, + bool persistable = false, int stop_gradient = -1) { + auto name_ = name == "" + ? imperative::GetCurrentTracer()->GenerateUniqueName( + "generated_tensor") + : name; + + VLOG(5) << "Init Tensor as: / name: " << name_ + << " / persistable: " << persistable << " / stop_gradient: " << stop_gradient; - new (self) imperative::VarBase(name); + new (self) imperative::VarBase(name_); + if (stop_gradient != -1) { + self->SetOverridedStopGradient(stop_gradient); + } + self->SetPersistable(persistable); + self->SetType(framework::proto::VarType::LOD_TENSOR); +} + +// initialize varbase and its tensor. +static void InitVarBaseAndTensor( + imperative::VarBase *self, const py::array &array, + const platform::Place &place, const std::string &name, + bool persistable = false, bool zero_copy = false, int stop_gradient = -1) { + InitVarBaseOnly(self, name, persistable, stop_gradient); auto *tensor = self->MutableVar()->GetMutable(); + VLOG(4) << "zero_copy: " << zero_copy; if (platform::is_cpu_place(place)) { SetTensorFromPyArray( tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy); @@ -182,26 +194,15 @@ static void InitTensorForVarBase(imperative::VarBase *self, "Place should be one of " "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace")); } - if (stop_gradient != -1) { - self->SetOverridedStopGradient(stop_gradient); - } - self->SetPersistable(persistable); - self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetDataType(tensor->type()); } static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, const py::kwargs &kwargs) { VLOG(4) << "Init VarBase from kwargs: "; - PADDLE_ENFORCE_EQ( - kwargs.contains("value"), true, - platform::errors::NotFound( - "The kwargs used to create Varbase misses argument: value")); auto persistable = kwargs.contains("persistable") ? kwargs["persistable"].cast() : false; - auto array = kwargs.contains("value") ? kwargs["value"].cast() - : py::array(); auto zero_copy = kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast() : false; auto name = kwargs.contains("name") ? kwargs["name"].cast() : ""; @@ -209,10 +210,18 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, ? kwargs["stop_gradient"].cast() : -1; auto default_place = imperative::GetCurrentTracer()->ExpectedPlace(); - auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"]) - : default_place; - InitTensorForVarBase(self, array, place, persistable, zero_copy, name, - stop_gradient); + + if (kwargs.contains("value")) { + auto array = kwargs["value"].cast(); + // place is only used when array is given, otherwise, it is meaningless and + // ignored + auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"]) + : default_place; + InitVarBaseAndTensor(self, array, place, name, persistable, zero_copy, + stop_gradient); + } else { + InitVarBaseOnly(self, name, persistable, stop_gradient); + } } template @@ -247,7 +256,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, const py::array &array) { auto place = imperative::GetCurrentTracer()->ExpectedPlace(); VLOG(4) << "Init VarBase from numpy at " << place; - InitTensorForVarBase(self, array, place); + InitVarBaseAndTensor(self, array, place, ""); } static void InitVarBaseFromTensorWithArgDefault( -- GitLab