未验证 提交 28ca2e5f 编写于 作者: W wangchaochaohu 提交者: GitHub

strided_slice perforamnce improvement test=develop (#20852)

上级 6fcfd32e
...@@ -241,8 +241,6 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -241,8 +241,6 @@ class StridedSliceKernel : public framework::OpKernel<T> {
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
} }
framework::Tensor tmp;
auto out_dims_origin = out_dims; auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) { if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape; std::vector<int> new_out_shape;
...@@ -263,21 +261,34 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -263,21 +261,34 @@ class StridedSliceKernel : public framework::OpKernel<T> {
out_dims_origin = framework::make_ddim(new_out_shape); out_dims_origin = framework::make_ddim(new_out_shape);
} }
tmp.mutable_data<T>(out_dims, context.GetPlace()); bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
out->Resize(out_dims); out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto in_t = auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in); *in);
auto tmp_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
tmp);
auto out_t = auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, out_dims); *out, out_dims);
tmp_t.device(place) = if (need_reverse) {
in_t.stridedSlice(starts_indices, ends_indices, strides_indices); framework::Tensor tmp;
out_t.device(place) = tmp_t.reverse(reverse_axis); tmp.mutable_data<T>(out_dims, context.GetPlace());
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}
if (decrease_axis.size() > 0) { if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin); out->Resize(out_dims_origin);
...@@ -388,22 +399,33 @@ class StridedSliceGradKernel : public framework::OpKernel<T> { ...@@ -388,22 +399,33 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
} }
framework::Tensor reverse_input; bool need_reverse = false;
reverse_input.mutable_data<T>(in_dims, context.GetPlace()); for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
auto in_t = auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input); *d_input);
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
reverse_input);
auto out_t = auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims); *d_out, out_dims);
if (need_reverse) {
reverse_in_t.device(place) = in_t.reverse(reverse_axis); framework::Tensor reverse_input;
out_t.stridedSlice(starts_indices, ends_indices, strides_indices) reverse_input.mutable_data<T>(in_dims, context.GetPlace());
.device(place) = reverse_in_t; auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(reverse_input);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
}
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册