diff --git a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc index e16c41829b1a657082325024924ca1a134988c67..facf5ca4b8397f3d499dfa751157fafd26a388ff 100644 --- a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc @@ -12,8 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +static mkldnn::memory::format_tag get_plain_format_tag( + const paddle::framework::Tensor* tensor) { + auto tensor_dims_size = tensor->dims().size(); + + switch (tensor_dims_size) { + case 1: + return mkldnn::memory::format_tag::a; + case 2: + return mkldnn::memory::format_tag::ab; + case 3: + return mkldnn::memory::format_tag::abc; + case 4: + return mkldnn::memory::format_tag::abcd; + case 5: + return mkldnn::memory::format_tag::abcde; + } + + return mkldnn::memory::format_tag::abcdef; +} + namespace paddle { namespace operators { @@ -35,7 +56,6 @@ class SliceMKLDNNKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto x_vec_dims = framework::vectorize(x->dims()); - auto out_vec_dims = framework::vectorize(out->dims()); auto axes_int = ctx.Attr>("axes"); auto starts_int = ctx.Attr>("starts"); @@ -48,8 +68,22 @@ class SliceMKLDNNKernel : public framework::OpKernel { std::vector ends(ctx.Attr>("ends").begin(), ctx.Attr>("ends").end()); + auto starts_tensor_list = ctx.MultiInput("StartsTensorList"); + if (ctx.HasInput("StartsTensor")) { + starts = GetDataFromTensor(ctx.Input("StartsTensor")); + } else if (starts_tensor_list.size() > 0) { + starts = GetDataFromTensorList(starts_tensor_list); + } + auto decrease_axis = ctx.Attr>("decrease_axis"); + auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); + if (ctx.HasInput("EndsTensor")) { + ends = GetDataFromTensor(ctx.Input("EndsTensor")); + } else if (ends_tensor_list.size() > 0) { + ends = GetDataFromTensorList(ends_tensor_list); + } + std::vector offsets(x_vec_dims.size(), 0); std::vector slice_dims(x_vec_dims); @@ -61,6 +95,8 @@ class SliceMKLDNNKernel : public framework::OpKernel { slice_dims[axes[i]] = ends[i] - starts[i]; } + out->Resize(framework::make_ddim(slice_dims)); + mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); auto key = platform::CreateKey(dev_ctx, x_vec_dims, axes, starts, ends, x->format(), x_type); @@ -73,20 +109,35 @@ class SliceMKLDNNKernel : public framework::OpKernel { auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, reorder_src_memory_p); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - out, slice_dims, 0, x->format(), ctx.GetPlace()); + out, slice_dims, 0, get_plain_format_tag(x), ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); - astream.wait(); + std::vector new_out_dims(slice_dims.size() - decrease_axis.size()); + + if (new_out_dims.size() == 0) { + new_out_dims.emplace_back(1); + } else { + for (const auto& axis : decrease_axis) { + slice_dims[axis] = 0; + } + + int i = 0; + for (const auto& slice_dim : slice_dims) { + if (slice_dim != 0) new_out_dims[i++] = slice_dim; + } + } + + astream.wait(); + out->Resize(framework::make_ddim(new_out_dims)); out->set_layout(framework::DataLayout::kMKLDNN); out->set_format(platform::GetMKLDNNFormat( - reorder_dst_memory_p->get_desc().reshape(out_vec_dims))); + reorder_dst_memory_p->get_desc().reshape(new_out_dims))); } }; - template class SliceGradMKLDNNKernel : public framework::OpKernel { public: @@ -116,6 +167,20 @@ class SliceGradMKLDNNKernel : public framework::OpKernel { std::vector ends(ctx.Attr>("ends").begin(), ctx.Attr>("ends").end()); + auto starts_tensor_list = ctx.MultiInput("StartsTensorList"); + if (ctx.HasInput("StartsTensor")) { + starts = GetDataFromTensor(ctx.Input("StartsTensor")); + } else if (starts_tensor_list.size() > 0) { + starts = GetDataFromTensorList(starts_tensor_list); + } + + auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); + if (ctx.HasInput("EndsTensor")) { + ends = GetDataFromTensor(ctx.Input("EndsTensor")); + } else if (ends_tensor_list.size() > 0) { + ends = GetDataFromTensorList(ends_tensor_list); + } + auto decrease_axis = ctx.Attr>("decrease_axis"); std::vector offsets(dx_vec_dims.size(), 0); @@ -172,4 +237,4 @@ REGISTER_OP_KERNEL(slice, MKLDNN, paddle::platform::CPUPlace, namespace ops = paddle::operators; REGISTER_OP_KERNEL(slice_grad, MKLDNN, paddle::platform::CPUPlace, ops::SliceGradMKLDNNKernel, - ops::SliceGradMKLDNNKernel); \ No newline at end of file + ops::SliceGradMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_slice_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_slice_mkldnn_op.py index caebcffd0e966af1dd55eec9fb4b900673c8e66d..443e4d90c3a8a21a07f66a6eed3002aa37538766 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_slice_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_slice_mkldnn_op.py @@ -150,6 +150,46 @@ class TestSlice3DOneDNNOp(TestSliceDecrease1AxisOneDNNOp): self.out = self.input[:, :, -1] +class TestSliceOneDNNOp_decs_dim_starts_ListTensor( + TestSliceDecrease1AxisOneDNNOp): + def set_inputs(self): + starts_tensor = [] + for index, ele in enumerate(self.starts): + starts_tensor.append(("x1", np.ones((1)).astype('int32') * 2)) + self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} + + def config(self): + self.input = np.random.random([5, 4, 5]).astype("float32") + self.starts = [1] + self.ends = [3] + self.axes = [2] + self.decrease_axis = [] + self.infer_flags = [1, 1, 1] + self.out = self.input[:, :, 2:3] + + +class TestSlice4DInferDimsOneDNNOp(TestSliceDecrease1AxisOneDNNOp): + def config(self): + self.input = np.random.random([1, 1, 10, 10]).astype("float32") + self.starts = [1, 2] + self.ends = [9, 9] + self.axes = [2, 3] + self.decrease_axis = [1] + self.infer_flags = [-1, -1] + self.out = self.input[:, :, 1:9, 2:9] + + +class TestSlice4DInferDimsOneDNNOp2(TestSliceDecrease1AxisOneDNNOp): + def config(self): + self.input = np.random.random([1, 1, 10, 10]).astype("float32") + self.starts = [4, 2] + self.ends = [7, 8] + self.axes = [2, 3] + self.decrease_axis = [0, 1] + self.infer_flags = [-1, -1] + self.out = self.input[:, :, 4:7, 2:8] + + # BF16 TESTS def create_bf16_test_class(parent): @OpTestTool.skip_if_not_cpu_bf16()