From 76154c941fa950c285cc98e14cd337a9eab49d1e Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Mon, 19 Sep 2022 10:48:17 +0800 Subject: [PATCH] Fix TransDataBackend Error when call unsqueeze using MKL Tensor (#46094) * Fix TransDataBackend Error when call unsqueeze using MKL Tensor * Add UT * Refine UT --- paddle/phi/api/lib/data_transform.cc | 2 +- .../fluid/tests/unittests/test_jit_layer.py | 42 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 23ff797d77c..72e65ae5286 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -307,7 +307,7 @@ paddle::optional> PrepareData( void TransDataBackend(const phi::DenseTensor* tensor, Backend target_backend, phi::DenseTensor* out) { - if (tensor) { + if (tensor && tensor->initialized()) { *out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend)); } } diff --git a/python/paddle/fluid/tests/unittests/test_jit_layer.py b/python/paddle/fluid/tests/unittests/test_jit_layer.py index bc5658127b2..5a03e0ac3b8 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_layer.py +++ b/python/paddle/fluid/tests/unittests/test_jit_layer.py @@ -18,7 +18,7 @@ import unittest import tempfile import numpy as np from paddle.static import InputSpec -from paddle.fluid.framework import _enable_legacy_dygraph +from paddle.fluid.framework import _dygraph_place_guard from paddle.jit.layer import Layer from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator @@ -51,9 +51,14 @@ class Net(paddle.nn.Layer): class TestMultiLoad(unittest.TestCase): - def test_multi_load(self): + def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() + def tearDown(self): + self.temp_dir.cleanup() + + def test_multi_load(self): + x = paddle.full([2, 4], 2) model = Net() program_translator = ProgramTranslator() @@ -74,8 +79,41 @@ class TestMultiLoad(unittest.TestCase): np.testing.assert_allclose(forward_out1, forward_out2[0], rtol=1e-05) np.testing.assert_allclose(infer_out1, infer_out2[0], rtol=1e-05) + +class SaveLinear(paddle.nn.Layer): + + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(80, 80) + + @paddle.jit.to_static( + input_spec=[InputSpec(shape=[None, 80], dtype='float32')]) + def forward(self, x): + out = self.linear(x) + return out + + +class TestMKLOutput(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): self.temp_dir.cleanup() + def test_mkl_output(self): + with _dygraph_place_guard(place=paddle.CPUPlace()): + net = SaveLinear() + model_path = os.path.join(self.temp_dir.name, 'save_linear') + paddle.jit.save(net, model_path, combine_params=True) + + layer = Layer() + layer.load(model_path, paddle.CPUPlace()) + x = paddle.ones([498, 80]) + out = layer.forward(x) + out = paddle.unsqueeze(out[0], 0) + np.testing.assert_equal(out.shape, [1, 498, 80]) + if __name__ == '__main__': unittest.main() -- GitLab