未验证 提交 8aec0580 编写于 作者: B Bo Zhang 提交者: GitHub

Reduce redundant cpu computation in slice compute (#50348)

* conflict

* add UpdateSliceAttrs
上级 097402d9
...@@ -101,6 +101,49 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims, ...@@ -101,6 +101,49 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
} }
} }
template <typename T = int64_t>
inline void UpdateSliceAttrs(const DDim in_dims,
const std::vector<T>& axes,
std::vector<T>* starts,
std::vector<T>* ends,
std::vector<int64_t>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T dim_value = in_dims[axis];
if (dim_value > 0) {
T step = steps == nullptr ? 1 : (*steps)[i];
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));
T end =
0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
end = std::min(end, dim_value);
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
} else {
// NOTE: When step < 0, start should less and equal to
// dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
if (end < -1) {
end += dim_value;
}
end = std::max(end, static_cast<T>(-1));
}
(*starts)[i] = start;
(*ends)[i] = end;
} else if (dim_value == 0) {
(*starts)[i] = 0;
(*ends)[i] = 0;
}
}
}
template <typename T = int64_t> template <typename T = int64_t>
inline phi::DDim GetSliceDims(const phi::DDim in_dims, inline phi::DDim GetSliceDims(const phi::DDim in_dims,
const std::vector<T>& axes, const std::vector<T>& axes,
......
...@@ -53,7 +53,7 @@ void SliceCompute(const Context& ctx, ...@@ -53,7 +53,7 @@ void SliceCompute(const Context& ctx,
} }
} }
funcs::CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends); funcs::UpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
slice_dims = funcs::GetSliceDims<int64_t>( slice_dims = funcs::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, nullptr); in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = funcs::GetDecreasedDims<int64_t>(slice_dims, decrease_axis); out_dims = funcs::GetDecreasedDims<int64_t>(slice_dims, decrease_axis);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册