未验证 提交 ae867a84 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Fix bugs of supporting ProcessGroupNCCL on DCU (#43682)

* fix bugs

* update

* update

* update

* code style

* code style check
上级 292b7254
......@@ -129,7 +129,7 @@ endif()
if(NOT ON_INFER)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup eager_reducer)
if(WITH_NCCL)
if(WITH_NCCL OR WITH_RCCL)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_nccl)
if(WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_heter)
......
......@@ -31,7 +31,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/phi/api/all.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
......@@ -61,11 +61,15 @@ std::shared_ptr<distributed::EagerReducer> CreateEagerReducer(
const std::vector<std::vector<size_t>> &group_indices,
const std::vector<bool> &is_sparse_gradient,
std::shared_ptr<distributed::ProcessGroup> process_group,
const std::vector<size_t> &group_size_limits, bool find_unused_parameters) {
const std::vector<size_t> &group_size_limits,
bool find_unused_parameters) {
auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
return std::make_shared<distributed::EagerReducer>(
params, group_indices, is_sparse_gradient, process_group,
group_size_limits, find_unused_parameters);
return std::make_shared<distributed::EagerReducer>(params,
group_indices,
is_sparse_gradient,
process_group,
group_size_limits,
find_unused_parameters);
}
#if defined(PADDLE_WITH_GLOO)
......@@ -111,7 +115,8 @@ void BindDistributed(py::module *m) {
.def("name", &distributed::ProcessGroup::GetBackendName)
.def(
"allreduce",
[](distributed::ProcessGroup &self, py::handle py_tensor,
[](distributed::ProcessGroup &self,
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
......@@ -121,12 +126,14 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts);
},
py::arg("tensor"), py::arg("op") = distributed::ReduceOp::SUM,
py::arg("tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
[](distributed::ProcessGroup &self, py::handle py_tensor,
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int source_rank) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts;
......@@ -136,7 +143,8 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts);
},
py::arg("tensor"), py::arg("source_rank"),
py::arg("tensor"),
py::arg("source_rank"),
py::call_guard<py::gil_scoped_release>())
.def(
......@@ -151,7 +159,8 @@ void BindDistributed(py::module *m) {
.def(
"send",
[](distributed::ProcessGroup &self, py::handle py_tensor,
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
......@@ -159,12 +168,14 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst);
},
py::arg("tensor"), py::arg("dst"),
py::arg("tensor"),
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self, py::handle py_tensor,
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
......@@ -172,12 +183,14 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src);
},
py::arg("tensor"), py::arg("src"),
py::arg("tensor"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
......@@ -189,12 +202,14 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllGather(in_tensors, out_tensors);
},
py::arg("in"), py::arg("out"),
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
......@@ -206,13 +221,16 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll(in_tensors, out_tensors);
},
py::arg("in"), py::arg("out"),
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"reduce",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
int dst, distributed::ReduceOp op) {
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
int dst,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
distributed::ReduceOptions opts;
opts.reduce_op = op;
......@@ -222,14 +240,17 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts);
},
py::arg("tensor"), py::arg("dst"),
py::arg("tensor"),
py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"scatter",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
py::handle py_out_tensor, int src) {
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ScatterOptions opts;
......@@ -242,17 +263,25 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.Scatter(in_tensors, out_tensors, opts);
},
py::arg("in"), py::arg("out"), py::arg("src"),
py::arg("in"),
py::arg("out"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
py::class_<distributed::ProcessGroupNCCL,
std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int,
const platform::CUDAPlace &, int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("place"), py::arg("group_id") = 0,
.def(py::init<const std::shared_ptr<distributed::Store> &,
int,
int,
const platform::CUDAPlace &,
int>(),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("place"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif
......@@ -261,29 +290,53 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupHeter,
std::shared_ptr<distributed::ProcessGroupHeter>>(
*m, "ProcessGroupHeter", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int,
.def(py::init<const std::shared_ptr<distributed::Store> &,
int,
int,
#if defined(PADDLE_WITH_ASCEND_CL)
const platform::NPUPlace &,
#else
const platform::CUDAPlace &,
#endif
int, int, int, int, int, bool, std::string, int, int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("place"), py::arg("gid") = 0, py::arg("local_rank") = 0,
py::arg("local_size") = 1, py::arg("gloo_rank") = 0,
py::arg("gloo_size") = 1, py::arg("with_switch") = false,
py::arg("switch_endpoint") = "", py::arg("src_rank") = "",
py::arg("dst_rank") = "", py::call_guard<py::gil_scoped_release>());
int,
int,
int,
int,
int,
bool,
std::string,
int,
int>(),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("place"),
py::arg("gid") = 0,
py::arg("local_rank") = 0,
py::arg("local_size") = 1,
py::arg("gloo_rank") = 0,
py::arg("gloo_size") = 1,
py::arg("with_switch") = false,
py::arg("switch_endpoint") = "",
py::arg("src_rank") = "",
py::arg("dst_rank") = "",
py::call_guard<py::gil_scoped_release>());
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
py::class_<distributed::ProcessGroupHCCL,
std::shared_ptr<distributed::ProcessGroupHCCL>>(
*m, "ProcessGroupHCCL", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int,
const platform::NPUPlace &, int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("place"), py::arg("group_id") = 0,
.def(py::init<const std::shared_ptr<distributed::Store> &,
int,
int,
const platform::NPUPlace &,
int>(),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("place"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif
......@@ -291,22 +344,29 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroup::Task,
std::shared_ptr<distributed::ProcessGroup::Task>>(*m, "task")
.def("is_completed", &distributed::ProcessGroup::Task::IsCompleted)
.def("wait", &distributed::ProcessGroup::Task::Wait,
.def("wait",
&distributed::ProcessGroup::Task::Wait,
py::arg("timeout") = kWaitTimeout,
py::call_guard<py::gil_scoped_release>())
.def("synchronize", &distributed::ProcessGroup::Task::Synchronize,
.def("synchronize",
&distributed::ProcessGroup::Task::Synchronize,
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_GLOO)
py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
*m, "ProcessGroupGloo", ProcessGroup)
.def(py::init<const std::shared_ptr<paddle::distributed::Store> &, int,
int, const platform::CPUPlace &, int,
.def(py::init<const std::shared_ptr<paddle::distributed::Store> &,
int,
int,
const platform::CPUPlace &,
int,
std::shared_ptr<GlooOptions> &>(),
py::call_guard<py::gil_scoped_release>())
.def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store,
int rank, int world_size,
const platform::CPUPlace &place, int gid) {
int rank,
int world_size,
const platform::CPUPlace &place,
int gid) {
auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
......@@ -315,11 +375,14 @@ void BindDistributed(py::module *m) {
} else {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
return std::make_shared<ProcessGroupGloo>(store, rank, world_size,
place, gid, opts);
return std::make_shared<ProcessGroupGloo>(
store, rank, world_size, place, gid, opts);
}),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("place"), py::arg("group_id") = 0,
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("place"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device",
&ProcessGroupGloo::createDefaultDevice);
......@@ -327,21 +390,23 @@ void BindDistributed(py::module *m) {
m->def(
"eager_assign_group_by_size",
[](py::handle py_tensors, std::vector<bool> is_sparse_gradient,
[](py::handle py_tensors,
std::vector<bool> is_sparse_gradient,
std::vector<size_t> group_size_limits,
std::vector<int64_t> tensor_indices) {
auto tensors = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
return distributed::Eager_AssignGroupBySize(
tensors, is_sparse_gradient, group_size_limits, tensor_indices);
},
py::arg("tensors"), py::arg("is_sparse_gradient"),
py::arg("tensors"),
py::arg("is_sparse_gradient"),
py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
py::arg("tensor_indices") = std::vector<int64_t>{},
py::call_guard<py::gil_scoped_release>());
py::class_<distributed::EagerReducer,
std::shared_ptr<distributed::EagerReducer>>(*m, "EagerReducer",
R"DOC()DOC")
std::shared_ptr<distributed::EagerReducer>>(
*m, "EagerReducer", R"DOC()DOC")
.def(py::init(&CreateEagerReducer))
.def(
"prepare_for_backward",
......@@ -349,7 +414,8 @@ void BindDistributed(py::module *m) {
auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
self.PrepareForBackward(params);
},
py::arg("tensors"), py::call_guard<py::gil_scoped_release>());
py::arg("tensors"),
py::call_guard<py::gil_scoped_release>());
}
} // end namespace pybind
......
......@@ -149,7 +149,8 @@ Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
paddle::experimental::Tensor tensor = CastPyArg2Tensor(obj, 0);
PADDLE_ENFORCE_EQ(
tensor.initialized(), true,
tensor.initialized(),
true,
paddle::platform::errors::InvalidArgument(
"We can only support initialized tensor in slice, however we got "
"uninitialized tensor %s, please check your code.",
......@@ -167,7 +168,8 @@ bool PyCheckTensor(PyObject* obj) {
return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type));
}
static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
static PyObject* tensor_method_numpy(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto& api = pybind11::detail::npy_api::get();
......@@ -179,8 +181,11 @@ static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_,
api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_), 1,
py_dims, py_strides, nullptr,
api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
1,
py_dims,
py_strides,
nullptr,
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
nullptr);
......@@ -199,8 +204,12 @@ static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
}
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_, api.PyArray_DescrFromType_(numpy_dtype),
tensor_dims.size(), py_dims, py_strides, nullptr,
api.PyArray_Type_,
api.PyArray_DescrFromType_(numpy_dtype),
tensor_dims.size(),
py_dims,
py_strides,
nullptr,
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
nullptr);
......@@ -210,8 +219,12 @@ static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
py_dims[0] = 0;
py_strides[0] = 0;
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_, api.PyArray_DescrFromType_(numpy_dtype), 1,
py_dims, py_strides, nullptr,
api.PyArray_Type_,
api.PyArray_DescrFromType_(numpy_dtype),
1,
py_dims,
py_strides,
nullptr,
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
nullptr);
......@@ -233,7 +246,9 @@ static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
paddle::memory::Copy(
place,
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
place, dense_tensor->data(), sizeof_dtype * numel);
place,
dense_tensor->data(),
sizeof_dtype * numel);
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
......@@ -242,11 +257,18 @@ static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
paddle::memory::Copy(
place,
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
place, dense_tensor->data(), sizeof_dtype * numel);
place,
dense_tensor->data(),
sizeof_dtype * numel);
}
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (self->tensor.is_gpu()) {
#if defined(PADDLE_WITH_CUDA)
gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
gpuMemcpyKind kind = hipMemcpyDeviceToHost;
#endif
if (self->tensor.is_selected_rows()) {
VLOG(6) << "Getting SelectedRows's numpy value";
auto* selected_rows =
......@@ -254,19 +276,21 @@ static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
auto* dense_tensor = static_cast<paddle::framework::LoDTensor*>(
selected_rows->mutable_value());
paddle::platform::GpuMemcpySync(
pybind11::detail::array_proxy(array)->data, dense_tensor->data(),
pybind11::detail::array_proxy(array)->data,
dense_tensor->data(),
paddle::framework::DataTypeSize(dense_tensor->dtype()) *
dense_tensor->numel(),
cudaMemcpyDeviceToHost);
kind);
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
paddle::platform::GpuMemcpySync(
pybind11::detail::array_proxy(array)->data, dense_tensor->data(),
pybind11::detail::array_proxy(array)->data,
dense_tensor->data(),
paddle::framework::DataTypeSize(dense_tensor->dtype()) *
dense_tensor->numel(),
cudaMemcpyDeviceToHost);
kind);
}
#endif
} else {
......@@ -294,8 +318,11 @@ static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_,
api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_), 1,
py_dims, py_strides, nullptr,
api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
1,
py_dims,
py_strides,
nullptr,
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
nullptr);
......@@ -334,7 +361,9 @@ static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
curr_unicode_len);
}
py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
tensor_dims, {}, py_array_data);
tensor_dims,
{},
py_array_data);
return array.release().ptr();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -384,7 +413,8 @@ static void IncreaseTensorReferenceCountUntilCopyComplete(
gc->DirectClearCallback(callback);
}
static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
static PyObject* tensor_method__copy_to(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
......@@ -401,7 +431,8 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_cpu(TensorObject* self, PyObject* args,
static PyObject* tensor_method_cpu(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
......@@ -434,7 +465,8 @@ static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_copy_(TensorObject* self, PyObject* args,
static PyObject* tensor_method_copy_(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor src_tensor =
......@@ -465,7 +497,8 @@ static PyObject* tensor_method_copy_(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_retain_grads(TensorObject* self, PyObject* args,
static PyObject* tensor_retain_grads(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (egr::Controller::Instance().HasGrad()) {
......@@ -482,7 +515,8 @@ static PyObject* tensor_retain_grads(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
static PyObject* tensor_clear_gradient(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(4) << "ClearGradient " << self->tensor.name();
......@@ -543,7 +577,8 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__zero_grads(TensorObject* self, PyObject* args,
static PyObject* tensor__zero_grads(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(4) << "ZeroGrads " << self->tensor.name();
......@@ -586,12 +621,14 @@ static PyObject* tensor__zero_grads(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__share_buffer_to(TensorObject* self, PyObject* args,
static PyObject* tensor__share_buffer_to(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor* dst_ptr =
&(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
PADDLE_ENFORCE_EQ(self->tensor.initialized(), true,
PADDLE_ENFORCE_EQ(self->tensor.initialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! please initialize "
"src tensor before share_buffer_with to other.",
......@@ -616,7 +653,8 @@ static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
EAGER_TRY
paddle::experimental::Tensor* dst_ptr =
&(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
PADDLE_ENFORCE_EQ(self->tensor.initialized(), true,
PADDLE_ENFORCE_EQ(self->tensor.initialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! please initialize "
"src tensor before share_buffer_with to other.",
......@@ -640,7 +678,8 @@ static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
EAGER_TRY
paddle::experimental::Tensor* src_ptr =
&(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
PADDLE_ENFORCE_EQ(self->tensor.initialized(), true,
PADDLE_ENFORCE_EQ(self->tensor.initialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! please initialize "
"src tensor before share_buffer_with to other.",
......@@ -657,7 +696,8 @@ static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
EAGER_TRY
paddle::experimental::Tensor src_tensor =
CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
PADDLE_ENFORCE_EQ(src_tensor.initialized(), true,
PADDLE_ENFORCE_EQ(src_tensor.initialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! please initialize "
"src tensor before share_buffer_with to other.",
......@@ -671,11 +711,13 @@ static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_detach(TensorObject* self, PyObject* args,
static PyObject* tensor_method_detach(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE_EQ(
self->tensor.initialized(), true,
self->tensor.initialized(),
true,
platform::errors::InvalidArgument("Tensor %s has not been initialized!",
self->tensor.name()));
......@@ -745,15 +787,24 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
PADDLE_ENFORCE_EQ(
self->tensor.initialized(), true,
self->tensor.initialized(),
true,
platform::errors::InvalidArgument(
"tensor %s has not been initialized, we can only slice initialized "
"tensor please init it first with numpy or other tensor.",
self->tensor.name()));
auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
ParseIndexingSlice(tensor, _index, &slice_axes, &slice_starts, &slice_ends,
&slice_strides, &decrease_axis, &none_axes, &infer_flags,
&list_select_idxs, &list_select_flag);
ParseIndexingSlice(tensor,
_index,
&slice_axes,
&slice_starts,
&slice_ends,
&slice_strides,
&decrease_axis,
&none_axes,
&infer_flags,
&list_select_idxs,
&list_select_flag);
auto out = slice_axes.empty() && !list_select_flag
? self->tensor
......@@ -782,9 +833,12 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
decrease_axis.end());
if (op_type == "slice") {
out = slice_final_state_dygraph_function(
self->tensor, slice_axes_tmp, slice_starts, slice_ends,
infer_flags_tmp, decrease_axis_tmp);
out = slice_final_state_dygraph_function(self->tensor,
slice_axes_tmp,
slice_starts,
slice_ends,
infer_flags_tmp,
decrease_axis_tmp);
} else if (op_type == "strided_slice") {
out = strided_slice_final_state_dygraph_function(
self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
......@@ -839,27 +893,29 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
select_index.set_impl(idx_tensor);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
egr::Controller::Instance().GetExpectedPlace());
paddle::framework::TensorFromVector(list_select_idxs, *dev_ctx,
idx_tensor.get());
paddle::framework::TensorFromVector(
list_select_idxs, *dev_ctx, idx_tensor.get());
framework::AttributeMap attrs = {{"dim", 0}};
out = index_select_final_state_dygraph_function(self->tensor, select_index,
0);
out = index_select_final_state_dygraph_function(
self->tensor, select_index, 0);
}
return ToPyObject(out);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* args,
static PyObject* tensor__getitem_from_offset(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
PADDLE_ENFORCE_NOT_NULL(
ptr, platform::errors::InvalidArgument("%s is not a DenseTensor.",
self->tensor.name()));
PADDLE_ENFORCE_NOT_NULL(ptr,
platform::errors::InvalidArgument(
"%s is not a DenseTensor.", self->tensor.name()));
const auto& tensor = *ptr;
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(), true,
tensor.IsInitialized(),
true,
platform::errors::InvalidArgument(
"Tensor of %s is Empty, please check if it has no data.",
self->tensor.name()));
......@@ -877,27 +933,33 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* args,
}
size_t offset = 0;
if (PyTuple_Size(args) == 0) {
PADDLE_ENFORCE_EQ(numel, 1,
PADDLE_ENFORCE_EQ(numel,
1,
platform::errors::InvalidArgument(
"only one element tensors can be converted to Python "
"scalars when no input coordinates"));
} else if (PyTuple_Size(args) == 1) {
offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
PADDLE_ENFORCE_LT(
offset, numel,
offset,
numel,
platform::errors::InvalidArgument(
"index %d is out of bounds for size %d", offset, numel));
} else {
PADDLE_ENFORCE_EQ(PyTuple_Size(args), dims.size(),
PADDLE_ENFORCE_EQ(PyTuple_Size(args),
dims.size(),
platform::errors::InvalidArgument(
"incorrect number of indices for Tensor"));
for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
PADDLE_ENFORCE_LT(
index, dims[i],
index,
dims[i],
platform::errors::InvalidArgument(
"index %d is out fo bounds for axis %d with size %d", index, i,
"index %d is out fo bounds for axis %d with size %d",
index,
i,
dims[i]));
offset += index * strides[i];
}
......@@ -929,14 +991,19 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* args,
py_strides[0] = 1; \
auto& api = pybind11::detail::npy_api::get(); \
PyObject* array = api.PyArray_NewFromDescr_( \
api.PyArray_Type_, api.PyArray_DescrFromType_(numpy_dtype), 1, \
py_dims, py_strides, nullptr, \
api.PyArray_Type_, \
api.PyArray_DescrFromType_(numpy_dtype), \
1, \
py_dims, \
py_strides, \
nullptr, \
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ | \
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_, \
nullptr); \
std::memcpy( \
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
static_cast<void*>(&b), sizeof(b)); \
static_cast<void*>(&b), \
sizeof(b)); \
return array; \
}
......@@ -991,9 +1058,17 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, &steps,
&decrease_axes, &none_axes, &infer_flags,
&list_select_idxs, &list_select_flag);
ParseIndexingSlice(self_tensor,
index_ptr,
&axes,
&starts,
&ends,
&steps,
&decrease_axes,
&none_axes,
&infer_flags,
&list_select_idxs,
&list_select_flag);
framework::AttributeMap attrs = {{"axes", axes},
{"starts", starts},
......@@ -1058,16 +1133,22 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
SetTensorFromPyArray(
static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
value, platform::Place(platform::CUDAPlace(0)), false);
value,
platform::Place(platform::CUDAPlace(0)),
false);
#else
SetTensorFromPyArray(
static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
value, platform::Place(platform::CPUPlace()), false);
value,
platform::Place(platform::CPUPlace()),
false);
#endif
} else {
SetTensorFromPyArray(
static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
value, value_tensor_tmp.place(), false);
value,
value_tensor_tmp.place(),
false);
}
value_tensor = value_tensor_tmp;
......@@ -1117,8 +1198,8 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
// Release gil and do tracing
py::gil_scoped_release release;
// use inplace set_value_ operator
self->tensor = set_value__dygraph_function(self->tensor, value_tensor, {},
{}, {}, attrs);
self->tensor = set_value__dygraph_function(
self->tensor, value_tensor, {}, {}, {}, attrs);
}
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
......@@ -1144,15 +1225,19 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
}
if (!self->tensor.initialized()) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
SetTensorFromPyArray(self_tensor, self_numpy,
platform::Place(platform::CUDAPlace(0)), false);
SetTensorFromPyArray(self_tensor,
self_numpy,
platform::Place(platform::CUDAPlace(0)),
false);
#else
SetTensorFromPyArray(self_tensor, self_numpy,
platform::Place(platform::CPUPlace()), false);
SetTensorFromPyArray(self_tensor,
self_numpy,
platform::Place(platform::CPUPlace()),
false);
#endif
} else {
SetTensorFromPyArray(self_tensor, self_numpy, self->tensor.place(),
false);
SetTensorFromPyArray(
self_tensor, self_numpy, self->tensor.place(), false);
}
}
RETURN_PY_NONE
......@@ -1160,7 +1245,8 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_register_grad_hook(TensorObject* self, PyObject* args,
static PyObject* tensor_register_grad_hook(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
int64_t hook_id;
......@@ -1187,7 +1273,8 @@ static PyObject* tensor_register_grad_hook(TensorObject* self, PyObject* args,
auto accumulation_grad_node =
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
hook_id = accumulation_grad_node->RegisterGradientHook(
rank_info.first, rank_info.second,
rank_info.first,
rank_info.second,
std::make_shared<PyTensorHook>(hook_func));
} else {
......@@ -1200,14 +1287,16 @@ static PyObject* tensor_register_grad_hook(TensorObject* self, PyObject* args,
PyObject* hook_func = PyTuple_GET_ITEM(args, 0);
hook_id = grad_node->RegisterGradientHook(
rank_info.first, rank_info.second,
rank_info.first,
rank_info.second,
std::make_shared<PyTensorHook>(hook_func));
}
return ToPyObject(hook_id);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_remove_grad_hook(TensorObject* self, PyObject* args,
static PyObject* tensor_remove_grad_hook(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
......@@ -1220,14 +1309,16 @@ static PyObject* tensor_remove_grad_hook(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_register_reduce_hook(TensorObject* self, PyObject* args,
static PyObject* tensor_register_reduce_hook(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();
std::shared_ptr<egr::GradNodeBase> grad_node =
egr::EagerUtils::grad_node(self->tensor);
PADDLE_ENFORCE_EQ(egr::egr_utils_api::IsLeafTensor(self->tensor), true,
PADDLE_ENFORCE_EQ(egr::egr_utils_api::IsLeafTensor(self->tensor),
true,
platform::errors::InvalidArgument(
"Only can register backward hook for leaf Tensor."));
PADDLE_ENFORCE_EQ(
......@@ -1253,7 +1344,8 @@ static PyObject* tensor_register_reduce_hook(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__set_grad_type(TensorObject* self, PyObject* args,
static PyObject* tensor__set_grad_type(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
......@@ -1269,7 +1361,8 @@ static PyObject* tensor__set_grad_type(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__clear(TensorObject* self, PyObject* args,
static PyObject* tensor__clear(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
self->tensor.reset();
......@@ -1278,26 +1371,31 @@ static PyObject* tensor__clear(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__copy_gradient_from(TensorObject* self, PyObject* args,
static PyObject* tensor__copy_gradient_from(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
if (self->tensor.initialized()) {
PADDLE_ENFORCE_EQ(self->tensor.dtype(), src.dtype(),
PADDLE_ENFORCE_EQ(self->tensor.dtype(),
src.dtype(),
platform::errors::PreconditionNotMet(
"Tensor %s has different data type with Tensor %s",
self->tensor.name(), src.name()));
self->tensor.name(),
src.name()));
PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
src.impl()->type_info().id(),
platform::errors::PreconditionNotMet(
"Tensor %s has different type with Tensor %s, Tensor "
"ShareGradientDataWith cannot be performed!",
self->tensor.name(), src.name()));
self->tensor.name(),
src.name()));
}
VLOG(6) << "Tensor copy gradient from: " << src.name();
auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
if (p_grad) {
PADDLE_ENFORCE_EQ(src.initialized(), true,
PADDLE_ENFORCE_EQ(src.initialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized", src.name()));
p_grad->set_impl(src.impl());
......@@ -1307,7 +1405,8 @@ static PyObject* tensor__copy_gradient_from(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_set_vocab(TensorObject* self, PyObject* args,
static PyObject* tensor_method_set_vocab(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
using Vocab = std::unordered_map<std::wstring, int>;
......@@ -1337,7 +1436,8 @@ static PyObject* tensor_method_get_map_tensor(TensorObject* self,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE_EQ(
egr::IsVariableCompatTensor(self->tensor), true,
egr::IsVariableCompatTensor(self->tensor),
true,
paddle::platform::errors::Fatal(
"this method is only effective for VariableCompatTensor"));
using Vocab = std::unordered_map<std::wstring, int>;
......@@ -1417,7 +1517,8 @@ static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_dense(TensorObject* self, PyObject* args,
static PyObject* tensor_method_is_dense(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (!self->tensor.defined()) {
......@@ -1427,7 +1528,8 @@ static PyObject* tensor_method_is_dense(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args,
static PyObject* tensor_method_is_sparse(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (!self->tensor.defined()) {
......@@ -1438,7 +1540,8 @@ static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args,
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (!self->tensor.defined()) {
......@@ -1448,7 +1551,8 @@ static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args,
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (!self->tensor.defined()) {
......@@ -1458,7 +1562,8 @@ static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_to_sparse_csr(TensorObject* self, PyObject* args,
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto csr_tensor = self->tensor.to_sparse_csr();
......@@ -1472,7 +1577,8 @@ static PyObject* tensor_method_to_sparse_csr(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args,
static PyObject* tensor__inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
uint32_t inplace_version = self->tensor.current_inplace_version();
......@@ -1481,7 +1587,8 @@ static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_element_size(TensorObject* self, PyObject* args,
static PyObject* tensor_method_element_size(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
uint32_t element_size = framework::DataTypeSize(self->tensor.dtype());
......@@ -1510,7 +1617,8 @@ static PyObject* tensor_method_is_selected_rows(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args,
static PyObject* tensor_method_get_rows(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE(self->tensor.is_selected_rows(),
......@@ -1522,7 +1630,8 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_methon_element_size(TensorObject* self, PyObject* args,
static PyObject* tensor_methon_element_size(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
return ToPyObject(paddle::experimental::SizeOf(self->tensor.dtype()));
......@@ -1550,11 +1659,13 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method__share_memory(TensorObject* self, PyObject* args,
static PyObject* tensor_method__share_memory(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
#ifndef _WIN32
PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()), true,
PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
true,
platform::errors::InvalidArgument(
"Sharing memory only support CPU Tensor currently"));
// 1. get LoDTensor
......@@ -1571,8 +1682,11 @@ static PyObject* tensor_method__share_memory(TensorObject* self, PyObject* args,
const std::string& ipc_name = shared_writer_holder->ipc_name();
memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
// 4. copy data & reset holder
memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
platform::CPUPlace(), data_ptr, data_size);
memory::Copy(platform::CPUPlace(),
shared_writer_holder->ptr(),
platform::CPUPlace(),
data_ptr,
data_size);
t->ResetHolder(shared_writer_holder);
return ToPyObject(t);
#else
......@@ -1584,12 +1698,14 @@ static PyObject* tensor_method__share_memory(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__offset(TensorObject* self, PyObject* args,
static PyObject* tensor__offset(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
t->IsInitialized(),
true,
platform::errors::InvalidArgument("Tensor %s has not been initialized!",
self->tensor.name()));
......@@ -1597,12 +1713,14 @@ static PyObject* tensor__offset(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__grad_name(TensorObject* self, PyObject* args,
static PyObject* tensor__grad_name(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE_EQ(grad != nullptr, true,
PADDLE_ENFORCE_EQ(grad != nullptr,
true,
platform::errors::InvalidArgument(
"Detected NULL grad. Please check if you have manually "
"cleared the grad inside autograd_meta"));
......@@ -1610,12 +1728,14 @@ static PyObject* tensor__grad_name(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__grad_value(TensorObject* self, PyObject* args,
static PyObject* tensor__grad_value(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE_EQ(grad != nullptr, true,
PADDLE_ENFORCE_EQ(grad != nullptr,
true,
platform::errors::InvalidArgument(
"Detected NULL grad. Please check if you have manually "
"cleared the grad inside autograd_meta"));
......@@ -1635,12 +1755,14 @@ static PyObject* tensor__grad_value(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__unset_fake_empty(TensorObject* self, PyObject* args,
static PyObject* tensor__unset_fake_empty(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE_EQ(grad != nullptr, true,
PADDLE_ENFORCE_EQ(grad != nullptr,
true,
platform::errors::InvalidArgument(
"Detected NULL grad. Please check if you have manually "
"cleared the grad inside autograd_meta"));
......@@ -1656,15 +1778,18 @@ static PyObject* tensor__unset_fake_empty(TensorObject* self, PyObject* args,
}
#if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self, PyObject* args,
static PyObject* tensor_method__uva(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(4) << "Running in tensor_method__uva.";
PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(), true,
PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
true,
platform::errors::InvalidArgument(
"Unified virtual addressing only support "
"DenseTensor currently."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()), true,
PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
true,
platform::errors::InvalidArgument(
"Unified virtual addressing only support "
"CPU Tensor currently."));
......@@ -1692,130 +1817,211 @@ static PyObject* tensor_method__is_string_tensor_hold_allocation(
}
PyMethodDef variable_methods[] = {
{"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy,
METH_VARARGS | METH_KEYWORDS, NULL},
{"numpy",
(PyCFunction)(void (*)(void))tensor_method_numpy,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_is_initialized",
(PyCFunction)(void (*)(void))tensor_method__is_initialized,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_is_dense_tensor_hold_allocation",
(PyCFunction)(void (*)(
void))tensor_method__is_dense_tensor_hold_allocation,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_copy_to", (PyCFunction)(void (*)(void))tensor_method__copy_to,
METH_VARARGS | METH_KEYWORDS, NULL},
{"copy_", (PyCFunction)(void (*)(void))tensor_method_copy_,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_copy_to",
(PyCFunction)(void (*)(void))tensor_method__copy_to,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"copy_",
(PyCFunction)(void (*)(void))tensor_method_copy_,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"reconstruct_from_",
(PyCFunction)(void (*)(void))tensor_method_reconstruct_from_,
METH_VARARGS | METH_KEYWORDS, NULL},
{"retain_grads", (PyCFunction)(void (*)(void))tensor_retain_grads,
METH_VARARGS | METH_KEYWORDS, NULL},
{"clear_gradient", (PyCFunction)(void (*)(void))tensor_clear_gradient,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_dense", (PyCFunction)(void (*)(void))tensor_method_is_dense,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_zero_grads", (PyCFunction)(void (*)(void))tensor__zero_grads,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_share_buffer_to", (PyCFunction)(void (*)(void))tensor__share_buffer_to,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"retain_grads",
(PyCFunction)(void (*)(void))tensor_retain_grads,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"clear_gradient",
(PyCFunction)(void (*)(void))tensor_clear_gradient,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_dense",
(PyCFunction)(void (*)(void))tensor_method_is_dense,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_zero_grads",
(PyCFunction)(void (*)(void))tensor__zero_grads,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_share_buffer_to",
(PyCFunction)(void (*)(void))tensor__share_buffer_to,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_is_shared_buffer_with",
(PyCFunction)(void (*)(void))tensor__is_shared_buffer_with,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_share_underline_tensor_to",
(PyCFunction)(void (*)(void))tensor__share_underline_tensor_to,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_is_shared_underline_tensor_with",
(PyCFunction)(void (*)(void))tensor__is_shared_underline_tensor_with,
METH_VARARGS | METH_KEYWORDS, NULL},
{"detach", (PyCFunction)(void (*)(void))tensor_method_detach,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"detach",
(PyCFunction)(void (*)(void))tensor_method_detach,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"get_tensor",
(PyCFunction)(void (*)(void))tensor_method_get_underline_tensor,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"get_selected_rows",
(PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_getitem_index_not_tensor",
(PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_getitem_from_offset",
(PyCFunction)(void (*)(void))tensor__getitem_from_offset,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__setitem_eager_tensor__",
(PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_register_grad_hook",
(PyCFunction)(void (*)(void))tensor_register_grad_hook,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_remove_grad_hook", (PyCFunction)(void (*)(void))tensor_remove_grad_hook,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_remove_grad_hook",
(PyCFunction)(void (*)(void))tensor_remove_grad_hook,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_register_backward_hook",
(PyCFunction)(void (*)(void))tensor_register_reduce_hook,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_set_grad_type", (PyCFunction)(void (*)(void))tensor__set_grad_type,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_clear", (PyCFunction)(void (*)(void))tensor__clear,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_set_grad_type",
(PyCFunction)(void (*)(void))tensor__set_grad_type,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_clear",
(PyCFunction)(void (*)(void))tensor__clear,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_copy_gradient_from",
(PyCFunction)(void (*)(void))tensor__copy_gradient_from,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
/** the methods to adapt old dygraph, will be removed in the future **/
{"set_string_list",
(PyCFunction)(void (*)(void))tensor_method_set_string_list,
METH_VARARGS | METH_KEYWORDS, NULL},
{"set_vocab", (PyCFunction)(void (*)(void))tensor_method_set_vocab,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"set_vocab",
(PyCFunction)(void (*)(void))tensor_method_set_vocab,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"get_map_tensor",
(PyCFunction)(void (*)(void))tensor_method_get_map_tensor,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
/***the method of sparse tensor****/
{"indices", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
METH_VARARGS | METH_KEYWORDS, NULL},
{"values", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
METH_VARARGS | METH_KEYWORDS, NULL},
{"crows", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"cols", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse", (PyCFunction)(void (*)(void))tensor_method_is_sparse,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse_coo", (PyCFunction)(void (*)(void))tensor_method_is_sparse_coo,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse_csr", (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
METH_VARARGS | METH_KEYWORDS, NULL},
{"to_sparse_csr", (PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
METH_VARARGS | METH_KEYWORDS, NULL},
{"element_size", (PyCFunction)(void (*)(void))tensor_method_element_size,
METH_VARARGS | METH_KEYWORDS, NULL},
{"indices",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"values",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"crows",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"cols",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_sparse",
(PyCFunction)(void (*)(void))tensor_method_is_sparse,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_sparse_coo",
(PyCFunction)(void (*)(void))tensor_method_is_sparse_coo,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_sparse_csr",
(PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"to_sparse_csr",
(PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"element_size",
(PyCFunction)(void (*)(void))tensor_method_element_size,
METH_VARARGS | METH_KEYWORDS,
NULL},
/***the method of sparse tensor****/
{"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_inplace_version",
(PyCFunction)(void (*)(void))tensor__inplace_version,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_bump_inplace_version",
(PyCFunction)(void (*)(void))tensor__bump_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_selected_rows",
(PyCFunction)(void (*)(void))tensor_method_is_selected_rows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"element_size", (PyCFunction)(void (*)(void))tensor_methon_element_size,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"rows",
(PyCFunction)(void (*)(void))tensor_method_get_rows,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"element_size",
(PyCFunction)(void (*)(void))tensor_methon_element_size,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_reset_grad_inplace_version",
(PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_share_memory", (PyCFunction)(void (*)(void))tensor_method__share_memory,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_offset", (PyCFunction)(void (*)(void))tensor__offset,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_grad_name", (PyCFunction)(void (*)(void))tensor__grad_name,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_grad_value", (PyCFunction)(void (*)(void))tensor__grad_value,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_unset_fake_empty", (PyCFunction)(void (*)(void))tensor__unset_fake_empty,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_share_memory",
(PyCFunction)(void (*)(void))tensor_method__share_memory,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_offset",
(PyCFunction)(void (*)(void))tensor__offset,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_grad_name",
(PyCFunction)(void (*)(void))tensor__grad_name,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_grad_value",
(PyCFunction)(void (*)(void))tensor__grad_value,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_unset_fake_empty",
(PyCFunction)(void (*)(void))tensor__unset_fake_empty,
METH_VARARGS | METH_KEYWORDS,
NULL},
#if defined(PADDLE_WITH_CUDA)
{"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_tensor_uva",
(PyCFunction)(void (*)(void))tensor_method__uva,
METH_VARARGS | METH_KEYWORDS,
NULL},
#endif
{NULL, NULL, 0, NULL}};
......@@ -1823,14 +2029,17 @@ PyMethodDef variable_methods[] = {
PyMethodDef string_tensor_variable_methods[] = {
{"numpy",
(PyCFunction)(void (*)(void))tensor_method_numpy_for_string_tensor,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_is_initialized",
(PyCFunction)(void (*)(void))tensor_method__is_initialized,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_is_string_tensor_hold_allocation",
(PyCFunction)(void (*)(
void))tensor_method__is_string_tensor_hold_allocation,
METH_VARARGS | METH_KEYWORDS, NULL},
METH_VARARGS | METH_KEYWORDS,
NULL},
// TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
{NULL, NULL, 0, NULL}};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册