提交 9ca88fa8 编写于 作者: C chenweihang

Adjust squeeze op and code the unittest, test passed

上级 bd57dec1
......@@ -33,11 +33,12 @@ class SqueezeOp : public framework::OperatorWithKernel {
"Output(Out) of SqueezeOp should not be null.");
const auto& x_dims = ctx->GetInputDim("X");
// TODO(chenweihang): need check input tensor dims (<9).
// Check input tensor dims (<9).
PADDLE_ENFORCE(x_dims.size() <= 9,
"Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions.");
const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes");
// TODO(chenweihang): need check axes is valid.
// PADDLE_ENFORCE();
for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The axis must be less than input tensor's rank.");
......@@ -45,7 +46,12 @@ class SqueezeOp : public framework::OperatorWithKernel {
auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
// TODO(chenweihang): need other check.
// TODO(chenweihang): This share option is necessary?
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
}
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
......@@ -67,12 +73,17 @@ class SqueezeOp : public framework::OperatorWithKernel {
for (int idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// TODO(chenweihang): shoude use PADALE_ENFORCE ? or if.
PADDLE_ENFORCE_GE(current, 0, "Invalid axis is given.");
PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given.");
PADDLE_ENFORCE_EQ(in_dims[current], 1, "Invalid axis is given.");
if (!(should_squeeze[current])) ++cnt_squeezed_dims;
// Check current index.
PADDLE_ENFORCE(current >= 0,
"Invalid axis, negative axis is out of range.");
// PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given.");
PADDLE_ENFORCE(
in_dims[current] == 1,
"Invalid axis index, the axis will be squeezed should be 1.");
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true;
}
}
......@@ -92,13 +103,14 @@ class SqueezeOp : public framework::OperatorWithKernel {
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), Tensors with at least max(dims) dimensions.");
AddOutput("Out", "(Tensor), Reshaped tensor with same data as input.");
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddAttr<std::vector<int>>("axes",
"List of positive integers,"
" indicate the dimensions to squeeze.");
"(std::vector<int>). List of positive integers,"
" indicate the dimensions to squeeze.")
.SetDefault({});
AddAttr<bool>("inplace",
"(default: false) Change the source tensor's shape without "
"(default: false) Squeeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).")
......@@ -110,6 +122,21 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Case 1:
Given
X.shape = (1, 3, 1, 5)
and
axes = [0]
we get:
Out.shape = (3, 1, 5)
Case 2:
Given
X.shape = (1, 3, 1, 5)
we get:
Out.shape = (3, 5)
)DOC");
}
};
......@@ -120,9 +147,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SqueezeOp should not be null.");
"Input(X) of SqueezeGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Output(Out@GRAD/) of SqueezeOp should not be null.");
"Output(Out@GRAD) of SqueezeGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
......
......@@ -33,7 +33,6 @@ class SqueezeKernel : public framework::OpKernel<T> {
framework::DDim out_dims = out->dims();
// TODO(chenweihang): Where is this attr be add.
bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims);
if (!inplace) {
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestSqueezeOp1(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, 2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: There is mins axis.
class TestSqueezeOp2(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, -2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: No axes input.
class TestSqueezeOp3(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = ()
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Just part of axes be squeezed.
class TestSqueezeOp4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace.
class TestSqueezeOpInplace1(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, 2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace. There is mins axis.
class TestSqueezeOpInplace2(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, -2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inplace. No axes input.
class TestSqueezeOpInplace3(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = ()
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册