提交 5f89272c 编写于 作者: C chenweihang

change the bit insert to array insert for understandability

上级 fccdc1ab
...@@ -44,39 +44,37 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -44,39 +44,37 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims, static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim &in_dims) { const framework::DDim &in_dims) {
unsigned int unsqz_mask = 0; int output_size = in_dims.size() + unsqz_dims.size();
unsigned int front = 0, back = 0; int cur_output_size = in_dims.size();
int output_dims_size = in_dims.size(); std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range.
PADDLE_ENFORCE(output_size <= 6,
"The output tensor's rank should be less than 6.");
// Simulate insert by bit calc.
for (int axis : unsqz_dims) { for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + output_dims_size + 1 : axis; int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Vaildity Check: the axis bound // Vaildity Check: the axis bound
PADDLE_ENFORCE( PADDLE_ENFORCE(
cur >= 0 && cur <= output_dims_size, cur >= 0 && cur <= cur_output_size,
"The unsqueeze dims must be within range of current rank."); "The unsqueeze dims must be within range of current rank.");
// Save the front part. // Move old axis, and insert new axis
front = unsqz_mask & ((1 << cur) - 1); for (int i = cur_output_size; i >= cur; --i) {
// Move the back part. if (output_shape[i] == 1) {
back = unsqz_mask & ~((1 << cur) - 1); // Move axis
back <<= 1; output_shape[i + 1] = 1;
// Merge two part. output_shape[i] = 0;
back |= (1 << cur); }
unsqz_mask = front | back; }
output_shape[cur] = 1;
// Add the output size. // Add the output size.
output_dims_size++; cur_output_size++;
// Validity Check: rank range.
PADDLE_ENFORCE(output_dims_size <= 6,
"The output tensor's rank should be less than 6.");
} }
// Make output shape // Make output shape
std::vector<int64_t> output_shape(output_dims_size, 0); for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
for (int in_idx = 0, out_idx = 0; out_idx < output_dims_size; ++out_idx) { if (output_shape[out_idx] == 0) {
if ((unsqz_mask & (1 << out_idx)) == 0) {
output_shape[out_idx] = in_dims[in_idx++]; output_shape[out_idx] = in_dims[in_idx++];
} else {
output_shape[out_idx] = 1;
} }
} }
...@@ -86,10 +84,7 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -86,10 +84,7 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
class UnsqueezeOp : public framework::OperatorBase { class UnsqueezeOp : public framework::OperatorBase {
public: public:
UnsqueezeOp(const std::string &type, const framework::VariableNameMap &inputs, using OperatorBase::OperatorBase;
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
...@@ -97,8 +92,6 @@ class UnsqueezeOp : public framework::OperatorBase { ...@@ -97,8 +92,6 @@ class UnsqueezeOp : public framework::OperatorBase {
auto &axes = Attr<std::vector<int>>("axes"); auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims(); auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
// auto out_dims =
// scope.FindVar(Output("Out"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims); attrs["shape"] = framework::vectorize2int(out_dims);
...@@ -165,11 +158,7 @@ class UnsqueezeGradInferShape : public framework::InferShapeBase { ...@@ -165,11 +158,7 @@ class UnsqueezeGradInferShape : public framework::InferShapeBase {
class UnsqueezeGradOp : public framework::OperatorBase { class UnsqueezeGradOp : public framework::OperatorBase {
public: public:
UnsqueezeGradOp(const std::string &type, using OperatorBase::OperatorBase;
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
......
...@@ -66,6 +66,14 @@ class TestUnsqueezeOp3(TestUnsqueezeOp): ...@@ -66,6 +66,14 @@ class TestUnsqueezeOp3(TestUnsqueezeOp):
self.new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 3, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1)
# Correct: Inplace. # Correct: Inplace.
class TestUnsqueezeOpInplace1(TestUnsqueezeOp): class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册