提交 0cef33a4 编写于 作者: C chenweihang

adjust the dims range to [1,6] and fix some problem

上级 9ca88fa8
...@@ -33,10 +33,10 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -33,10 +33,10 @@ class SqueezeOp : public framework::OperatorWithKernel {
"Output(Out) of SqueezeOp should not be null."); "Output(Out) of SqueezeOp should not be null.");
const auto& x_dims = ctx->GetInputDim("X"); const auto& x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<9). // Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 9, PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, dynamic dimensions must have " "Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions."); "between [1, 6] dimensions (Eigen limit).");
const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) { for (int a : axes) {
......
...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/squeeze_op.h" #include "paddle/fluid/operators/squeeze_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -27,7 +27,7 @@ class TestSqueezeOp1(OpTest): ...@@ -27,7 +27,7 @@ class TestSqueezeOp1(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -46,7 +46,7 @@ class TestSqueezeOp2(OpTest): ...@@ -46,7 +46,7 @@ class TestSqueezeOp2(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -65,7 +65,7 @@ class TestSqueezeOp3(OpTest): ...@@ -65,7 +65,7 @@ class TestSqueezeOp3(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -78,13 +78,13 @@ class TestSqueezeOp3(OpTest): ...@@ -78,13 +78,13 @@ class TestSqueezeOp3(OpTest):
# Correct: Just part of axes be squeezed. # Correct: Just part of axes be squeezed.
class TestSqueezeOp4(OpTest): class TestSqueezeOp4(OpTest):
def setUp(self): def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1) ori_shape = (3, 1, 5, 1, 4, 1)
axes = (2, 6) axes = (1, -1)
new_shape = (1, 3, 5, 1, 4) new_shape = (3, 5, 1, 4)
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -122,7 +122,7 @@ class TestSqueezeOpInplace2(OpTest): ...@@ -122,7 +122,7 @@ class TestSqueezeOpInplace2(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -141,7 +141,7 @@ class TestSqueezeOpInplace3(OpTest): ...@@ -141,7 +141,7 @@ class TestSqueezeOpInplace3(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -154,13 +154,13 @@ class TestSqueezeOpInplace3(OpTest): ...@@ -154,13 +154,13 @@ class TestSqueezeOpInplace3(OpTest):
# Correct: Inpalce. Just part of axes be squeezed. # Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(OpTest): class TestSqueezeOpInplace4(OpTest):
def setUp(self): def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1) ori_shape = (3, 1, 5, 1, 4, 1)
axes = (2, 6) axes = (1, -1)
new_shape = (1, 3, 5, 1, 4) new_shape = (3, 5, 1, 4)
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册