未验证 提交 ea22fdb0 编写于 作者: L Leo Chen 提交者: GitHub

support fetch empty tensor on CPUPlace (#51735)

* support fetch empty tensor on CPUPlace

* fix the shape in unittest of empty output
上级 a698eb7d
...@@ -34,7 +34,7 @@ namespace operators { ...@@ -34,7 +34,7 @@ namespace operators {
static void DeepCopy(const phi::DenseTensor &src_item, static void DeepCopy(const phi::DenseTensor &src_item,
const std::string &fetch_var_name, const std::string &fetch_var_name,
phi::DenseTensor *dst_item) { phi::DenseTensor *dst_item) {
if (src_item.IsInitialized() && src_item.numel() > 0) { if (src_item.IsInitialized()) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle // Conversion from MKL-DNN to Paddle
if (src_item.layout() == phi::DataLayout::ONEDNN) { if (src_item.layout() == phi::DataLayout::ONEDNN) {
...@@ -58,9 +58,7 @@ static void DeepCopy(const phi::DenseTensor &src_item, ...@@ -58,9 +58,7 @@ static void DeepCopy(const phi::DenseTensor &src_item,
paddle::framework::TensorCopySync(src_item, platform::CPUPlace(), dst_item); paddle::framework::TensorCopySync(src_item, platform::CPUPlace(), dst_item);
#endif #endif
} else { } else {
// Not copy, if the src tensor is empty. VLOG(4) << "No copy";
dst_item->clear();
dst_item->Resize({0});
} }
dst_item->set_lod(src_item.lod()); dst_item->set_lod(src_item.lod());
} }
......
...@@ -378,6 +378,19 @@ class TestException(unittest.TestCase): ...@@ -378,6 +378,19 @@ class TestException(unittest.TestCase):
) )
class TestFetchEmptyTensor(unittest.TestCase):
def test_fetch(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(paddle.static.Program()):
out = paddle.empty([3, 0])
exe = paddle.static.Executor(place)
res = exe.run(fetch_list=[out])
self.assertEqual(res[0].shape, (3, 0))
class TestInplaceApiWithDataTransform(unittest.TestCase): class TestInplaceApiWithDataTransform(unittest.TestCase):
def test_increment(self): def test_increment(self):
if paddle.fluid.core.is_compiled_with_cuda(): if paddle.fluid.core.is_compiled_with_cuda():
......
...@@ -270,7 +270,11 @@ class TestMatrixNMSOp(OpTest): ...@@ -270,7 +270,11 @@ class TestMatrixNMSOp(OpTest):
) )
empty = len(det_outs) == 0 empty = len(det_outs) == 0
det_outs = np.array([], dtype=np.float32) if empty else det_outs det_outs = (
np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
if empty
else det_outs
)
index_outs = np.array([], dtype=np.float32) if empty else index_outs index_outs = np.array([], dtype=np.float32) if empty else index_outs
nmsed_outs = det_outs.astype('float32') nmsed_outs = det_outs.astype('float32')
......
...@@ -626,7 +626,9 @@ class TestMulticlassNMS2Op(TestMulticlassNMSOp): ...@@ -626,7 +626,9 @@ class TestMulticlassNMS2Op(TestMulticlassNMSOp):
det_outs = np.array(det_outs) det_outs = np.array(det_outs)
nmsed_outs = ( nmsed_outs = (
det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs det_outs[:, :-1].astype('float32')
if len(det_outs)
else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
) )
index_outs = ( index_outs = (
det_outs[:, -1:].astype('int') if len(det_outs) else det_outs det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
...@@ -696,7 +698,9 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput): ...@@ -696,7 +698,9 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
det_outs = np.array(det_outs) det_outs = np.array(det_outs)
nmsed_outs = ( nmsed_outs = (
det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs det_outs[:, :-1].astype('float32')
if len(det_outs)
else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
) )
index_outs = ( index_outs = (
det_outs[:, -1:].astype('int') if len(det_outs) else det_outs det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
...@@ -768,7 +772,9 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op): ...@@ -768,7 +772,9 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
det_outs = np.array(det_outs) det_outs = np.array(det_outs)
nmsed_outs = ( nmsed_outs = (
det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs det_outs[:, :-1].astype('float32')
if len(det_outs)
else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
) )
index_outs = ( index_outs = (
det_outs[:, -1:].astype('int') if len(det_outs) else det_outs det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册