未验证 提交 11e34ae0 编写于 作者: H heliqi 提交者: GitHub

Fix paddle.queeze_ bug (#49903)

* fix queeze_ bug

* fix slove use squeeze_kernel

* fix slove use squeeze_kernel

* fix slove use squeeze_kernel

* add test case
上级 22b5241f
......@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims());
out_tmp = *out;
phi::SqueezeInferKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
phi::Squeeze<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
......
......@@ -25,11 +25,7 @@ void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
auto x_dims = x.dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
out->Resize(out_dims);
auto out_dims = out->dims();
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims.
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -33,4 +34,14 @@ void SqueezeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void Squeeze(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
MetaTensor meta_out(out);
SqueezeInferMeta(x, axes, &meta_out);
SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
......@@ -155,5 +155,37 @@ class TestSqueeze2AxesTensorList(UnittestBase):
self.assertEqual(infer_out.shape, (2, 3, 10))
# test api
class TestSqueezeAPI(unittest.TestCase):
def setUp(self):
self.executed_api()
def executed_api(self):
self.squeeze = paddle.squeeze
def test_api(self):
paddle.disable_static()
input_data = np.random.random([3, 2, 1]).astype("float32")
x = paddle.to_tensor(input_data)
out = self.squeeze(x, axis=2)
out.backward()
self.assertEqual(out.shape, [3, 2])
paddle.enable_static()
def test_error(self):
def test_axes_type():
x2 = paddle.static.data(name="x2", shape=[2, 1, 25], dtype="int32")
self.squeeze(x2, axis=2.1)
self.assertRaises(TypeError, test_axes_type)
class TestSqueezeInplaceAPI(TestSqueezeAPI):
def executed_api(self):
self.squeeze = paddle.squeeze_
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册