From 36cdb6e27a02bd32370e24ac5cad2fd7692e8ac5 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Tue, 7 Sep 2021 09:40:36 +0200 Subject: [PATCH] Fix for reshape2 oneDNN op (#35455) * fix for reshape2 * added reviewers sugestions --- .../operators/mkldnn/reshape_mkldnn_op.cc | 25 +++++++++++- .../mkldnn/test_reshape_mkldnn_op.py | 39 +++++++++++++++---- 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc index 244430e69f2..d5e428bd805 100644 --- a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc @@ -22,6 +22,25 @@ using paddle::framework::LoDTensor; using platform::to_void_cast; using platform::GetMKLDNNFormat; +static std::vector extract_shape( + const std::vector& list_new_shape_tensor) { + std::vector vec_new_shape; + vec_new_shape.reserve(list_new_shape_tensor.size()); + + for (const auto& tensor : list_new_shape_tensor) { + PADDLE_ENFORCE_EQ( + tensor->dims(), framework::make_ddim({1}), + platform::errors::InvalidArgument( + "If the element type of 'shape' in ReshapeOp is Tensor, " + "the element's shape must be [1]. But received the element's shape " + "is [%s]", + tensor->dims())); + vec_new_shape.emplace_back(*tensor->data()); + } + + return vec_new_shape; +} + template class ReshapeMKLDNNKernel : public framework::OpKernel { public: @@ -59,7 +78,11 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { } if (ctx.Type().find("reshape") != std::string::npos) { - if (ctx.HasInput("Shape")) { + auto list_new_shape_tensor = ctx.MultiInput("ShapeTensor"); + if (list_new_shape_tensor.size() > 0) { + auto new_shape = extract_shape(list_new_shape_tensor); + out_dims = ValidateShape(new_shape, x_dims); + } else if (ctx.HasInput("Shape")) { auto* shape_tensor = ctx.Input("Shape"); auto* shape_data = shape_tensor->data(); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py index a28827207ee..78e5af3311b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_mkldnn_op.py @@ -72,9 +72,9 @@ class TestReshape2OneDNNOpDimInfer1(TestReshape2OneDNNOp): class TestReshape2OneDNNOpDimInfer2(TestReshape2OneDNNOp): def init_data(self): - self.ori_shape = (10, 2, 6) - self.new_shape = (10, 0, 3, -1) - self.infered_shape = (10, 2, 3, -1) + self.ori_shape = (6, 20) + self.new_shape = (0, -1, 20) + self.actual_shape = (2, 3, 20) def set_additional_inputs(self): self.inputs["Shape"] = np.array(self.actual_shape, dtype="int32") @@ -85,11 +85,6 @@ class TestReshape2OneDNNOpDimInfer2(TestReshape2OneDNNOp): 'XShape': np.random.random(self.ori_shape).astype("float32") } - def init_data1(self): - self.ori_shape = (6, 20) - self.new_shape = (0, -1, 20) - self.actual_shape = (2, 3, 20) - class TestReshape2OneDNNOp_attr_OnlyShape(TestReshape2OneDNNOp): def set_additional_inputs(self): @@ -119,6 +114,34 @@ class TestReshape2OneDNNOpDimInfer1_attr_OnlyShape( self.shape = (5, -1, -1) +class TestReshape2OneDNNOpDimInfer1_attr_ShapeTensor(TestReshape2OneDNNOp): + def set_additional_inputs(self): + shape_tensor = [] + for index, ele in enumerate(self.new_shape): + shape_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs["ShapeTensor"] = shape_tensor + + def init_data(self): + self.ori_shape = (5, 20) + self.new_shape = (5, -1, 10) + self.infered_shape = (5, -1, 10) + self.shape = (5, -1, -1) + + +class TestReshape2OneDNNOpDimInfer1_attr_ShapeTensorAndShape( + TestReshape2OneDNNOpDimInfer1_attr_ShapeTensor): + def set_additional_inputs(self): + shape_tensor = [] + for index, ele in enumerate(self.new_shape): + shape_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs["Shape"] = np.array((1, 2, 3, 4), dtype="int32") + self.inputs["ShapeTensor"] = shape_tensor + + class TestReshapeOneDNNOp(TestReshape2OneDNNOp): def set_op_type(self): self.op_type = "reshape" -- GitLab