未验证 提交 39f41cb4 编写于 作者: L liym27 提交者: GitHub

Performance optimization for dynamic setitem: Call op set_value to speed up...

Performance optimization for dynamic setitem: Call op set_value to speed up because the original call to TensorToPyArray will introduce unnecessary data copy. (#30817)
上级 bef46ccf
......@@ -583,26 +583,82 @@ void BindImperative(py::module *m_ptr) {
py::object &value_obj) {
auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
auto self_numpy = TensorToPyArray(*self_tensor);
PyObject *index_ptr = !PyTuple_Check(_index.ptr())
? PyTuple_Pack(1, _index.ptr())
: _index.ptr();
// 1. Check argumnets
// 1.1 Check whether _index can be parsed.
bool parse_index = true;
const int size = PyTuple_GET_SIZE(index_ptr);
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index_ptr, dim);
if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item))) {
parse_index = false;
break;
}
}
// 1.2 Check whether stride is 1.
std::vector<int> axes, starts, ends, strides, decrease_axis,
infer_flags;
bool stride_is_1 = true;
if (parse_index) {
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
&strides, &decrease_axis, &infer_flags);
stride_is_1 =
std::all_of(strides.cbegin(), strides.cend(),
[](int64_t stride) { return stride == 1; });
}
// 1.3 Check whether value obj is a tensor.
bool value_is_tensor = true;
if (py::isinstance<py::array>(value_obj) ||
py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::float_>(value_obj)) {
auto value_numpy = value_obj;
self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
value_is_tensor = false;
}
// 2. Call op set_value to speed up if the condition is met,
// otherwise call TensorToPyArray.
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if (parse_index && stride_is_1 && value_is_tensor) {
framework::AttributeMap attrs = {
{"axes", axes}, {"starts", starts}, {"ends", ends}};
imperative::NameVarBaseMap ins = {{"Input", {self}}};
imperative::NameVarBaseMap outs = {{"Out", {self}}};
} else {
auto value =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
auto value_tensor =
value->MutableVar()->GetMutable<framework::LoDTensor>();
auto value_numpy = TensorToPyArray(*value_tensor);
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
ins.insert({"ValueTensor", {value_tensor}});
self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
const auto &tracer = imperative::GetCurrentTracer();
{
// Release gil and do tracing
py::gil_scoped_release release;
tracer->TraceOp("set_value", ins, outs, std::move(attrs));
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
if (value_is_tensor) {
auto value =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
auto value_tensor =
value->MutableVar()->GetMutable<framework::LoDTensor>();
auto value_numpy = TensorToPyArray(*value_tensor);
self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
} else {
auto value_numpy = value_obj;
self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
}
}
// NOTE(liym27):
// Increase the version of VarBase self because __setitem__ is an
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册