提交 9ca88fa8 编写于 作者: C chenweihang

Adjust squeeze op and code the unittest, test passed

上级 bd57dec1
...@@ -33,11 +33,12 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -33,11 +33,12 @@ class SqueezeOp : public framework::OperatorWithKernel {
"Output(Out) of SqueezeOp should not be null."); "Output(Out) of SqueezeOp should not be null.");
const auto& x_dims = ctx->GetInputDim("X"); const auto& x_dims = ctx->GetInputDim("X");
// TODO(chenweihang): need check input tensor dims (<9). // Check input tensor dims (<9).
PADDLE_ENFORCE(x_dims.size() <= 9,
"Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions.");
const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes");
// TODO(chenweihang): need check axes is valid.
// PADDLE_ENFORCE();
for (int a : axes) { for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(), PADDLE_ENFORCE_LT(a, x_dims.size(),
"The axis must be less than input tensor's rank."); "The axis must be less than input tensor's rank.");
...@@ -45,7 +46,12 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -45,7 +46,12 @@ class SqueezeOp : public framework::OperatorWithKernel {
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
// TODO(chenweihang): need other check. // TODO(chenweihang): This share option is necessary?
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
} }
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims, static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
...@@ -67,12 +73,17 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -67,12 +73,17 @@ class SqueezeOp : public framework::OperatorWithKernel {
for (int idx = 0; idx < num_squeeze_dims; ++idx) { for (int idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx]; : squeeze_dims[idx];
// TODO(chenweihang): shoude use PADALE_ENFORCE ? or if. // Check current index.
PADDLE_ENFORCE_GE(current, 0, "Invalid axis is given."); PADDLE_ENFORCE(current >= 0,
PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given."); "Invalid axis, negative axis is out of range.");
PADDLE_ENFORCE_EQ(in_dims[current], 1, "Invalid axis is given."); // PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given.");
PADDLE_ENFORCE(
if (!(should_squeeze[current])) ++cnt_squeezed_dims; in_dims[current] == 1,
"Invalid axis index, the axis will be squeezed should be 1.");
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true; should_squeeze[current] = true;
} }
} }
...@@ -92,13 +103,14 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -92,13 +103,14 @@ class SqueezeOp : public framework::OperatorWithKernel {
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor), Tensors with at least max(dims) dimensions."); AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor), Reshaped tensor with same data as input."); AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddAttr<std::vector<int>>("axes", AddAttr<std::vector<int>>("axes",
"List of positive integers," "(std::vector<int>). List of positive integers,"
" indicate the dimensions to squeeze."); " indicate the dimensions to squeeze.")
.SetDefault({});
AddAttr<bool>("inplace", AddAttr<bool>("inplace",
"(default: false) Change the source tensor's shape without " "(default: false) Squeeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output " "memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output " "tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).") "tensor is created, and its data are copied from Input(x).")
...@@ -110,6 +122,21 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,6 +122,21 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
Takes a parameter axes with a list of axes to squeeze. Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape. If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised. If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Case 1:
Given
X.shape = (1, 3, 1, 5)
and
axes = [0]
we get:
Out.shape = (3, 1, 5)
Case 2:
Given
X.shape = (1, 3, 1, 5)
we get:
Out.shape = (3, 5)
)DOC"); )DOC");
} }
}; };
...@@ -120,9 +147,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel { ...@@ -120,9 +147,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SqueezeOp should not be null."); "Input(X) of SqueezeGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Output(Out@GRAD/) of SqueezeOp should not be null."); "Output(Out@GRAD) of SqueezeGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
......
...@@ -33,7 +33,6 @@ class SqueezeKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,6 @@ class SqueezeKernel : public framework::OpKernel<T> {
framework::DDim out_dims = out->dims(); framework::DDim out_dims = out->dims();
// TODO(chenweihang): Where is this attr be add.
bool inplace = ctx.Attr<bool>("inplace"); bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims); out->Resize(out_dims);
if (!inplace) { if (!inplace) {
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestSqueezeOp1(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, 2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: There is mins axis.
class TestSqueezeOp2(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, -2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: No axes input.
class TestSqueezeOp3(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = ()
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Just part of axes be squeezed.
class TestSqueezeOp4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace.
class TestSqueezeOpInplace1(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, 2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace. There is mins axis.
class TestSqueezeOpInplace2(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, -2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace. No axes input.
class TestSqueezeOpInplace3(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = ()
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册