From eca6638c599591c69fe40aa196f5fd42db7efbe2 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 29 Apr 2022 20:54:56 +0800 Subject: [PATCH] modify reshape to reshape2 in paddle.nn.initializer.dirac (#42396) --- .../fluid/tests/unittests/test_initializer.py | 4 +-- python/paddle/nn/initializer/dirac.py | 29 +++++++++++++++---- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index 3a9387082e..52137b22a7 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -1037,11 +1037,11 @@ class TestDiracInitializer1(unittest.TestCase): block = start_prog.global_block() self.assertEqual(len(block.ops), self.num_ops) self.assertEqual(block.ops[0].type, 'fill_constant') - self.assertEqual(block.ops[1].type, 'reshape') + self.assertEqual(block.ops[1].type, 'reshape2') self.assertEqual(block.ops[2].type, 'assign_value') self.assertEqual(block.ops[3].type, 'assign_value') self.assertEqual(block.ops[4].type, 'scatter') - self.assertEqual(block.ops[5].type, 'reshape') + self.assertEqual(block.ops[5].type, 'reshape2') exe = paddle.static.Executor() exe.run(start_prog) diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index c7cb1052d2..9c84b01ecb 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -168,14 +168,22 @@ class Dirac(Initializer): idx_list.append(offset) if framework.in_dygraph_mode(): with fluid.dygraph.no_grad(): - tmp_out = _C_ops.reshape(out_var, 'shape', [-1]) + tmp_out, _ = _C_ops.reshape2(out_var, None, 'shape', [-1]) tmp_out._share_underline_tensor_to(out_var) else: + x_shape = block.create_var( + name=unique_name.generate(".".join([out_var.name, "XShape"])), + dtype=out_var.dtype, + shape=out_var.shape, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=True) block.append_op( - type="reshape", + type="reshape2", inputs={"X": out_var}, attrs={'shape': [-1]}, - outputs={"Out": out_var}, + outputs={"Out": out_var, + "XShape": x_shape}, stop_gradient=True) index_tensor = block.create_var( @@ -229,7 +237,8 @@ class Dirac(Initializer): tmp_out = _C_ops.final_state_scatter(out_var, index_tensor, value_tensor, True) tmp_out._share_underline_tensor_to(out_var) - tmp_reshape_out = _C_ops.reshape(out_var, 'shape', origin_shape) + tmp_reshape_out, _ = _C_ops.reshape2(out_var, None, 'shape', + origin_shape) tmp_reshape_out._share_underline_tensor_to(out_var) if var.dtype != VarDesc.VarType.FP32: tmp_cast_out = _C_ops.cast(out_var, 'in_dtype', @@ -248,11 +257,19 @@ class Dirac(Initializer): attrs={'overwrite': True}, outputs={"Out": out_var}, stop_gradient=True) + x_shape = block.create_var( + name=unique_name.generate(".".join([out_var.name, "XShape"])), + dtype=out_var.dtype, + shape=out_var.shape, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=True) block.append_op( - type="reshape", + type="reshape2", inputs={"X": out_var}, attrs={'shape': origin_shape}, - outputs={"Out": out_var}, + outputs={"Out": out_var, + "XShape": x_shape}, stop_gradient=True) if var.dtype != VarDesc.VarType.FP32: block.append_op( -- GitLab