未验证 提交 3f630658 编写于 作者: W wanghuancoder 提交者: GitHub

test del python varbase (#55788)

* del python varbase
上级 d7aef892
...@@ -69,7 +69,6 @@ namespace paddle { ...@@ -69,7 +69,6 @@ namespace paddle {
namespace pybind { namespace pybind {
std::atomic<int> VarBaseUniqueNameID{0}; std::atomic<int> VarBaseUniqueNameID{0};
PyTypeObject *g_varbase_pytype = nullptr;
namespace py = ::pybind11; namespace py = ::pybind11;
...@@ -646,1500 +645,6 @@ void BindImperative(py::module *m_ptr) { ...@@ -646,1500 +645,6 @@ void BindImperative(py::module *m_ptr) {
egr::Controller::Instance().SetCurrentTracer(tracer); egr::Controller::Instance().SetCurrentTracer(tracer);
imperative::SetCurrentTracer(tracer); imperative::SetCurrentTracer(tracer);
}); });
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>> varbase(
m, "VarBase", R"DOC()DOC");
g_varbase_pytype = (PyTypeObject *)varbase.ptr(); // NOLINT
varbase.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
.def("__init__",
[](imperative::VarBase &self) {
std::string name =
imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_tensor");
new (&self) imperative::VarBase(name);
})
.def("__init__",
[](imperative::VarBase &self,
framework::proto::VarType::Type dtype,
const std::vector<int64_t> &dims,
const py::handle &name,
framework::proto::VarType::Type type,
bool persistable) {
VLOG(4) << "Init VarBase";
std::string act_name = "";
if (!name.ptr() || name.ptr() == Py_None) {
act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_tensor");
} else {
act_name = name.cast<std::string>();
}
new (&self) imperative::VarBase(act_name);
self.SetPersistable(persistable);
self.SetType(type);
self.SetDataType(dtype);
if (type == framework::proto::VarType::LOD_TENSOR) {
auto *tensor = self.MutableVar()->GetMutable<phi::DenseTensor>();
tensor->Resize(phi::make_ddim(dims));
}
})
.def("__init__",
&InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
py::arg("value"),
py::arg("place"),
py::arg("persistable") = false,
py::arg("zero_copy") = false,
py::arg("name") = "",
py::arg("stop_gradient") = -1)
.def("__init__",
&InitVarBaseFromNumpyWithArg<platform::XPUPlace>,
py::arg("value"),
py::arg("place"),
py::arg("persistable") = false,
py::arg("zero_copy") = false,
py::arg("name") = "",
py::arg("stop_gradient") = -1)
.def("__init__",
&InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
py::arg("value"),
py::arg("place"),
py::arg("persistable") = false,
py::arg("zero_copy") = false,
py::arg("name") = "",
py::arg("stop_gradient") = -1)
.def("__init__",
&InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
py::arg("value"),
py::arg("place"),
py::arg("persistable") = false,
py::arg("zero_copy") = false,
py::arg("name") = "",
py::arg("stop_gradient") = -1)
.def("__init__",
&InitVarBaseFromNumpyWithArg<platform::CustomPlace>,
py::arg("value"),
py::arg("place"),
py::arg("persistable") = false,
py::arg("zero_copy") = false,
py::arg("name") = "",
py::arg("stop_gradient") = -1)
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__",
&InitVarBaseFromTensorWithArgDefault,
py::arg("tensor"),
py::arg("name") = "")
.def("__init__",
&InitVarBaseFromTensorWithArg<platform::CPUPlace>,
py::arg("tensor"),
py::arg("place"),
py::arg("name") = "")
.def("__init__",
&InitVarBaseFromTensorWithArg<platform::XPUPlace>,
py::arg("tensor"),
py::arg("place"),
py::arg("name") = "")
.def("__init__",
&InitVarBaseFromTensorWithArg<platform::CUDAPlace>,
py::arg("tensor"),
py::arg("place"),
py::arg("name") = "")
.def("__init__",
&InitVarBaseFromTensorWithArg<platform::CUDAPinnedPlace>,
py::arg("tensor"),
py::arg("place"),
py::arg("name") = "")
.def("__init__",
&InitVarBaseFromTensorWithArg<platform::CustomPlace>,
py::arg("tensor"),
py::arg("place"),
py::arg("name") = "")
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def(
"__setitem_varbase__",
[](std::shared_ptr<imperative::VarBase> &self,
py::handle _index,
py::object &value_obj) {
VLOG(4) << "Call __setitem_varbase__";
auto self_tensor =
self->MutableVar()->GetMutable<phi::DenseTensor>();
// NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
// https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
PyObject *index_ptr = !PyTuple_Check(_index.ptr())
? PyTuple_Pack(1, _index.ptr())
: _index.ptr();
DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
if (!PyTuple_Check(_index.ptr())) {
Py_DECREF(index_ptr);
VLOG(4) << "Call Py_DECREF";
}
});
auto is_tensor = [](py::handle var) {
if (!var.ptr() || var.ptr() == Py_None) {
return false;
}
try {
py::cast<std::shared_ptr<imperative::VarBase>>(var);
return true;
} catch (py::cast_error &) {
return false;
}
};
// NOTE(liym27):
// Increase the version of VarBase self because __setitem__ is an
// inplace operator for the VarBase self.
self->BumpInplaceVersion();
// 1. Check arguments
bool parse_index = true;
// Check whether _index can be parsed.
const int size = PyTuple_GET_SIZE(index_ptr);
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index_ptr, dim);
if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
slice_item == Py_Ellipsis || slice_item == Py_None)) {
parse_index = false;
break;
}
}
// 2. Call op set_value to speed up if the condition is met,
// otherwise call TensorToPyArray.
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if (parse_index) {
std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags;
std::vector<int64_t> list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
ParseIndexingSlice(self_tensor,
index_ptr,
&axes,
&starts,
&ends,
&steps,
&decrease_axes,
&none_axes,
&infer_flags,
&list_select_idxs,
&list_select_flag);
framework::AttributeMap attrs = {{"axes", axes},
{"starts", starts},
{"ends", ends},
{"steps", steps},
{"decrease_axes", decrease_axes},
{"none_axes", none_axes}};
imperative::NameVarBaseMap ins = {{"Input", {self}}};
imperative::NameVarBaseMap outs = {{"Out", {self}}};
const auto &tracer = imperative::GetCurrentTracer();
if (tracer->HasGrad()) {
PADDLE_ENFORCE_EQ(
self->IsLeaf() && !self->OverridedStopGradient(),
false,
platform::errors::InvalidArgument(
"Leaf Tensor (%s) that doesn't stop gradient can't use "
"inplace strategy.",
self->Name()));
}
if (py::isinstance<imperative::VarBase>(value_obj.ptr())) {
auto value_tensor =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
ins.insert({"ValueTensor", {value_tensor}});
// pass the stop_gradient from value to tensor
if (!value_tensor->OverridedStopGradient() &&
self->OverridedStopGradient()) {
self->SetOverridedStopGradient(false);
}
} else if (py::isinstance<py::array>(value_obj)) {
auto value_tensor = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(false,
tracer->GenerateUniqueName()));
py::object value = value_obj;
if (self->DataType() == framework::proto::VarType::FP32) {
if (!py::isinstance<py::array_t<float>>(value_obj)) {
value = pybind11::detail::CastNumpyArray<float>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::FP64) {
if (!py::isinstance<py::array_t<double>>(value_obj)) {
value = pybind11::detail::CastNumpyArray<double>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::INT32) {
if (!py::isinstance<py::array_t<int32_t>>(value_obj)) {
value =
pybind11::detail::CastNumpyArray<int32_t>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::INT64) {
if (!py::isinstance<py::array_t<int64_t>>(value_obj)) {
value =
pybind11::detail::CastNumpyArray<int64_t>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::BOOL) {
if (!py::isinstance<py::array_t<bool>>(value_obj)) {
value = pybind11::detail::CastNumpyArray<bool>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX64) {
if (!py::isinstance<py::array_t<std::complex<float>>>(
value_obj)) {
value =
pybind11::detail::CastNumpyArray<std::complex<float>>(
value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX128) {
if (!py::isinstance<py::array_t<std::complex<double>>>(
value_obj)) {
value =
pybind11::detail::CastNumpyArray<std::complex<double>>(
value_obj);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, float64, complex64, complex128, int32 or "
"int64, "
"please check the type of tensor."));
}
SetTensorFromPyArray(
value_tensor->MutableVar()->GetMutable<phi::DenseTensor>(),
value,
self->Place(),
false);
ins.insert({"ValueTensor", {value_tensor}});
} else {
// convert the value to self data type
if (py::isinstance<py::float_>(value_obj) ||
py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::bool_>(value_obj) ||
PyComplex_Check(value_obj.ptr())) {
if (self->DataType() == framework::proto::VarType::FP32) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<float>()};
} else if (self->DataType() ==
framework::proto::VarType::FP64) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<double>()};
} else if (self->DataType() ==
framework::proto::VarType::INT32) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<int32_t>()};
} else if (self->DataType() ==
framework::proto::VarType::INT64) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<int64_t>()};
} else if (self->DataType() ==
framework::proto::VarType::BOOL) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<bool>()};
} else if (self->DataType() ==
framework::proto::VarType::FP16) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<float>()};
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX64) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<std::complex<float>>()};
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX128) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<std::complex<double>>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, float64, complex64, complex128, int32, int64 "
"or float16, "
"please check the type of tensor."));
}
attrs["shape"] = std::vector<int64_t>{1};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Value type error. The assign value allows "
"numpy.ndarray, integer, float or bool, "
"but received %s.",
Py_TYPE(value_obj.ptr())));
}
}
{
// Release gil and do tracing
py::gil_scoped_release release;
tracer->TraceOp("set_value",
ins,
outs,
std::move(attrs),
{{"Input", "Out"}});
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
VLOG(4) << "parse_index is false";
if (is_tensor(_index)) {
VLOG(4) << "index is tensor";
auto index_var =
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
auto index_tensor =
index_var->MutableVar()->GetMutable<phi::DenseTensor>();
auto index_numpy = TensorToPyArray(*index_tensor);
self_numpy[index_numpy] = value_obj;
} else {
VLOG(4) << "index is not tensor";
self_numpy[_index] = value_obj;
}
SetTensorFromPyArray(
self_tensor, self_numpy, self_tensor->place(), false);
}
})
.def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
VLOG(4) << "Call _getitem_index_not_tensor";
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, none_axes, infer_flags;
std::vector<int64_t> list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
auto tensor = self->MutableVar()->GetMutable<phi::DenseTensor>();
ParseIndexingSlice(tensor,
_index.ptr(),
&slice_axes,
&slice_starts,
&slice_ends,
&slice_strides,
&decrease_axis,
&none_axes,
&infer_flags,
&list_select_idxs,
&list_select_flag);
// release gil and do tracing
py::gil_scoped_release release;
const auto &tracer = imperative::GetCurrentTracer();
auto out = slice_axes.empty() && !list_select_flag
? self
: std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(
tracer->GenerateUniqueName()));
if (!slice_axes.empty()) {
imperative::NameVarBaseMap ins = {{"Input", {self}}};
framework::AttributeMap attrs = {
{"axes", slice_axes},
{"starts", slice_starts},
{"ends", slice_ends},
{"infer_flags", infer_flags},
{"decrease_axis", decrease_axis}};
imperative::NameVarBaseMap outs = {{"Out", {out}}};
std::string op_type = "slice";
for (auto stride : slice_strides) {
if (stride != 1) {
op_type = "strided_slice";
attrs.insert({"strides", slice_strides});
attrs.erase("decrease_axis");
break;
}
}
tracer->TraceOp(op_type, ins, outs, std::move(attrs));
}
bool set_to_1d = FLAGS_set_to_1d;
if (set_to_1d) {
// NOTE(zoooo0820): When all axes are decreased, the output
// will be 1-D with FLAGS_set_to_1d=True. In this case, one
// `None` should be pop out, otherwise the output shape will be
// not correct.
if (static_cast<int>(decrease_axis.size()) ==
tensor->dims().size()) {
VLOG(1) << "Warning: In Tensor '__getitem__', if the number "
"of scalar "
"elements "
"in the index is equal to the rank of the Tensor, "
"the output "
"should "
"be 0-D. In order to be consistent with the "
"behavior of previous "
"versions, it will be processed to 1-D. But it is "
"not correct and "
"will be "
"removed in release 2.6. "
"If 1-D is still wanted, please modify the index "
"element from "
"scalar to slice "
"(e.g. 'x[i]' => 'x[i:i+1]'). ";
if (!none_axes.empty()) {
none_axes.pop_back();
}
}
}
if (!none_axes.empty()) {
// Deal with cases that decrease_axes is not empty
// For example:
// # x.shape: (2,3,4)
// out = x[0, 0:2, None] # out.shape : (2, 1, 4)
for (auto &axis : none_axes) {
int len = 0;
for (int da : decrease_axis) {
if (da < axis) {
len++;
}
}
axis -= len;
}
imperative::NameVarBaseMap ins = {{"X", {out}}};
framework::AttributeMap attrs = {{"axes", none_axes}};
auto new_out = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
auto out_xshape = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
imperative::NameVarBaseMap outs = {{"Out", {new_out}},
{"XShape", {out_xshape}}};
tracer->TraceOp("unsqueeze2", ins, outs, std::move(attrs));
return new_out;
}
// the index is a list
if (list_select_flag) {
auto select_index = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
auto *idx_tensor =
select_index->MutableVar()->GetMutable<phi::DenseTensor>();
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(
tracer->ExpectedPlace());
paddle::framework::TensorFromVector(
list_select_idxs, *dev_ctx, idx_tensor);
imperative::NameVarBaseMap ins = {{"X", {self}},
{"Index", {select_index}}};
imperative::NameVarBaseMap outs = {{"Out", {out}}};
tracer->TraceOp("index_select", ins, outs, {{"dim", 0}});
}
return out;
})
.def(
"_getitem_from_offset",
[](std::shared_ptr<imperative::VarBase> &self, const py::args &args) {
const auto &tensor = self->Var().Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor of %s is Empty, please check if it has no data.",
self->Name()));
const auto &tensor_dims = tensor.dims();
std::vector<size_t> dims(tensor_dims.size());
std::vector<size_t> strides(tensor_dims.size());
size_t numel = 1;
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
strides[i] = numel;
dims[i] = static_cast<size_t>(tensor_dims[i]);
numel *= dims[i];
}
size_t offset = 0;
if (args.empty()) {
PADDLE_ENFORCE_EQ(
numel,
1,
platform::errors::InvalidArgument(
"only one element tensors can be converted to Python "
"scalars when no input coordinates"));
} else if (args.size() == 1) {
offset = args[0].cast<size_t>();
PADDLE_ENFORCE_LT(
offset,
numel,
platform::errors::InvalidArgument(
"index %d is out of bounds for size %d", offset, numel));
} else {
PADDLE_ENFORCE_EQ(args.size(),
dims.size(),
platform::errors::InvalidArgument(
"incorrect number of indices for Tensor"));
for (size_t i = 0; i < args.size(); ++i) {
size_t index = args[i].cast<size_t>();
PADDLE_ENFORCE_LT(
index,
dims[i],
platform::errors::InvalidArgument(
"index %d is out fo bounds for axis %d with size %d",
index,
i,
dims[i]));
offset += index * strides[i];
}
}
#define TENSOR_TO_PY_SCALAR(T, proto_type) \
if (framework::TransToProtoVarType(tensor.dtype()) == proto_type) { \
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(proto_type); \
T b = TensorGetElement<T>(tensor, offset); \
return py::array( \
py::dtype(py_dtype_str.c_str()), {}, {}, static_cast<void *>(&b)); \
}
_ForEachDataType_(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported tensor data type: %s", tensor.dtype()));
},
py::return_value_policy::copy)
.def("_inplace_version",
[](imperative::VarBase &self) -> uint32_t {
const auto &var = self.MutableVar();
PADDLE_ENFORCE_EQ(
var->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor of %s is Empty, please check if it has no data.",
self.Name()));
return var->CurrentInplaceVersion();
})
.def(
"_bump_inplace_version",
[](std::shared_ptr<imperative::VarBase> &self) {
// NOTE(liym27): _bump_inplace_version is only used for inplace
// operation
self->BumpInplaceVersion();
},
R"DOC(
**Notes**:
**This API is ONLY available in Dygraph mode.**
**This is a very low level API. Users should not use it directly. **
Bump the version whenever the Tensor is modified through an inplace operation.
)DOC")
.def(
"numpy",
[](imperative::VarBase &self) -> py::array {
const auto &tensor = self.MutableVar()->Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor of %s is Empty, please check if it has no data.",
self.Name()));
return TensorToPyArray(tensor, true);
},
R"DOC(
Returns a numpy array shows the value of current Tensor.
Returns:
ndarray: The numpy value of current Tensor.
Returns type:
ndarray: dtype is same as current Tensor
Examples:
.. code-block:: python
import paddle
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
linear = paddle.nn.Linear(32, 64)
data = paddle.to_tensor(data)
x = linear(data)
print(x.numpy())
)DOC")
.def(
"detach",
[](const imperative::VarBase &self)
-> std::shared_ptr<imperative::VarBase> {
PADDLE_ENFORCE_EQ(
self.Var().IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
PADDLE_ENFORCE_EQ(
self.Var().IsType<phi::DenseTensor>() ||
self.Var().IsType<phi::SelectedRows>(),
true,
platform::errors::InvalidArgument(
"Type of Tensor[%s] must be LoDTensor or SelectedRows!",
self.Name()));
auto detach_var = std::make_shared<imperative::VarBase>(
true, "detach_" + self.Name());
detach_var->SetPersistable(self.Persistable());
detach_var->SetType(self.Type());
detach_var->SetDataType(self.DataType());
if (self.Var().IsType<phi::DenseTensor>()) {
const auto &origin_tensor = self.Var().Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
origin_tensor.IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_tensor =
detach_var->MutableVar()->GetMutable<phi::DenseTensor>();
detach_tensor->ShareDataWith(origin_tensor);
// NOTE(liym27): Call ShareInplaceVersionCounterWith to share the
// same TensorInplaceVersion, which is used to check whether
// inplace
// operations are correct.
detach_tensor->ShareInplaceVersionCounterWith(origin_tensor);
} else {
const auto &origin_selected_rows =
self.Var().Get<phi::SelectedRows>();
PADDLE_ENFORCE_EQ(
origin_selected_rows.value().IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_selected_rows =
detach_var->MutableVar()->GetMutable<phi::SelectedRows>();
detach_selected_rows->set_height(origin_selected_rows.height());
detach_selected_rows->set_rows(origin_selected_rows.rows());
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
detach_selected_rows->mutable_value()
->ShareInplaceVersionCounterWith(
origin_selected_rows.value());
}
VLOG(3) << "The detached Tensor(" << detach_var->Name()
<< ") share data with " << self.Name();
return detach_var;
},
py::return_value_policy::take_ownership,
R"DOC(
Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy.
In addition, the detached Tensor doesn't provide gradient propagation.
Returns: The detached Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.0], stop_gradient=False)
detach_x = x.detach()
detach_x[:] = 10.0
print(x) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
# [10.])
y = x**2
y.backward()
print(x.grad) # [20.0]
print(detach_x.grad) # None, 'stop_gradient=True' by default
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = detach_x**3
z.backward()
print(x.grad) # [20.0], detach_x is detached from x's graph, not affect each other
print(detach_x.grad) # [300.0], detach_x has its own graph
# Due to sharing of data with origin Tensor, There are some unsafe operations:
y = 2 * x
detach_x[:] = 5.0
y.backward()
# It will raise Error:
# one of the variables needed for gradient computation has been modified by an inplace operation.
)DOC")
.def("clear_gradient",
&imperative::VarBase::ClearGradient,
py::arg("set_to_zero") = true,
R"DOC(
Only for Tensor that has gradient, normally we use this for Parameters since other temporary Tensor doesen't has gradient.
The Gradient of current Tensor will be set to ``0`` .
Returns: None
Examples:
.. code-block:: python
import paddle
input = paddle.uniform([10, 2])
linear = paddle.nn.Linear(2, 3)
out = linear(input)
out.backward()
print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
linear.weight.clear_gradient()
print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
)DOC")
.def("_gradient_set_empty",
&imperative::VarBase::_GradientSetEmpty,
py::arg("set_is_empty") = true)
.def("_is_gradient_set_empty", &imperative::VarBase::_IsGradientSetEmpty)
.def(
"clone",
[](std::shared_ptr<imperative::VarBase> &self) {
const auto &tensor = self->Var().Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(tensor.IsInitialized(),
true,
platform::errors::InvalidArgument(
"%s has not been initialized", self->Name()));
auto tracer = imperative::GetCurrentTracer();
auto new_var = std::make_shared<imperative::VarBase>(
true, tracer->GenerateUniqueName(self->Name() + "_clone"));
framework::AttributeMap attrs;
imperative::NameVarBaseMap ins = {{"X", {self}}};
imperative::NameVarBaseMap outs = {{"Out", {new_var}}};
tracer->TraceOp("assign", ins, outs, attrs);
return new_var;
},
py::return_value_policy::copy,
R"DOC(
Returns a new Tensor, which is clone of origin Tensor, and it remains in the current graph.
It will always have a Tensor copy.
Tn addition, the cloned Tensor provides gradient propagation.
Returns: The cloned Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, stop_gradient=False)
clone_x = x.clone()
y = clone_x**2
y.backward()
print(clone_x.stop_gradient) # False
print(clone_x.grad) # [2.0], support gradient propagation
print(x.stop_gradient) # False
print(x.grad) # [2.0], clone_x support gradient propagation for x
x = paddle.to_tensor(1.0)
clone_x = x.clone()
clone_x.stop_gradient = False
z = clone_x**3
z.backward()
print(clone_x.stop_gradient) # False
print(clone_x.grad) # [3.0], support gradient propagation
print(x.stop_gradient) # True
print(x.grad) # None
)DOC")
.def("_grad_name", &imperative::VarBase::GradVarName)
.def(
"_grad_value",
[](imperative::VarBase &self) {
return self.MutableGradVar()->Get<phi::DenseTensor>();
},
py::return_value_policy::reference)
.def("_set_grad_type",
[](imperative::VarBase &self, framework::proto::VarType::Type type) {
self.MutableGradVarBase()->SetType(type);
})
.def("_reset_grad_inplace_version",
[](imperative::VarBase &self, bool set_to_zero) {
/*
*** This interfaceis a complete hack ***
reset_grad_inplace_version removes all inplace related records to
Grad VarBase/VariableWrapper,
the essential purpose of which is to let you use inplace operations
as if using its non-inplaced version,
which of course will cause unexpected consequences if not used with
care.
Make sure you fully understand what you're doing before make use of
this interface, and prepare for the worst.
*/
py::gil_scoped_release release;
if (self.HasGradVar()) {
auto grad_var = self.GradVarBase();
auto var_wrapper = grad_var->SharedVar();
if (var_wrapper) {
var_wrapper->ResetInplaceVersion(set_to_zero);
}
}
})
.def(
"_grad_ivar",
[](const imperative::VarBase &self) {
auto &grad_var = self.GradVarBase();
if (grad_var && grad_var->Var().IsInitialized()) {
auto *tensor =
grad_var->MutableVar()->IsType<phi::DenseTensor>()
? grad_var->MutableVar()->GetMutable<phi::DenseTensor>()
: grad_var->MutableVar()
->GetMutable<phi::SelectedRows>()
->mutable_value();
if (tensor->IsInitialized()) {
return grad_var;
}
}
return std::shared_ptr<imperative::VarBase>(nullptr);
},
py::return_value_policy::copy)
.def("_set_grad_ivar",
[](imperative::VarBase &self, imperative::VarBase &grad) {
self.SetGradVarBase(grad);
})
.def("_is_sparse",
[](imperative::VarBase &self) {
return self.Var().IsType<phi::SelectedRows>();
})
.def(
"_allreduce",
[](imperative::VarBase &self,
const imperative::ParallelStrategy &strategy) {
if (strategy.nranks_ > 1) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2212
imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
#else
if (!self.Var().IsType<phi::SelectedRows>()) {
imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Imperative SelectedRows allreduce is not supported when "
"paddle is compiled with NCCL version lower than v2.2.12. "
"You can set is_sparse=False for the Layer containing "
"this argument, such as Embedding(is_sparse=False)."));
}
#endif // NCCL_VERSION_CODE
#else
PADDLE_THROW(platform::errors::Unimplemented(
"Imperative allreduce is not supported when paddle is "
"not compiled with NCCL."));
#endif // PADDLE_WITH_NCCL or PADDLE_WITH_RCCL
}
},
py::call_guard<py::gil_scoped_release>())
.def("_register_grad_hook",
[](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ(
!self.OverridedStopGradient() && self.HasGradVar(),
true,
platform::errors::InvalidArgument(
"Cannot register gradient hook on a Tensor that stop "
"gradient or without gradient."));
return self.GradVarBase()->AddVariableWrapperHook(
std::make_shared<PyVariableWrapperHook>(hook.ptr()));
})
.def("_remove_grad_hook",
[](imperative::VarBase &self, int64_t hook_id) {
PADDLE_ENFORCE_EQ(
!self.OverridedStopGradient() && self.HasGradVar(),
true,
platform::errors::InvalidArgument(
"Cannot remove gradient hook on a Tensor that stop "
"gradient or without gradient."));
return self.GradVarBase()->RemoveVariableWrapperHook(hook_id);
})
.def("_register_void_function_post_hook",
[](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ(
!self.OverridedStopGradient() && self.HasGradVar(),
true,
platform::errors::InvalidArgument(
"Cannot register void function post hook on a Tensor that "
"stop "
"gradient or without gradient."));
auto py_func = PyObjectCast<std::function<void()>>(hook.ptr());
auto grad_node = self.MutableGradVarBase()->GradNode();
for (auto &cur_op : *grad_node) {
cur_op.AddVoidFunctionPostHook(
std::make_shared<std::function<void()>>(py_func));
}
})
.def(
"_register_backward_hook",
[](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ(
self.IsLeaf(),
true,
platform::errors::InvalidArgument(
"Only can register backward hook for leaf Tensor."));
PADDLE_ENFORCE_EQ(
!self.OverridedStopGradient() && self.HasGradVar(),
true,
platform::errors::InvalidArgument(
"Cannot register backward hook on a Tensor that stop "
"gradient or without gradient."));
auto py_func = PyObjectCast<std::function<void()>>(hook.ptr());
self.GradVarBase()->AddVoidHook(
std::make_shared<std::function<void()>>(py_func));
},
R"DOC(
Registers a backward hook for current Tensor.
This hook will be called every time the gradient of current Tensor has been fully calculated.
There are two differences with `_register_grad_hook`:
1. This backward hook will be executed after the gradient accumulation completed across batches,
but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
completed in current batch.
2. This backward hook function should have the following signature:
hook() -> None
It requires no input and no return value.
Args:
hook(function): A backward hook to be registered for Tensor.gradient
Returns:
None
)DOC")
.def(
"cpu",
[](const std::shared_ptr<imperative::VarBase> &self) {
if (platform::is_cpu_place(self->Place())) {
return self;
} else {
auto new_var = self->NewVarBase(platform::CPUPlace(), true);
new_var->SetOverridedStopGradient(self->OverridedStopGradient());
return new_var;
}
},
R"DOC(
Returns a copy of this Tensor in CPU memory.
If this Tensor is already in CPU memory, then no copy is performed and the original Tensor is returned.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
print(x.place) # CUDAPlace(0)
y = x.cpu()
print(y.place) # CPUPlace
)DOC")
.def(
"pin_memory",
[](const std::shared_ptr<imperative::VarBase> &self) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot copy this Tensor to pinned memory in CPU version "
"Paddle, "
"Please recompile or reinstall Paddle with CUDA support."));
#endif
if (platform::is_cuda_pinned_place(self->Place())) {
return self;
} else {
auto new_var =
self->NewVarBase(platform::CUDAPinnedPlace(), true);
new_var->SetOverridedStopGradient(self->OverridedStopGradient());
return new_var;
}
},
R"DOC(
Returns a copy of this Tensor in pin memory.
If this Tensor is already in pin memory, then no copy is performed and the original Tensor is returned.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
print(x.place) # CUDAPlace(0)
y = x.pin_memory()
print(y.place) # CUDAPinnedPlace
)DOC")
.def(
"cuda",
[](const std::shared_ptr<imperative::VarBase> &self,
py::handle &handle,
bool blocking) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot copy this Tensor to GPU in CPU version Paddle, "
"Please recompile or reinstall Paddle with CUDA support."));
#else
int device_count = platform::GetGPUDeviceCount();
int device_id = 0;
if (handle == py::none()) {
auto default_place =
imperative::GetCurrentTracer()->ExpectedPlace();
device_id = default_place.GetDeviceId();
} else {
PyObject *py_obj = handle.ptr();
PADDLE_ENFORCE_EQ(
PyCheckInteger(py_obj), true,
platform::errors::InvalidArgument(
" 'device_id' must be a positive integer"));
device_id = py::cast<int>(handle);
}
PADDLE_ENFORCE_GE(
device_id, 0,
platform::errors::InvalidArgument(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)",
device_id, device_count));
PADDLE_ENFORCE_LT(
device_id, device_count,
platform::errors::InvalidArgument(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)",
device_id, device_count));
platform::CUDAPlace place = platform::CUDAPlace(device_id);
if (platform::is_same_place(self->Place(), place)) {
return self;
} else {
auto new_var = self->NewVarBase(place, blocking);
new_var->SetOverridedStopGradient(self->OverridedStopGradient());
return new_var;
}
#endif
},
py::arg("device_id") = py::none(),
py::arg("blocking") = true,
R"DOC(
Returns a copy of this Tensor in GPU memory.
If this Tensor is already in GPU memory and device_id is default,
then no copy is performed and the original Tensor is returned.
Args:
device_id(int, optional): The destination GPU device id. Default: None, means current device.
blocking(bool, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.
Examples:
.. code-block:: python
# required: gpu
import paddle
x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
print(x.place) # Place(cpu)
y = x.cuda()
print(y.place) # Place(gpu:0)
y = x.cuda(None)
print(y.place) # Place(gpu:0)
paddle.device.set_device("gpu:1")
y = x.cuda(None)
print(y.place) # Place(gpu:1)
)DOC")
.def(
"_share_memory",
[](const std::shared_ptr<imperative::VarBase> &self) {
#ifndef _WIN32
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(self->Place()),
true,
platform::errors::InvalidArgument(
"Sharing memory only support CPU Tensor currently"));
// 1. get LoDTensor
auto *t = self->MutableVar()->GetMutable<phi::DenseTensor>();
// 2. allocate shared memory
void *data_ptr = t->data();
size_t data_size =
t->numel() * framework::SizeOfType(
framework::TransToProtoVarType(t->dtype()));
auto shared_writer_holder =
memory::allocation::AllocateMemoryMapWriterAllocation(
data_size);
// 3. maintain mmap fd set & backup ipc_name
const std::string &ipc_name = shared_writer_holder->ipc_name();
memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
// 4. copy data & reset holder
memory::Copy(platform::CPUPlace(),
shared_writer_holder->ptr(),
platform::CPUPlace(),
data_ptr,
data_size);
t->ResetHolder(shared_writer_holder);
return *t;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Sharing memory in Windows OS is not supported currently"));
#endif
},
py::return_value_policy::reference)
#if defined(PADDLE_WITH_CUDA)
.def(
"_uva",
[](const std::shared_ptr<imperative::VarBase> &self, int device_id) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->Place()),
true,
platform::errors::InvalidArgument(
"Unified virtual addressing only support "
"CPU Tensor currently."));
auto *self_tensor =
self->MutableVar()->GetMutable<phi::DenseTensor>();
tensor_uva(self_tensor, device_id);
},
py::arg("device_id") = 0,
py::return_value_policy::reference,
R"DOC(
Returns self tensor with the UVA(unified virtual addressing).
Args:
device_id(int, optional): The destination GPU device id. Default: None, means current device.
Examples:
.. code-block:: python
# required: gpu
import paddle
x = paddle.to_tensor([1, 2, 3], place=paddle.CPUPlace())
x._uva()
print(x)
)DOC")
#endif
.def("copy_", &imperative::VarBase::CopyFrom)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::CPUPlace &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
// Note(zhiqiu): Since NewVarBase may use GpuCopyAsync to
// copy data from the tensor of self to the tensor of new varbase,
// we need to ensure that the varbase self is not destructed until
// the GpuCopyAsync is completed. Otherwise, the memory may be
// freed
// when varbase self is destructed.
// To do that, we increase the reference count of self by 1 and
// add a cuda event to wait the GpuCopyAsync's completion.
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::CUDAPinnedPlace &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::XPUPlace &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::CUDAPlace &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::IPUPlace &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::CustomPlace &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::Place &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"value",
[](imperative::VarBase &self) { return self.MutableVar(); },
py::return_value_policy::reference)
.def("_clear",
[](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
t->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
t->clear();
})
.def("_offset",
[](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
t->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
return t->offset();
})
.def("_share_buffer_to",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &dst) {
auto *src = self->MutableVar()->GetMutable<phi::DenseTensor>();
auto *dst_ = dst->MutableVar()->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
src->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
dst_->ShareBufferWith(*src);
dst_->ShareDataTypeWith(*src);
})
.def("_is_shared_buffer_with",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &dst) {
auto *src = self->MutableVar()->GetMutable<phi::DenseTensor>();
auto *dst_ = dst->MutableVar()->GetMutable<phi::DenseTensor>();
if (!src->IsInitialized() || !dst_->IsInitialized()) {
return false;
}
return dst_->IsSharedBufferWith(*src);
})
.def("_share_underline_tensor_to",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &dst) {
auto *src = self->MutableVar()->GetMutable<phi::DenseTensor>();
auto *dst_ = dst->MutableVar()->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
src->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
dst_->ShareBufferWith(*src);
dst_->ShareDataTypeWith(*src);
dst_->Resize(src->dims());
})
.def("_is_shared_underline_tensor_with",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &dst) {
auto *src = self->MutableVar()->GetMutable<phi::DenseTensor>();
auto *dst_ = dst->MutableVar()->GetMutable<phi::DenseTensor>();
if (!src->IsInitialized() || !dst_->IsInitialized()) {
return false;
}
return dst_->IsSharedBufferWith(*src);
})
.def("_slice",
[](const std::shared_ptr<imperative::VarBase> &self,
int64_t begin_idx,
int64_t end_idx) {
auto *t = self->MutableVar()->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
t->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
return t->Slice(begin_idx, end_idx);
})
.def("_copy_gradient_from",
[](std::shared_ptr<imperative::VarBase> &self,
const imperative::VarBase &src) { self->_CopyGradientFrom(src); })
.def("_numel",
[](std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<phi::DenseTensor>();
return t->numel();
})
.def("element_size", &imperative::VarBase::ElementSize, R"DOC(
Returns the size in bytes of an element in the Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1, dtype='bool')
x.element_size() # 1
x = paddle.to_tensor(1, dtype='float16')
x.element_size() # 2
x = paddle.to_tensor(1, dtype='float32')
x.element_size() # 4
x = paddle.to_tensor(1, dtype='float64')
x.element_size() # 8
x = paddle.to_tensor(1, dtype='complex128')
x.element_size() # 16
)DOC")
.def_property(
"name", &imperative::VarBase::Name, &imperative::VarBase::SetName)
.def_property("stop_gradient",
&imperative::VarBase::OverridedStopGradient,
&imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable",
&imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable)
.def_property_readonly(
"shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<phi::DenseTensor>()) {
auto value = phi::vectorize<int>(
self.Var().Get<phi::DenseTensor>().dims());
auto tensor = self.Var().Get<phi::DenseTensor>();
auto tmp_value = value;
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance()
.GetDesiredLayout();
auto default_layout =
paddle::imperative::LayoutAutoTune::Instance()
.GetDefaultLayout();
bool change_dim =
(desired_layout != default_layout &&
tensor.layout() == desired_layout && value.size() == 4);
VLOG(6) << "'Shape' method, layout autotune,"
<< " desired_layout: " << desired_layout
<< " default_layout: " << default_layout
<< " tensor layout: " << tensor.layout()
<< " tensor's shape size is : " << value.size();
if (change_dim &&
phi::DataLayoutToString(desired_layout) == "NCHW") {
VLOG(6) << "layout autotune get Shape from NHWC -> NCHW "
<< value[0] << " " << value[1] << " " << value[2] << " "
<< value[3] << " to " << tmp_value[3] << " "
<< tmp_value[1] << " " << tmp_value[2] << " "
<< tmp_value[1];
// NCHW -> NHWC
value[1] = tmp_value[2];
value[2] = tmp_value[3];
value[3] = tmp_value[1];
} else if (change_dim &&
phi::DataLayoutToString(desired_layout) == "NHWC") {
VLOG(6) << "layout autotune get Shape from NHWC -> NCHW "
<< value[0] << " " << value[1] << " " << value[2] << " "
<< value[3] << " to " << tmp_value[0] << " "
<< tmp_value[3] << " " << tmp_value[1] << " "
<< tmp_value[2];
// NHWC -> NCHW
value[1] = tmp_value[3];
value[2] = tmp_value[1];
value[3] = tmp_value[2];
}
return value;
} else if (self.Var().IsType<phi::SelectedRows>()) {
return phi::vectorize<int>(
self.Var().Get<phi::SelectedRows>().value().dims());
} else if (self.Var().IsType<framework::Strings>()) {
return std::vector<int>{static_cast<int>(
self.Var().Get<framework::Strings>().size())};
} else if (self.Var().IsType<framework::Vocab>()) {
return std::vector<int>{
static_cast<int>(self.Var().Get<framework::Vocab>().size())};
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly(
"layout",
[](imperative::VarBase &self) {
if (self.Var().IsType<phi::DenseTensor>()) {
auto layout = self.Var().Get<phi::DenseTensor>().layout();
return phi::DataLayoutToString(layout);
}
return std::string("");
})
.def_property_readonly("is_leaf",
&imperative::VarBase::IsLeaf,
R"DOC(
Whether a Tensor is leaf Tensor.
For the Tensor whose stop_gradient is ``True`` , it will be leaf Tensor.
For the Tensor whose stop_gradient is ``False`` , it will be leaf Tensor too if it is created by user.
Returns:
bool: Whether a Tensor is leaf Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.)
print(x.is_leaf) # True
x = paddle.to_tensor(1., stop_gradient=True)
y = x + 1
print(x.is_leaf) # True
print(y.is_leaf) # True
x = paddle.to_tensor(1., stop_gradient=False)
y = x + 1
print(x.is_leaf) # True
print(y.is_leaf) # False
)DOC")
.def_property_readonly(
"place",
[](imperative::VarBase &self) { return self.Place(); },
py::return_value_policy::copy)
.def_property_readonly("_place_str",
[](imperative::VarBase &self) {
std::stringstream ostr;
ostr << self.Place();
return ostr.str();
})
.def_property_readonly("type", &imperative::VarBase::Type)
.def_property_readonly("dtype", &imperative::VarBase::DataType);
py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "") py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
.def("create_program_desc", .def("create_program_desc",
&imperative::jit::ProgramDescTracer::CreateProgramDesc) &imperative::jit::ProgramDescTracer::CreateProgramDesc)
......
...@@ -61,7 +61,6 @@ class OpAttrTypeMap { ...@@ -61,7 +61,6 @@ class OpAttrTypeMap {
ops_attrtype_map_; ops_attrtype_map_;
}; };
extern PyTypeObject* g_varbase_pytype;
extern PyTypeObject* g_vartype_pytype; extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_blockdesc_pytype; extern PyTypeObject* g_blockdesc_pytype;
extern PyTypeObject* p_tensor_type; extern PyTypeObject* p_tensor_type;
...@@ -71,7 +70,6 @@ bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); } ...@@ -71,7 +70,6 @@ bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); }
bool PyObject_CheckLongOrToLong(PyObject** obj) { bool PyObject_CheckLongOrToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) || if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT
PyObject_TypeCheck(*obj, g_varbase_pytype) || // NOLINT
(PyObject_TypeCheck(*obj, p_tensor_type) && // NOLINT (PyObject_TypeCheck(*obj, p_tensor_type) && // NOLINT
(((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT
return true; return true;
...@@ -92,7 +90,6 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) { ...@@ -92,7 +90,6 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) {
bool PyObject_CheckFloatOrToFloat(PyObject** obj) { bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
// sometimes users provide PyLong or numpy.int64 but attr is float // sometimes users provide PyLong or numpy.int64 but attr is float
if (PyFloat_Check(*obj) || PyLong_Check(*obj) || if (PyFloat_Check(*obj) || PyLong_Check(*obj) ||
PyObject_TypeCheck(*obj, g_varbase_pytype) || // NOLINT
(PyObject_TypeCheck(*obj, p_tensor_type) && // NOLINT (PyObject_TypeCheck(*obj, p_tensor_type) && // NOLINT
(((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT
return true; return true;
...@@ -111,7 +108,6 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { ...@@ -111,7 +108,6 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool PyObject_CheckComplexOrToComplex(PyObject** obj) { bool PyObject_CheckComplexOrToComplex(PyObject** obj) {
if (PyComplex_Check(*obj) || PyLong_Check(*obj) || PyFloat_Check(*obj) || if (PyComplex_Check(*obj) || PyLong_Check(*obj) || PyFloat_Check(*obj) ||
PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT
PyObject_TypeCheck(*obj, g_varbase_pytype) || // NOLINT
PyObject_TypeCheck(*obj, p_tensor_type)) { // NOLINT PyObject_TypeCheck(*obj, p_tensor_type)) { // NOLINT
return true; return true;
} }
...@@ -926,138 +922,6 @@ void ConstructAttrMapFromPyArgs( ...@@ -926,138 +922,6 @@ void ConstructAttrMapFromPyArgs(
} }
} }
std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
const std::string& op_type,
const std::string& arg_name,
PyObject* args,
ssize_t arg_idx,
bool dispensable) {
::pybind11::detail::instance* inst =
(::pybind11::detail::instance*)PyTuple_GET_ITEM(args, arg_idx);
if (PyTuple_Check((PyObject*)inst)) { // NOLINT
inst = (::pybind11::detail::instance*)PyTuple_GET_ITEM(inst, 0);
}
if (inst == nullptr || (PyObject*)inst == Py_None) { // NOLINT
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got None",
op_type,
arg_name,
arg_idx));
}
return nullptr;
}
if (!PyObject_TypeCheck((PyObject*)inst, g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type,
arg_name,
arg_idx,
((PyTypeObject*)((PyObject*)inst)->ob_type)->tp_name)); // NOLINT
}
void** vh = inst->simple_layout ? inst->simple_value_holder
: &inst->nonsimple.values_and_holders[0];
return reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(vh[1]);
}
std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
const std::string& op_type,
const std::string& arg_name,
PyObject* args,
ssize_t arg_idx,
bool dispensable) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);
if (list == nullptr || list == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
"None",
op_type,
arg_name,
arg_idx)); // NOLINT
}
return {};
}
std::vector<std::shared_ptr<imperative::VarBase>> result;
if (PyList_Check(list)) {
Py_ssize_t len = PyList_Size(list);
if (len == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list",
op_type,
arg_name,
arg_idx));
}
::pybind11::detail::instance* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = (::pybind11::detail::instance*)PyList_GetItem(list, i);
if (!PyObject_TypeCheck((PyObject*)item, g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s",
op_type,
arg_name,
arg_idx,
((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name)); // NOLINT
}
void** vh = item->simple_layout ? item->simple_value_holder
: &item->nonsimple.values_and_holders[0];
result.emplace_back(
reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
vh[1]));
}
} else if (PyTuple_Check(list)) {
Py_ssize_t len = PyTuple_Size(list);
if (len == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list",
op_type,
arg_name,
arg_idx));
}
::pybind11::detail::instance* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = (::pybind11::detail::instance*)PyTuple_GetItem(list, i); // NOLINT
if (!PyObject_TypeCheck((PyObject*)item, g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s",
op_type,
arg_name,
arg_idx,
((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name)); // NOLINT
}
void** vh = item->simple_layout ? item->simple_value_holder
: &item->nonsimple.values_and_holders[0];
result.emplace_back(
reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
vh[1]));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"%s",
op_type,
arg_name,
arg_idx,
((PyTypeObject*)list->ob_type)->tp_name)); // NOLINT
}
return result;
}
unsigned long GetUnsignedLongFromArgs( // NOLINT unsigned long GetUnsignedLongFromArgs( // NOLINT
const std::string& op_type, const std::string& op_type,
const std::string& arg_name, const std::string& arg_name,
......
...@@ -194,20 +194,6 @@ void ConstructAttrMapFromPyArgs( ...@@ -194,20 +194,6 @@ void ConstructAttrMapFromPyArgs(
ssize_t attr_end, ssize_t attr_end,
paddle::framework::AttributeMap& attrs); // NOLINT paddle::framework::AttributeMap& attrs); // NOLINT
std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
const std::string& op_type,
const std::string& arg_name,
PyObject* args,
ssize_t arg_idx,
bool dispensable = false);
std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
const std::string& op_type,
const std::string& arg_name,
PyObject* args,
ssize_t arg_idx,
bool dispensable = false);
unsigned long GetUnsignedLongFromArgs( // NOLINT unsigned long GetUnsignedLongFromArgs( // NOLINT
const std::string& op_type, const std::string& op_type,
const std::string& arg_name, const std::string& arg_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册