提交 49b2cf5f 编写于 作者: C chenweihang

adjust some code based reviewer's advice

上级 79333fa7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -36,15 +36,13 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -36,15 +36,13 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6, PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6,
"Invalid dimensions, dynamic dimensions should within " "Invalid dimensions, dynamic dimensions should within "
"[1, 6] dimensions (Eigen limit)."); "[1, 6] dimensions (Eigen limit).");
// Validity Check: the range of unsqueeze aixs.
for (int axis : axes) {
PADDLE_ENFORCE(axis < 6,
"Invalid dimensions, input axis should within "
"[1, 6] dimensions (Eigen limit).");
}
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
} }
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims, static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
...@@ -102,6 +100,8 @@ class UnsqueezeOp : public framework::OperatorBase { ...@@ -102,6 +100,8 @@ class UnsqueezeOp : public framework::OperatorBase {
auto &axes = Attr<std::vector<int>>("axes"); auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims(); auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
// auto out_dims =
// scope.FindVar(Output("Out"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims); attrs["shape"] = framework::vectorize2int(out_dims);
...@@ -121,7 +121,19 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -121,7 +121,19 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
AddAttr<std::vector<int>>("axes", AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of positive integers," "(std::vector<int>). List of positive integers,"
" indicate the dimensions to be inserted"); " indicate the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) {
// Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
"Invalid dimensions, dynamic dimensions should within "
"[1, 6] dimensions (Eigen limit).");
// Validity Check: the range of unsqueeze aixs.
for (int axis : axes) {
PADDLE_ENFORCE(axis < 6,
"Invalid dimensions, input axis should within "
"[1, 6] dimensions (Eigen limit).");
}
});
AddAttr<bool>( AddAttr<bool>(
"inplace", "inplace",
"(default: false) Unsqueeze the source tensor's shape without " "(default: false) Unsqueeze the source tensor's shape without "
......
...@@ -21,14 +21,11 @@ from op_test import OpTest ...@@ -21,14 +21,11 @@ from op_test import OpTest
# Correct: General. # Correct: General.
class TestUnsqueezeOp(OpTest): class TestUnsqueezeOp(OpTest):
def setUp(self): def setUp(self):
ori_shape = (3, 5) self.init_test_case()
axes = (0, 2)
new_shape = (1, 3, 1, 5)
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False} self.attrs = {"axes": self.axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -36,194 +33,59 @@ class TestUnsqueezeOp(OpTest): ...@@ -36,194 +33,59 @@ class TestUnsqueezeOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5)
# Correct: Single input index.
class TestUnsqueezeOp1(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (-1, )
new_shape = (3, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self): # Correct: Single input index.
self.check_grad(["X"], "Out") class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (-1, )
self.new_shape = (3, 5, 1)
# Correct: Mixed input axis. # Correct: Mixed input axis.
class TestUnsqueezeOp2(OpTest): class TestUnsqueezeOp2(TestUnsqueezeOp):
def setUp(self): def init_test_case(self):
ori_shape = (3, 5) self.ori_shape = (3, 5)
axes = (0, -1) self.axes = (0, -1)
new_shape = (1, 3, 5, 1) self.new_shape = (1, 3, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: There is duplicated axis. # Correct: There is duplicated axis.
class TestUnsqueezeOp3(OpTest): class TestUnsqueezeOp3(TestUnsqueezeOp):
def setUp(self): def init_test_case(self):
ori_shape = (3, 2, 5) self.ori_shape = (3, 2, 5)
axes = (0, 3, 3) self.axes = (0, 3, 3)
new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 3, 2, 1, 1, 5)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace. # Correct: Inplace.
class TestUnsqueezeOpInplace1(OpTest): class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
def setUp(self): def init_test_case(self):
ori_shape = (3, 5) self.ori_shape = (3, 5)
axes = (0, 2) self.axes = (0, 2)
new_shape = (1, 3, 1, 5) self.new_shape = (1, 3, 1, 5)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace. There is mins index. # Correct: Inplace. There is mins index.
class TestUnsqueezeOpInplace2(OpTest): class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
def setUp(self): def init_test_case(self):
ori_shape = (3, 5) self.ori_shape = (3, 5)
axes = (0, -2) self.axes = (0, -2)
new_shape = (1, 3, 1, 5) self.new_shape = (1, 3, 1, 5)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace. There is duplicated axis. # Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOpInplace3(OpTest): class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
def setUp(self): def init_test_case(self):
ori_shape = (3, 2, 5) self.ori_shape = (3, 2, 5)
axes = (0, 3, 3) self.axes = (0, 3, 3)
new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 3, 2, 1, 1, 5)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
'''
# Error: Output dimension is error.
class TestUnsqueezeOp4(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 3)
new_shape = (1, 3, 1, 1, 5)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Error: Input axis is large than output range.
class TestUnsqueezeOp5(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 4)
new_shape = (1, 3, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Error: Input axes is large than Eigen limit.
class TestUnsqueezeOp6(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 2, 10)
new_shape = (1, 3, 1, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Error: Input axes size is large than Eigen limit.
class TestUnsqueezeOp7(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 2, 2, 2, 2, 2)
new_shape = (1, 3, 1, 1, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
'''
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册