未验证 提交 34fafb11 编写于 作者: H heliqi 提交者: GitHub

[cherry-pick]Fix paddle.queeze_ bug (#49937)

* Fix paddle.queeze_ bug (#49903)

* fix queeze_ bug

* fix slove use squeeze_kernel

* fix slove use squeeze_kernel

* fix slove use squeeze_kernel

* add test case

* Update squeeze_kernel.h
上级 0699afb1
......@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims());
out_tmp = *out;
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
phi::Squeeze<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
......
......@@ -23,11 +23,7 @@ void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
auto x_dims = x.dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
out->Resize(out_dims);
auto out_dims = out->dims();
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims.
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -33,4 +34,14 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void Squeeze(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
MetaTensor meta_out(out);
SqueezeInferMeta(x, axes, &meta_out);
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
......@@ -28,7 +28,6 @@ paddle.enable_static()
# Correct: General.
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze2"
self.python_api = paddle.squeeze
......@@ -40,7 +39,7 @@ class TestSqueezeOp(OpTest):
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float64")
"XShape": np.random.random(self.ori_shape).astype("float64"),
}
def test_check_output(self):
......@@ -60,7 +59,6 @@ class TestSqueezeOp(OpTest):
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = (0, -2)
......@@ -69,7 +67,6 @@ class TestSqueezeOp1(TestSqueezeOp):
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = ()
......@@ -78,7 +75,6 @@ class TestSqueezeOp2(TestSqueezeOp):
# Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, -1)
......@@ -86,7 +82,6 @@ class TestSqueezeOp3(TestSqueezeOp):
class TestSqueeze2AxesTensor(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')
......@@ -123,7 +118,6 @@ class TestSqueeze2AxesTensor(UnittestBase):
class TestSqueeze2AxesTensorList(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')
......@@ -140,7 +134,7 @@ class TestSqueeze2AxesTensorList(UnittestBase):
# axes is a list[Variable]
axes = [
paddle.full([1], 0, dtype='int32'),
paddle.full([1], 2, dtype='int32')
paddle.full([1], 2, dtype='int32'),
]
out = paddle.squeeze(feat, axes)
out2 = paddle.fluid.layers.squeeze(feat, axes)
......@@ -162,5 +156,37 @@ class TestSqueeze2AxesTensorList(UnittestBase):
self.assertEqual(infer_out.shape, (2, 3, 10))
# test api
class TestSqueezeAPI(unittest.TestCase):
def setUp(self):
self.executed_api()
def executed_api(self):
self.squeeze = paddle.squeeze
def test_api(self):
paddle.disable_static()
input_data = np.random.random([3, 2, 1]).astype("float32")
x = paddle.to_tensor(input_data)
out = self.squeeze(x, axis=2)
out.backward()
self.assertEqual(out.shape, [3, 2])
paddle.enable_static()
def test_error(self):
def test_axes_type():
x2 = paddle.static.data(name="x2", shape=[2, 1, 25], dtype="int32")
self.squeeze(x2, axis=2.1)
self.assertRaises(TypeError, test_axes_type)
class TestSqueezeInplaceAPI(TestSqueezeAPI):
def executed_api(self):
self.squeeze = paddle.squeeze_
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册