未验证 提交 7ce0ee69 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml (#41920)

* [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml

* add unittest for full_like

* fix unittest
上级 caaaf2f0
...@@ -36,11 +36,17 @@ inline bool NeedTransformDataType(const DataType& input, ...@@ -36,11 +36,17 @@ inline bool NeedTransformDataType(const DataType& input,
inline bool NeedTransformPlace(const paddle::platform::Place& input, inline bool NeedTransformPlace(const paddle::platform::Place& input,
const Backend& target, const Backend& target,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
bool ret = // NOTE(dev): The default value of TransformFlag is True, if it is set with
input.GetType() == AllocationType::GPUPINNED || // False
(transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND && // somewhere such as api.yaml or backward.yaml that means we should skip data
phi::TransToPhiBackend(input) != // transform. Because "stop_transform_" has highest priority.
(target != Backend::GPUDNN ? target : Backend::GPU)); if (!transform_flag.need_trans_backend()) {
return false;
}
bool ret = input.GetType() == AllocationType::GPUPINNED ||
(target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) !=
(target != Backend::GPUDNN ? target : Backend::GPU));
return ret; return ret;
} }
......
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import _test_eager_guard
class TestFullOp(unittest.TestCase): class TestFullOp(unittest.TestCase):
...@@ -133,5 +134,19 @@ class TestFullLikeOp3(TestFullLikeOp1): ...@@ -133,5 +134,19 @@ class TestFullLikeOp3(TestFullLikeOp1):
self.dtype = np.int64 self.dtype = np.int64
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFullLikeOp4(unittest.TestCase):
def test_skip_data_transform(self):
paddle.disable_static()
with _test_eager_guard():
x = paddle.to_tensor(
[1., 2., 3., 4.], place=paddle.CUDAPinnedPlace())
out = paddle.full_like(x, 1.)
self.assertTrue(
(out.numpy() == np.ones([4]).astype(np.float32)).all(), True)
paddle.enable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
func : AdamaxInferMeta func : AdamaxInferMeta
kernel : kernel :
func : adamax func : adamax
- api : adamw - api : adamw
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow) args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
...@@ -460,7 +460,7 @@ ...@@ -460,7 +460,7 @@
- api : deformable_conv - api : deformable_conv
args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : DeformableConvInferMeta func : DeformableConvInferMeta
kernel : kernel :
func : deformable_conv func : deformable_conv
...@@ -793,6 +793,8 @@ ...@@ -793,6 +793,8 @@
param : [x, value, dtype] param : [x, value, dtype]
data_type : dtype > x data_type : dtype > x
backend : place > x backend : place > x
data_transform :
skip_transform : x
- api : gather - api : gather
args : (Tensor x, Tensor index, Scalar axis=0) args : (Tensor x, Tensor index, Scalar axis=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册