From 7ce0ee69bd299e6f0952cbff33a5e9b69d821680 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 19 Apr 2022 09:42:54 +0800 Subject: [PATCH] [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml (#41920) * [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml * add unittest for full_like * fix unittest --- paddle/phi/api/lib/data_transform.cc | 16 +++++++++++----- .../fluid/tests/unittests/test_full_like_op.py | 15 +++++++++++++++ python/paddle/utils/code_gen/api.yaml | 6 ++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index d4e92ded32..65cb37d414 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -36,11 +36,17 @@ inline bool NeedTransformDataType(const DataType& input, inline bool NeedTransformPlace(const paddle::platform::Place& input, const Backend& target, const TransformFlag& transform_flag) { - bool ret = - input.GetType() == AllocationType::GPUPINNED || - (transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND && - phi::TransToPhiBackend(input) != - (target != Backend::GPUDNN ? target : Backend::GPU)); + // NOTE(dev): The default value of TransformFlag is True, if it is set with + // False + // somewhere such as api.yaml or backward.yaml that means we should skip data + // transform. Because "stop_transform_" has highest priority. + if (!transform_flag.need_trans_backend()) { + return false; + } + bool ret = input.GetType() == AllocationType::GPUPINNED || + (target != Backend::ALL_BACKEND && + phi::TransToPhiBackend(input) != + (target != Backend::GPUDNN ? target : Backend::GPU)); return ret; } diff --git a/python/paddle/fluid/tests/unittests/test_full_like_op.py b/python/paddle/fluid/tests/unittests/test_full_like_op.py index 05a310a9c5..d3fea677a4 100644 --- a/python/paddle/fluid/tests/unittests/test_full_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_full_like_op.py @@ -22,6 +22,7 @@ import unittest import numpy as np from op_test import OpTest from paddle.fluid.framework import convert_np_dtype_to_dtype_ +from paddle.fluid.framework import _test_eager_guard class TestFullOp(unittest.TestCase): @@ -133,5 +134,19 @@ class TestFullLikeOp3(TestFullLikeOp1): self.dtype = np.int64 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFullLikeOp4(unittest.TestCase): + def test_skip_data_transform(self): + paddle.disable_static() + with _test_eager_guard(): + x = paddle.to_tensor( + [1., 2., 3., 4.], place=paddle.CUDAPinnedPlace()) + out = paddle.full_like(x, 1.) + self.assertTrue( + (out.numpy() == np.ones([4]).astype(np.float32)).all(), True) + paddle.enable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 351eca5a6b..58b80950e5 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -58,7 +58,7 @@ func : AdamaxInferMeta kernel : func : adamax - + - api : adamw args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) @@ -460,7 +460,7 @@ - api : deformable_conv args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) output : Tensor(out) - infer_meta : + infer_meta : func : DeformableConvInferMeta kernel : func : deformable_conv @@ -793,6 +793,8 @@ param : [x, value, dtype] data_type : dtype > x backend : place > x + data_transform : + skip_transform : x - api : gather args : (Tensor x, Tensor index, Scalar axis=0) -- GitLab