提交 80126a74 编写于 作者: C chenweihang

small fix based reviewer's advice

上级 a6d94e8d
......@@ -28,9 +28,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
"Output(Out) of UnsqueezeOp should not be null.");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
PADDLE_ENFORCE(!axes.empty(),
"The unsqueeze axes information must be set by Attr(axes).");
const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6,
......@@ -123,6 +120,9 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
"(std::vector<int>). List of positive integers,"
" indicate the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) {
PADDLE_ENFORCE(
!axes.empty(),
"The unsqueeze axes information must be set by Attr(axes).");
// Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
"Invalid dimensions, dynamic dimensions should within "
......
......@@ -24,7 +24,7 @@ class TestUnsqueezeOp(OpTest):
self.init_test_case()
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.attrs = {"axes": self.axes, "inplace": False}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
......@@ -38,6 +38,9 @@ class TestUnsqueezeOp(OpTest):
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
......@@ -70,6 +73,9 @@ class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
self.axes = (0, 2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins index.
class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
......@@ -78,6 +84,9 @@ class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
self.axes = (0, -2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
......@@ -86,6 +95,9 @@ class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册