未验证 提交 ea22fdb0 编写于 作者: L Leo Chen 提交者: GitHub

support fetch empty tensor on CPUPlace (#51735)

* support fetch empty tensor on CPUPlace

* fix the shape in unittest of empty output
上级 a698eb7d
......@@ -34,7 +34,7 @@ namespace operators {
static void DeepCopy(const phi::DenseTensor &src_item,
const std::string &fetch_var_name,
phi::DenseTensor *dst_item) {
if (src_item.IsInitialized() && src_item.numel() > 0) {
if (src_item.IsInitialized()) {
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle
if (src_item.layout() == phi::DataLayout::ONEDNN) {
......@@ -58,9 +58,7 @@ static void DeepCopy(const phi::DenseTensor &src_item,
paddle::framework::TensorCopySync(src_item, platform::CPUPlace(), dst_item);
#endif
} else {
// Not copy, if the src tensor is empty.
dst_item->clear();
dst_item->Resize({0});
VLOG(4) << "No copy";
}
dst_item->set_lod(src_item.lod());
}
......
......@@ -378,6 +378,19 @@ class TestException(unittest.TestCase):
)
class TestFetchEmptyTensor(unittest.TestCase):
def test_fetch(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(paddle.static.Program()):
out = paddle.empty([3, 0])
exe = paddle.static.Executor(place)
res = exe.run(fetch_list=[out])
self.assertEqual(res[0].shape, (3, 0))
class TestInplaceApiWithDataTransform(unittest.TestCase):
def test_increment(self):
if paddle.fluid.core.is_compiled_with_cuda():
......
......@@ -270,7 +270,11 @@ class TestMatrixNMSOp(OpTest):
)
empty = len(det_outs) == 0
det_outs = np.array([], dtype=np.float32) if empty else det_outs
det_outs = (
np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
if empty
else det_outs
)
index_outs = np.array([], dtype=np.float32) if empty else index_outs
nmsed_outs = det_outs.astype('float32')
......
......@@ -626,7 +626,9 @@ class TestMulticlassNMS2Op(TestMulticlassNMSOp):
det_outs = np.array(det_outs)
nmsed_outs = (
det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs
det_outs[:, :-1].astype('float32')
if len(det_outs)
else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
)
index_outs = (
det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
......@@ -696,7 +698,9 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
det_outs = np.array(det_outs)
nmsed_outs = (
det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs
det_outs[:, :-1].astype('float32')
if len(det_outs)
else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
)
index_outs = (
det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
......@@ -768,7 +772,9 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
det_outs = np.array(det_outs)
nmsed_outs = (
det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs
det_outs[:, :-1].astype('float32')
if len(det_outs)
else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
)
index_outs = (
det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册