diff --git a/python/paddle/jit/dy2static/basic_api_transformer.py b/python/paddle/jit/dy2static/basic_api_transformer.py index af111b55e79a60116d99f4884dad768dbd78cec4..40c5a5f511bde7accb19e72af665dd173dac7060 100644 --- a/python/paddle/jit/dy2static/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/basic_api_transformer.py @@ -152,8 +152,8 @@ class NameloadJstTransformer(BaseTransformer): """ Can't convert name of function call, bacause this will affect CallTransformer. """ - node.args = [self.generic_visit(arg) for arg in node.args] - node.func = self.generic_visit(node.func) + node.args = [self.visit(arg) for arg in node.args] + node.func = self.visit(node.func) return node def visit_Attribute(self, node): diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 18069d404a938b11b132a6988a435f0a401c3b47..59841ed431f086c01127034e9205be9584c6473d 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. + import unittest import numpy as np import paddle +import paddle.nn.functional as F class TestSetItemBase(unittest.TestCase): @@ -231,5 +233,31 @@ class TestCase14(TestSetItemBase): return (y,) +class TestCase15(TestSetItemBase): + # Test gradient of value tensor + def init_func(self): + def foo(x, H, W): + B, _, _, C = x.shape + pad_list = paddle.zeros([4], dtype="int32") + pad_list[3] = H // 2 + pad_list[1] = W // 2 + + # 问题在这里,进去F.pad以后,pad_list是初始变量而非赋值后的变量 + # 在修改前,赋值前后的变量是同一个,没有问题 + # 修改后,期望接收赋值后的变量,接收赋值前变量结果是不对的 + x = F.pad(x, pad_list, data_format="NHWC") + return x + + return foo + + def run_dygraph(self, func): + # 注释这句看结果diff + x = paddle.ones((1, 6, 6, 3)) + H = paddle.full([1], 6, dtype='int32') + W = paddle.full([1], 6, dtype='int32') + y = func(x, H, W) + return (y,) + + if __name__ == '__main__': unittest.main()