未验证 提交 93f0e594 编写于 作者: A Aurelius84 提交者: GitHub

[Cherry-Pick]Fix expand_sig infershape BUG under static graph mode and...

[Cherry-Pick]Fix expand_sig infershape BUG under static graph mode and NeedTransformPlace behavior if set skip_transform in yaml (#41973)

* [Phi]Fix expand_sig infershape BUG under static graph mode (#41936)

* [Phi]Fix expand_sig infershape BUG under static graph mode

* [Phi]Fix expand_sig infershape BUG under static graph mode

* [Phi]Fix unittest

* [Phi]Fix unittest

* [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml (#41920)

* [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml

* add unittest for full_like

* fix unittest
上级 2ea5e02c
...@@ -37,9 +37,15 @@ inline bool NeedTransformDataType(const DataType& input, ...@@ -37,9 +37,15 @@ inline bool NeedTransformDataType(const DataType& input,
inline bool NeedTransformPlace(const paddle::platform::Place& input, inline bool NeedTransformPlace(const paddle::platform::Place& input,
const Backend& target, const Backend& target,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
bool ret = // NOTE(dev): The default value of TransformFlag is True, if it is set with
input.GetType() == AllocationType::GPUPINNED || // False
(transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND && // somewhere such as api.yaml or backward.yaml that means we should skip data
// transform. Because "stop_transform_" has highest priority.
if (!transform_flag.need_trans_backend()) {
return false;
}
bool ret = input.GetType() == AllocationType::GPUPINNED ||
(target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) != phi::TransToPhiBackend(input) !=
(target != Backend::GPUDNN ? target : Backend::GPU)); (target != Backend::GPUDNN ? target : Backend::GPU));
return ret; return ret;
......
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
namespace phi { namespace phi {
KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) {
const auto& shape = paddle::any_cast<std::vector<int>>(ctx.Attr("shape"));
// Infer output shape by Attr("shape") in CompileTime if it is specified.
if (!ctx.IsRuntime() && !shape.empty()) {
return KernelSignature("expand", {"X"}, {"shape"}, {"Out"});
}
if (ctx.HasInput("Shape")) { if (ctx.HasInput("Shape")) {
return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"}); return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) { } else if (ctx.InputSize("expand_shapes_tensor") > 0) {
...@@ -27,6 +32,12 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -27,6 +32,12 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
const auto& shape = paddle::any_cast<std::vector<int>>(ctx.Attr("shape"));
// Infer output shape by Attr("shape") in CompileTime if it is specified.
if (!ctx.IsRuntime() && !shape.empty()) {
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
}
if (ctx.HasInput("Shape")) { if (ctx.HasInput("Shape")) {
return KernelSignature("expand_grad", return KernelSignature("expand_grad",
{"X", GradVarName("Out")}, {"X", GradVarName("Out")},
......
...@@ -231,6 +231,18 @@ class TestExpandV2API(unittest.TestCase): ...@@ -231,6 +231,18 @@ class TestExpandV2API(unittest.TestCase):
assert np.array_equal(res_3, np.tile(input, (1, 1))) assert np.array_equal(res_3, np.tile(input, (1, 1)))
class TestExpandInferShape(unittest.TestCase):
def test_shape_with_var(self):
with program_guard(Program(), Program()):
x = paddle.static.data(shape=[-1, 1, 3], name='x')
fake_var = paddle.randn([2, 3])
target_shape = [
-1, paddle.shape(fake_var)[0], paddle.shape(fake_var)[1]
]
out = paddle.expand(x, shape=target_shape)
self.assertListEqual(list(out.shape), [-1, -1, -1])
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import _test_eager_guard
class TestFullOp(unittest.TestCase): class TestFullOp(unittest.TestCase):
...@@ -133,5 +134,19 @@ class TestFullLikeOp3(TestFullLikeOp1): ...@@ -133,5 +134,19 @@ class TestFullLikeOp3(TestFullLikeOp1):
self.dtype = np.int64 self.dtype = np.int64
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFullLikeOp4(unittest.TestCase):
def test_skip_data_transform(self):
paddle.disable_static()
with _test_eager_guard():
x = paddle.to_tensor(
[1., 2., 3., 4.], place=paddle.CUDAPinnedPlace())
out = paddle.full_like(x, 1.)
self.assertTrue(
(out.numpy() == np.ones([4]).astype(np.float32)).all(), True)
paddle.enable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -781,6 +781,8 @@ ...@@ -781,6 +781,8 @@
param : [x, value, dtype] param : [x, value, dtype]
data_type : dtype > x data_type : dtype > x
backend : place > x backend : place > x
data_transform :
skip_transform : x
- api : gather - api : gather
args : (Tensor x, Tensor index, Scalar axis=0) args : (Tensor x, Tensor index, Scalar axis=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册