未验证 提交 76154c94 编写于 作者: W WangZhen 提交者: GitHub

Fix TransDataBackend Error when call unsqueeze using MKL Tensor (#46094)

* Fix TransDataBackend Error when call unsqueeze using MKL Tensor

* Add UT

* Refine UT
上级 16439bb9
...@@ -307,7 +307,7 @@ paddle::optional<std::vector<phi::DenseTensor>> PrepareData( ...@@ -307,7 +307,7 @@ paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
void TransDataBackend(const phi::DenseTensor* tensor, void TransDataBackend(const phi::DenseTensor* tensor,
Backend target_backend, Backend target_backend,
phi::DenseTensor* out) { phi::DenseTensor* out) {
if (tensor) { if (tensor && tensor->initialized()) {
*out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend)); *out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend));
} }
} }
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import tempfile import tempfile
import numpy as np import numpy as np
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.framework import _enable_legacy_dygraph from paddle.fluid.framework import _dygraph_place_guard
from paddle.jit.layer import Layer from paddle.jit.layer import Layer
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
...@@ -51,9 +51,14 @@ class Net(paddle.nn.Layer): ...@@ -51,9 +51,14 @@ class Net(paddle.nn.Layer):
class TestMultiLoad(unittest.TestCase): class TestMultiLoad(unittest.TestCase):
def test_multi_load(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_multi_load(self):
x = paddle.full([2, 4], 2) x = paddle.full([2, 4], 2)
model = Net() model = Net()
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
...@@ -74,8 +79,41 @@ class TestMultiLoad(unittest.TestCase): ...@@ -74,8 +79,41 @@ class TestMultiLoad(unittest.TestCase):
np.testing.assert_allclose(forward_out1, forward_out2[0], rtol=1e-05) np.testing.assert_allclose(forward_out1, forward_out2[0], rtol=1e-05)
np.testing.assert_allclose(infer_out1, infer_out2[0], rtol=1e-05) np.testing.assert_allclose(infer_out1, infer_out2[0], rtol=1e-05)
class SaveLinear(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(80, 80)
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[None, 80], dtype='float32')])
def forward(self, x):
out = self.linear(x)
return out
class TestMKLOutput(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup() self.temp_dir.cleanup()
def test_mkl_output(self):
with _dygraph_place_guard(place=paddle.CPUPlace()):
net = SaveLinear()
model_path = os.path.join(self.temp_dir.name, 'save_linear')
paddle.jit.save(net, model_path, combine_params=True)
layer = Layer()
layer.load(model_path, paddle.CPUPlace())
x = paddle.ones([498, 80])
out = layer.forward(x)
out = paddle.unsqueeze(out[0], 0)
np.testing.assert_equal(out.shape, [1, 498, 80])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册