未验证 提交 b106c424 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] refine gil use (#46452)

* refine gil use
上级 a02eb143
......@@ -27,6 +27,11 @@
#include "pybind11/pytypes.h"
namespace egr {
GradNodePyLayer::~GradNodePyLayer() {
pybind11::gil_scoped_acquire gil;
Py_XDECREF(ctx_);
}
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>
GradNodePyLayer::operator()(
......
......@@ -34,7 +34,7 @@ class GradNodePyLayer : public GradNodeBase {
Py_INCREF(ctx_);
}
~GradNodePyLayer() override { Py_XDECREF(ctx_); };
~GradNodePyLayer() override;
virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>
......
......@@ -107,12 +107,18 @@ static PyObject* eager_api_scale(PyObject* self,
PyObject* kwargs) {
EAGER_TRY
// TODO(jiabin): Sync Tensor and Variable here when we support
paddle::experimental::Tensor ret = egr::scale(
reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor,
CastPyArg2AttrFloat(PyTuple_GET_ITEM(args, 1), 1),
CastPyArg2AttrFloat(PyTuple_GET_ITEM(args, 2), 2),
CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3),
CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 4), 4));
auto& tensor =
reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor;
float scale = CastPyArg2AttrFloat(PyTuple_GET_ITEM(args, 1), 1);
float bias = CastPyArg2AttrFloat(PyTuple_GET_ITEM(args, 2), 2);
bool bias_after_scale = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3);
bool trace_backward = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 4), 4);
paddle::experimental::Tensor ret;
{
eager_gil_scoped_release guard;
ret = egr::scale(tensor, scale, bias, bias_after_scale, trace_backward);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -123,11 +129,10 @@ static PyObject* eager_api_run_backward(PyObject* self,
EAGER_TRY
auto tensors = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
auto grad_tensors = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 1), 1);
bool retain_graph = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 2), 2);
{
eager_gil_scoped_release guard;
egr::Backward(tensors,
grad_tensors,
CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 2), 2));
egr::Backward(tensors, grad_tensors, retain_graph);
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
......@@ -156,8 +161,8 @@ static PyObject* eager_api_run_partial_grad(PyObject* self,
only_inputs,
allow_unused,
no_grad_vars);
}
VLOG(1) << " in eager_api_run_partial_grad, after runing egr::Grad";
}
return ToPyObject(result, true /* return_py_none_if_not_initialize */);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -173,11 +178,14 @@ static PyObject* eager_api_tensor_copy(PyObject* self,
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 2), 2);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3);
{
eager_gil_scoped_release guard;
dst = src.copy_to(place, blocking);
egr::EagerUtils::autograd_meta(&dst)->SetStopGradient(
egr::EagerUtils::autograd_meta(&(src))->StopGradient());
egr::EagerUtils::autograd_meta(&dst)->SetPersistable(
egr::EagerUtils::autograd_meta(&(src))->Persistable());
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -378,7 +386,11 @@ static PyObject* eager_api_jit_function_call(PyObject* self,
CastPyArg2JitFunction(PyTuple_GET_ITEM(args, 0), 0);
std::vector<paddle::experimental::Tensor> ins =
CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 1), 1);
std::vector<paddle::experimental::Tensor> outs = (*function)(ins);
std::vector<paddle::experimental::Tensor> outs;
{
eager_gil_scoped_release guard;
outs = (*function)(ins);
}
return ToPyObject(outs);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -391,10 +403,13 @@ static PyObject* eager_api_run_costum_op(PyObject* self,
CastPyArg2CustomOpKernelContext(PyTuple_GET_ITEM(args, 0), 0);
std::string op_type = CastPyArg2AttrString(PyTuple_GET_ITEM(args, 1), 1);
bool trace_backward = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 2), 2);
{
eager_gil_scoped_release guard;
VLOG(7) << "Get things for python for Custom Op: " << op_type
<< ", trace_backward is: " << trace_backward;
auto meta_info_map = egr::Controller::Instance().GetOpMetaInfoMap();
PADDLE_ENFORCE_NE(meta_info_map.find(op_type),
PADDLE_ENFORCE_NE(
meta_info_map.find(op_type),
meta_info_map.end(),
paddle::platform::errors::NotFound(
"Can't find %s in Eager OpMetaInfoMap which should be "
......@@ -454,8 +469,8 @@ static PyObject* eager_api_run_costum_op(PyObject* self,
if (slot_map[0][0].find(i) != slot_map[0][0].end()) {
grad_node->SetGradOutMeta(in_tensors, slot_map[0][0][i]);
} else {
grad_node->SetGradOutMeta(in_tensors,
ins_auto_grad_metas.size() - 1 - no_grad_cnt);
grad_node->SetGradOutMeta(
in_tensors, ins_auto_grad_metas.size() - 1 - no_grad_cnt);
no_grad_cnt++;
}
}
......@@ -502,6 +517,7 @@ static PyObject* eager_api_run_costum_op(PyObject* self,
}
grad_node->SetAttrs(attrs);
}
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -514,6 +530,9 @@ static PyObject* eager_api_sparse_coo_tensor(PyObject* self,
auto non_zero_elements = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 1), 1);
auto dense_shape = CastPyArg2VectorOfInt(PyTuple_GET_ITEM(args, 2), 2);
auto stop_gradient = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3);
paddle::experimental::Tensor tensor;
{
eager_gil_scoped_release guard;
PADDLE_ENFORCE(non_zero_indices.is_dense_tensor(),
paddle::platform::errors::Fatal(
"the non-zero indices must be a DenseTensor."));
......@@ -524,12 +543,11 @@ static PyObject* eager_api_sparse_coo_tensor(PyObject* self,
std::dynamic_pointer_cast<phi::DenseTensor>(non_zero_indices.impl());
auto dense_elements =
std::dynamic_pointer_cast<phi::DenseTensor>(non_zero_elements.impl());
// TODO(zhangkaihuo): After creating SparseCooTensor, call coalesced() to sort
// and merge duplicate indices
// TODO(zhangkaihuo): After creating SparseCooTensor, call coalesced() to
// sort and merge duplicate indices
std::shared_ptr<phi::SparseCooTensor> coo_tensor =
std::make_shared<phi::SparseCooTensor>(
*dense_indices, *dense_elements, phi::make_ddim(dense_shape));
paddle::experimental::Tensor tensor;
tensor.set_impl(coo_tensor);
auto name =
egr::Controller::Instance().GenerateUniqueName("generated_tensor");
......@@ -542,6 +560,7 @@ static PyObject* eager_api_sparse_coo_tensor(PyObject* self,
autograd_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
}
}
return ToPyObject(tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -555,6 +574,9 @@ static PyObject* eager_api_sparse_csr_tensor(PyObject* self,
auto non_zero_elements = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 2), 2);
auto dense_shape = CastPyArg2VectorOfInt(PyTuple_GET_ITEM(args, 3), 3);
auto stop_gradient = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 4), 4);
paddle::experimental::Tensor tensor;
{
eager_gil_scoped_release guard;
PADDLE_ENFORCE(non_zero_crows.is_dense_tensor(),
paddle::platform::errors::Fatal(
"the compressed non-zero rows must be a DenseTensor."));
......@@ -576,7 +598,6 @@ static PyObject* eager_api_sparse_csr_tensor(PyObject* self,
*dense_cols,
*dense_elements,
phi::make_ddim(dense_shape));
paddle::experimental::Tensor tensor;
tensor.set_impl(csr_tensor);
auto name =
egr::Controller::Instance().GenerateUniqueName("generated_tensor");
......@@ -589,6 +610,7 @@ static PyObject* eager_api_sparse_csr_tensor(PyObject* self,
autograd_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
}
}
return ToPyObject(tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -626,6 +648,8 @@ static PyObject* eager_api_async_read(PyObject* self,
auto& buffer = GetTensorFromArgs("async_read", "buffer", args, 3, false);
auto& offset = GetTensorFromArgs("async_read", "offset", args, 4, false);
auto& count = GetTensorFromArgs("async_read", "count", args, 5, false);
{
eager_gil_scoped_release guard;
PADDLE_ENFORCE_EQ(
src.is_gpu_pinned(),
true,
......@@ -683,7 +707,8 @@ static PyObject* eager_api_async_read(PyObject* self,
"`src` and `buffer` should have same tensor shape, "
"except for the first dimension."));
for (int i = 1; i < src_tensor.dims().size(); i++) {
PADDLE_ENFORCE_EQ(src_tensor.dims()[i],
PADDLE_ENFORCE_EQ(
src_tensor.dims()[i],
dst_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
......@@ -724,27 +749,27 @@ static PyObject* eager_api_async_read(PyObject* self,
for (int64_t i = 0; i < count_tensor.numel(); i++) {
numel += count_data[i];
}
PADDLE_ENFORCE_LE(
numel + index_tensor.numel(),
PADDLE_ENFORCE_LE(numel + index_tensor.numel(),
buffer_tensor->dims()[0],
platform::errors::InvalidArgument("Buffer tensor size is too small."));
PADDLE_ENFORCE_LE(
numel + index_tensor.numel(),
platform::errors::InvalidArgument(
"Buffer tensor size is too small."));
PADDLE_ENFORCE_LE(numel + index_tensor.numel(),
dst_tensor->dims()[0],
platform::errors::InvalidArgument("Target tensor size is too small."));
platform::errors::InvalidArgument(
"Target tensor size is too small."));
int64_t src_offset, dst_offset = 0, c;
auto* src_data = src_tensor.data<float>();
for (int64_t i = 0; i < offset_tensor.numel(); i++) {
src_offset = offset_data[i], c = count_data[i];
PADDLE_ENFORCE_LE(
src_offset + c,
PADDLE_ENFORCE_LE(src_offset + c,
src_tensor.dims()[0],
platform::errors::InvalidArgument("Invalid offset or count index."));
PADDLE_ENFORCE_LE(
dst_offset + c,
platform::errors::InvalidArgument(
"Invalid offset or count index."));
PADDLE_ENFORCE_LE(dst_offset + c,
dst_tensor->dims()[0],
platform::errors::InvalidArgument("Invalid offset or count index."));
platform::errors::InvalidArgument(
"Invalid offset or count index."));
cudaMemcpyAsync(dst_data + (dst_offset * size),
src_data + (src_offset * size),
c * size * sizeof(float),
......@@ -753,10 +778,10 @@ static PyObject* eager_api_async_read(PyObject* self,
dst_offset += c;
}
} else {
PADDLE_ENFORCE_LE(
index_tensor.numel(),
PADDLE_ENFORCE_LE(index_tensor.numel(),
buffer_tensor->dims()[0],
platform::errors::InvalidArgument("Buffer tensor size is too small."));
platform::errors::InvalidArgument(
"Buffer tensor size is too small."));
}
// Select the index data to the buffer
......@@ -784,6 +809,7 @@ static PyObject* eager_api_async_read(PyObject* self,
index_tensor.numel() * size * sizeof(float),
cudaMemcpyHostToDevice,
stream);
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -796,6 +822,8 @@ static PyObject* eager_api_async_write(PyObject* self,
auto& dst = GetTensorFromArgs("async_write", "dst", args, 1, false);
auto& offset = GetTensorFromArgs("async_write", "offset", args, 2, false);
auto& count = GetTensorFromArgs("async_write", "count", args, 3, false);
{
eager_gil_scoped_release guard;
PADDLE_ENFORCE_EQ(
src.is_gpu(),
true,
......@@ -847,7 +875,8 @@ static PyObject* eager_api_async_write(PyObject* self,
"`src` and `dst` should have the same tensor shape, "
"except for the first dimension."));
for (int i = 1; i < src_tensor.dims().size(); i++) {
PADDLE_ENFORCE_EQ(src_tensor.dims()[i],
PADDLE_ENFORCE_EQ(
src_tensor.dims()[i],
dst_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
......@@ -879,6 +908,7 @@ static PyObject* eager_api_async_write(PyObject* self,
stream);
src_offset += c;
}
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -929,7 +959,6 @@ static PyObject* eager_api_to_uva_tensor(PyObject* self,
"float64, int8, int16, int32, int64,"
"please check your input or input array data type."));
}
return ToPyObject(*(new_tensor.get()));
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......
......@@ -156,6 +156,7 @@ static PyObject* tensor_method_numpy(TensorObject* self,
}
if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
eager_gil_scoped_release guard;
platform::CPUPlace place;
if (self->tensor.is_selected_rows()) {
VLOG(6) << "Getting SelectedRows's numpy value";
......@@ -186,6 +187,7 @@ static PyObject* tensor_method_numpy(TensorObject* self,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (self->tensor.is_gpu()) {
eager_gil_scoped_release guard;
#if defined(PADDLE_WITH_CUDA)
gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
......@@ -244,6 +246,7 @@ static PyObject* tensor_method_numpy(TensorObject* self,
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (self->tensor.is_custom_device()) {
eager_gil_scoped_release guard;
if (self->tensor.is_selected_rows()) {
VLOG(6) << "Getting SelectedRows's numpy value";
auto* selected_rows =
......@@ -311,8 +314,8 @@ static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
const auto* st_ptr = string_tensor->data();
auto numel = self->tensor.numel();
auto tensor_dims = self->tensor.shape();
// Get the max unicode length of StringTensor to create numpy unicode string
// array.
// Get the max unicode length of StringTensor to create numpy unicode
// string array.
auto* longest_pstring = std::max_element(
st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
auto a_unicode_len =
......@@ -394,7 +397,10 @@ static PyObject* tensor_method__copy_to(TensorObject* self,
EAGER_TRY
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
auto cp_tensor = self->tensor.copy_to(place, blocking);
paddle::experimental::Tensor cp_tensor;
{
eager_gil_scoped_release guard;
cp_tensor = self->tensor.copy_to(place, blocking);
if (!blocking) {
IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
}
......@@ -402,6 +408,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self,
egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable(
egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
}
return ToPyObject(cp_tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -410,11 +417,15 @@ static PyObject* tensor_method_cpu(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
paddle::experimental::Tensor cp_tensor;
{
eager_gil_scoped_release guard;
cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable(
egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
}
return ToPyObject(cp_tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -450,6 +461,7 @@ static PyObject* tensor_method_copy_(TensorObject* self,
VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
<< self->tensor.name();
if (!self->tensor.initialized()) {
eager_gil_scoped_release guard;
egr::EagerUtils::autograd_meta(&(self->tensor))
->SetStopGradient(
egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
......@@ -461,6 +473,7 @@ static PyObject* tensor_method_copy_(TensorObject* self,
}
} else {
if (src_tensor.initialized()) {
eager_gil_scoped_release guard;
self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
}
}
......@@ -476,7 +489,9 @@ static PyObject* tensor_method_clone(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor out;
{
eager_gil_scoped_release guard;
PADDLE_ENFORCE_EQ(
self->tensor.initialized(),
true,
......@@ -485,7 +500,8 @@ static PyObject* tensor_method_clone(TensorObject* self,
"uninitialized tensor %s, please check your code.",
self->tensor.name()));
auto out = assign_ad_func(self->tensor);
out = assign_ad_func(self->tensor);
}
return ToPyObject(out);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -495,6 +511,7 @@ static PyObject* tensor_retain_grads(TensorObject* self,
PyObject* kwargs) {
EAGER_TRY
if (egr::Controller::Instance().HasGrad()) {
eager_gil_scoped_release guard;
auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
if (!meta->GetMutableGradNode()) {
VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
......@@ -535,6 +552,7 @@ static PyObject* tensor_clear_gradient(TensorObject* self,
}
if (grad->impl()) {
eager_gil_scoped_release guard;
if (grad->is_selected_rows()) {
auto selected_rows =
std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
......@@ -577,6 +595,7 @@ static PyObject* tensor__zero_grads(TensorObject* self,
VLOG(4) << "ZeroGrads " << self->tensor.name();
if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
eager_gil_scoped_release guard;
// Add RetainGrad as PostHook to AccumulationNode
paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
......@@ -595,6 +614,7 @@ static PyObject* tensor__zero_grads(TensorObject* self,
}
}
} else {
eager_gil_scoped_release guard;
auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
if (meta->MutableGrad()->initialized()) {
if (meta->MutableGrad()->is_dense_tensor()) {
......@@ -855,6 +875,7 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
decrease_axis.end());
if (op_type == "slice") {
eager_gil_scoped_release guard;
out = slice_ad_func(self->tensor,
slice_axes_tmp,
slice_starts,
......@@ -862,6 +883,7 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
infer_flags_tmp,
decrease_axis_tmp);
} else if (op_type == "strided_slice") {
eager_gil_scoped_release guard;
out = strided_slice_ad_func(
self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
} else {
......@@ -886,6 +908,9 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
none_axes.pop_back();
}
if (!none_axes.empty()) {
paddle::experimental::Tensor new_out;
{
eager_gil_scoped_release guard;
// Deal with cases that decrease_axes is not empty
// For example:
// # x.shape: (2,3,4)
......@@ -899,15 +924,15 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
}
axis -= len;
}
paddle::experimental::Tensor new_out;
new_out = unsqueeze_ad_func(out, none_axes);
}
return ToPyObject(new_out);
}
}
// the index is a list
if (list_select_flag) {
eager_gil_scoped_release guard;
auto select_index = paddle::experimental::Tensor(
egr::Controller::Instance().GenerateUniqueName());
auto idx_tensor = std::make_shared<phi::DenseTensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册