未验证 提交 7ce0ee69 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml (#41920)

* [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml

* add unittest for full_like

* fix unittest
上级 caaaf2f0
......@@ -36,9 +36,15 @@ inline bool NeedTransformDataType(const DataType& input,
inline bool NeedTransformPlace(const paddle::platform::Place& input,
const Backend& target,
const TransformFlag& transform_flag) {
bool ret =
input.GetType() == AllocationType::GPUPINNED ||
(transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND &&
// NOTE(dev): The default value of TransformFlag is True, if it is set with
// False
// somewhere such as api.yaml or backward.yaml that means we should skip data
// transform. Because "stop_transform_" has highest priority.
if (!transform_flag.need_trans_backend()) {
return false;
}
bool ret = input.GetType() == AllocationType::GPUPINNED ||
(target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) !=
(target != Backend::GPUDNN ? target : Backend::GPU));
return ret;
......
......@@ -22,6 +22,7 @@ import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import _test_eager_guard
class TestFullOp(unittest.TestCase):
......@@ -133,5 +134,19 @@ class TestFullLikeOp3(TestFullLikeOp1):
self.dtype = np.int64
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFullLikeOp4(unittest.TestCase):
def test_skip_data_transform(self):
paddle.disable_static()
with _test_eager_guard():
x = paddle.to_tensor(
[1., 2., 3., 4.], place=paddle.CUDAPinnedPlace())
out = paddle.full_like(x, 1.)
self.assertTrue(
(out.numpy() == np.ones([4]).astype(np.float32)).all(), True)
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -793,6 +793,8 @@
param : [x, value, dtype]
data_type : dtype > x
backend : place > x
data_transform :
skip_transform : x
- api : gather
args : (Tensor x, Tensor index, Scalar axis=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册