未验证 提交 7a79fd88 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

Fix unsqueeze with empty axis bug (#51828)

上级 6ac7cabe
......@@ -4791,15 +4791,10 @@ void UnsqueezeInferMeta(const MetaTensor& x,
std::vector<int64_t> vec_out_dims(output_size, -1);
out->set_dtype(x.dtype());
out->set_dims(phi::make_ddim(vec_out_dims));
} else if (!axes.GetData().empty()) {
std::vector<int32_t> tmp;
tmp.reserve(axes.GetData().size());
std::for_each(axes.GetData().begin(),
axes.GetData().end(),
[&tmp](const int64_t& t) { tmp.push_back(t); });
auto out_dims = funcs::GetUnsqueezeShape(tmp, x_dims);
} else {
auto out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims);
out->set_dims(out_dims);
if (x_dims[0] == out_dims[0]) {
if (x_dims.size() > 0 && x_dims[0] == out_dims[0]) {
out->share_lod(x);
}
out->set_dtype(x.dtype());
......
......@@ -103,7 +103,7 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,
return phi::make_ddim(output_shape);
}
inline DDim GetUnsqueezeShape(const std::vector<int> unsqz_dims,
inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
const DDim& in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
......
......@@ -28,12 +28,7 @@ void UnsqueezeInferKernel(const Context& dev_ctx,
auto x_dims = x.dims();
auto out_dims = out->dims();
if (axes.FromTensor()) {
std::vector<int32_t> tmp;
tmp.reserve(axes.GetData().size());
std::for_each(axes.GetData().begin(),
axes.GetData().end(),
[&tmp](const int64_t& t) { tmp.push_back(t); });
out_dims = funcs::GetUnsqueezeShape(tmp, x_dims);
out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims);
}
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
......
......@@ -108,6 +108,22 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
self.new_shape = (10, 1, 1, 2, 5, 1)
# axis is empty, x is ND
class TestUnsqueezeOp5(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = ()
self.new_shape = ()
# axis is empty, x is 0D
class TestUnsqueezeOp6(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = ()
self.new_shape = (10, 2, 5)
class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册