未验证 提交 28ca2e5f 编写于 作者: W wangchaochaohu 提交者: GitHub

strided_slice perforamnce improvement test=develop (#20852)

上级 6fcfd32e
......@@ -241,8 +241,6 @@ class StridedSliceKernel : public framework::OpKernel<T> {
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
framework::Tensor tmp;
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
......@@ -263,21 +261,34 @@ class StridedSliceKernel : public framework::OpKernel<T> {
out_dims_origin = framework::make_ddim(new_out_shape);
}
tmp.mutable_data<T>(out_dims, context.GetPlace());
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in);
auto tmp_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
tmp);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, out_dims);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
if (need_reverse) {
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
......@@ -388,22 +399,33 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input);
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
reverse_input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
if (need_reverse) {
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(reverse_input);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
}
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册