提交 0cef33a4 编写于 作者: C chenweihang

adjust the dims range to [1,6] and fix some problem

上级 9ca88fa8
......@@ -33,10 +33,10 @@ class SqueezeOp : public framework::OperatorWithKernel {
"Output(Out) of SqueezeOp should not be null.");
const auto& x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<9).
PADDLE_ENFORCE(x_dims.size() <= 9,
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions.");
"between [1, 6] dimensions (Eigen limit).");
const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/squeeze_op.h"
namespace ops = paddle::operators;
......
......@@ -27,7 +27,7 @@ class TestSqueezeOp1(OpTest):
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
......@@ -46,7 +46,7 @@ class TestSqueezeOp2(OpTest):
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
......@@ -65,7 +65,7 @@ class TestSqueezeOp3(OpTest):
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
......@@ -78,13 +78,13 @@ class TestSqueezeOp3(OpTest):
# Correct: Just part of axes be squeezed.
class TestSqueezeOp4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
ori_shape = (3, 1, 5, 1, 4, 1)
axes = (1, -1)
new_shape = (3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
......@@ -122,7 +122,7 @@ class TestSqueezeOpInplace2(OpTest):
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
......@@ -141,7 +141,7 @@ class TestSqueezeOpInplace3(OpTest):
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
......@@ -154,13 +154,13 @@ class TestSqueezeOpInplace3(OpTest):
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
ori_shape = (3, 1, 5, 1, 4, 1)
axes = (1, -1)
new_shape = (3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册