提交 80126a74 编写于 作者: C chenweihang

small fix based reviewer's advice

上级 a6d94e8d
...@@ -28,9 +28,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -28,9 +28,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
"Output(Out) of UnsqueezeOp should not be null."); "Output(Out) of UnsqueezeOp should not be null.");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
PADDLE_ENFORCE(!axes.empty(),
"The unsqueeze axes information must be set by Attr(axes).");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6). // Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6, PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6,
...@@ -123,6 +120,9 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -123,6 +120,9 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
"(std::vector<int>). List of positive integers," "(std::vector<int>). List of positive integers,"
" indicate the dimensions to be inserted") " indicate the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) { .AddCustomChecker([](const std::vector<int> &axes) {
PADDLE_ENFORCE(
!axes.empty(),
"The unsqueeze axes information must be set by Attr(axes).");
// Validity Check: axes dims (<6). // Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6, PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
"Invalid dimensions, dynamic dimensions should within " "Invalid dimensions, dynamic dimensions should within "
......
...@@ -24,7 +24,7 @@ class TestUnsqueezeOp(OpTest): ...@@ -24,7 +24,7 @@ class TestUnsqueezeOp(OpTest):
self.init_test_case() self.init_test_case()
self.op_type = "unsqueeze" self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.attrs = {"axes": self.axes, "inplace": False} self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -38,6 +38,9 @@ class TestUnsqueezeOp(OpTest): ...@@ -38,6 +38,9 @@ class TestUnsqueezeOp(OpTest):
self.axes = (1, 2) self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5) self.new_shape = (3, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
# Correct: Single input index. # Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp): class TestUnsqueezeOp1(TestUnsqueezeOp):
...@@ -70,6 +73,9 @@ class TestUnsqueezeOpInplace1(TestUnsqueezeOp): ...@@ -70,6 +73,9 @@ class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
self.axes = (0, 2) self.axes = (0, 2)
self.new_shape = (1, 3, 1, 5) self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins index. # Correct: Inplace. There is mins index.
class TestUnsqueezeOpInplace2(TestUnsqueezeOp): class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
...@@ -78,6 +84,9 @@ class TestUnsqueezeOpInplace2(TestUnsqueezeOp): ...@@ -78,6 +84,9 @@ class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
self.axes = (0, -2) self.axes = (0, -2)
self.new_shape = (1, 3, 1, 5) self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is duplicated axis. # Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOpInplace3(TestUnsqueezeOp): class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
...@@ -86,6 +95,9 @@ class TestUnsqueezeOpInplace3(TestUnsqueezeOp): ...@@ -86,6 +95,9 @@ class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
self.axes = (0, 3, 3) self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 3, 2, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册