From 5f89272c89befd113d1fa44e9055f47bcceb455e Mon Sep 17 00:00:00 2001 From: chenweihang Date: Mon, 9 Jul 2018 06:08:55 +0000 Subject: [PATCH] change the bit insert to array insert for understandability --- paddle/fluid/operators/unsqueeze_op.cc | 57 ++++++++----------- .../tests/unittests/test_unsqueeze_op.py | 8 +++ 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index d950da6a758..960bc6f241d 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -44,39 +44,37 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { static framework::DDim GetOutputShape(const std::vector unsqz_dims, const framework::DDim &in_dims) { - unsigned int unsqz_mask = 0; - unsigned int front = 0, back = 0; - int output_dims_size = in_dims.size(); + int output_size = in_dims.size() + unsqz_dims.size(); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validity Check: rank range. + PADDLE_ENFORCE(output_size <= 6, + "The output tensor's rank should be less than 6."); - // Simulate insert by bit calc. for (int axis : unsqz_dims) { - int cur = axis < 0 ? axis + output_dims_size + 1 : axis; + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; // Vaildity Check: the axis bound PADDLE_ENFORCE( - cur >= 0 && cur <= output_dims_size, + cur >= 0 && cur <= cur_output_size, "The unsqueeze dims must be within range of current rank."); - // Save the front part. - front = unsqz_mask & ((1 << cur) - 1); - // Move the back part. - back = unsqz_mask & ~((1 << cur) - 1); - back <<= 1; - // Merge two part. - back |= (1 << cur); - unsqz_mask = front | back; + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + output_shape[cur] = 1; // Add the output size. - output_dims_size++; - // Validity Check: rank range. - PADDLE_ENFORCE(output_dims_size <= 6, - "The output tensor's rank should be less than 6."); + cur_output_size++; } // Make output shape - std::vector output_shape(output_dims_size, 0); - for (int in_idx = 0, out_idx = 0; out_idx < output_dims_size; ++out_idx) { - if ((unsqz_mask & (1 << out_idx)) == 0) { + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { output_shape[out_idx] = in_dims[in_idx++]; - } else { - output_shape[out_idx] = 1; } } @@ -86,10 +84,7 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { class UnsqueezeOp : public framework::OperatorBase { public: - UnsqueezeOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + using OperatorBase::OperatorBase; private: void RunImpl(const framework::Scope &scope, @@ -97,8 +92,6 @@ class UnsqueezeOp : public framework::OperatorBase { auto &axes = Attr>("axes"); auto x_dims = scope.FindVar(Input("X"))->Get().dims(); auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); - // auto out_dims = - // scope.FindVar(Output("Out"))->Get().dims(); framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(out_dims); @@ -165,11 +158,7 @@ class UnsqueezeGradInferShape : public framework::InferShapeBase { class UnsqueezeGradOp : public framework::OperatorBase { public: - UnsqueezeGradOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + using OperatorBase::OperatorBase; private: void RunImpl(const framework::Scope &scope, diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index d19d4e525a8..7a4aa0a40b5 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -66,6 +66,14 @@ class TestUnsqueezeOp3(TestUnsqueezeOp): self.new_shape = (1, 3, 2, 1, 1, 5) +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + # Correct: Inplace. class TestUnsqueezeOpInplace1(TestUnsqueezeOp): def init_test_case(self): -- GitLab