未验证 提交 db6242e9 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] release gil before op run (#35370)

* release gil before op run

* support npu grad test

* fix op_test
上级 3dab2e20
......@@ -1849,19 +1849,32 @@ All parameter, weight, gradient are variables in Paddle.
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CPUPlace &place) { self.Run(scope, place); })
const platform::CPUPlace &place) {
pybind11::gil_scoped_release release;
self.Run(scope, place);
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::XPUPlace &place) { self.Run(scope, place); })
const platform::XPUPlace &place) {
pybind11::gil_scoped_release release;
self.Run(scope, place);
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::NPUPlace &place) { self.Run(scope, place); })
const platform::NPUPlace &place) {
pybind11::gil_scoped_release release;
self.Run(scope, place);
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CUDAPlace &place) { self.Run(scope, place); })
const platform::CUDAPlace &place) {
pybind11::gil_scoped_release release;
self.Run(scope, place);
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CUDAPinnedPlace &place) {
pybind11::gil_scoped_release release;
self.Run(scope, place);
})
.def("type",
......
......@@ -216,6 +216,7 @@ T TensorGetElement(const framework::Tensor &self, size_t offset) {
PADDLE_ENFORCE_LT(offset, self.numel(),
platform::errors::InvalidArgument(
"The offset exceeds the size of tensor."));
T b = static_cast<T>(0);
if (platform::is_cpu_place(self.place())) {
b = self.data<T>()[offset];
......@@ -231,8 +232,17 @@ T TensorGetElement(const framework::Tensor &self, size_t offset) {
auto p = BOOST_GET_CONST(platform::CUDAPlace, self.place());
paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T),
nullptr);
#endif
} else if (platform::is_npu_place(self.place())) {
#if defined(PADDLE_WITH_ASCEND_CL)
const T *a = self.data<T>();
auto p = BOOST_GET_CONST(platform::NPUPlace, self.place());
paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T),
nullptr);
#endif
}
VLOG(10) << "TensorGetElement, place: " << self.place()
<< ", offset: " << offset << ", element: " << b;
return b;
}
......@@ -241,6 +251,8 @@ void TensorSetElement(framework::Tensor *self, size_t offset, T elem) {
PADDLE_ENFORCE_LT(offset, self->numel(),
platform::errors::InvalidArgument(
"The offset exceeds the size of tensor."));
VLOG(10) << "TensorSetElement, place: " << self->place()
<< ", offset: " << offset << ", element: " << elem;
if (platform::is_cpu_place(self->place())) {
self->mutable_data<T>(self->place())[offset] = elem;
} else if (platform::is_xpu_place(self->place())) {
......@@ -255,6 +267,13 @@ void TensorSetElement(framework::Tensor *self, size_t offset, T elem) {
T *a = self->mutable_data<T>(p);
paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T),
nullptr);
#endif
} else if (platform::is_npu_place(self->place())) {
#if defined(PADDLE_WITH_ASCEND_CL)
auto p = BOOST_GET_CONST(platform::NPUPlace, self->place());
T *a = self->mutable_data<T>(p);
paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T),
nullptr);
#endif
}
}
......@@ -676,7 +695,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor,
size_t numel = 1;
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
py_dims[i] = (size_t)tensor_dims[i];
py_dims[i] = static_cast<size_t>(tensor_dims[i]);
py_strides[i] = sizeof_dtype * numel;
numel *= py_dims[i];
}
......
......@@ -1491,18 +1491,9 @@ class OpTest(unittest.TestCase):
if not type(output_names) is list:
output_names = [output_names]
# FIXME: Replace numeric_place with place to calculate numeric_grads.
# NOTE(liym27): There is an unknown error when call op.run() on NPUPlace, which
# needs to be fixed.
if hasattr(self.__class__,
"use_npu") and self.__class__.use_npu == True:
numeric_place = paddle.CPUPlace()
else:
numeric_place = place
numeric_grads = user_defined_grads or [
get_numeric_gradient(
numeric_place,
place,
self.scope,
self.op,
self.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册