未验证 提交 eca6638c 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

modify reshape to reshape2 in paddle.nn.initializer.dirac (#42396)

上级 683f152a
...@@ -1037,11 +1037,11 @@ class TestDiracInitializer1(unittest.TestCase): ...@@ -1037,11 +1037,11 @@ class TestDiracInitializer1(unittest.TestCase):
block = start_prog.global_block() block = start_prog.global_block()
self.assertEqual(len(block.ops), self.num_ops) self.assertEqual(len(block.ops), self.num_ops)
self.assertEqual(block.ops[0].type, 'fill_constant') self.assertEqual(block.ops[0].type, 'fill_constant')
self.assertEqual(block.ops[1].type, 'reshape') self.assertEqual(block.ops[1].type, 'reshape2')
self.assertEqual(block.ops[2].type, 'assign_value') self.assertEqual(block.ops[2].type, 'assign_value')
self.assertEqual(block.ops[3].type, 'assign_value') self.assertEqual(block.ops[3].type, 'assign_value')
self.assertEqual(block.ops[4].type, 'scatter') self.assertEqual(block.ops[4].type, 'scatter')
self.assertEqual(block.ops[5].type, 'reshape') self.assertEqual(block.ops[5].type, 'reshape2')
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(start_prog) exe.run(start_prog)
......
...@@ -168,14 +168,22 @@ class Dirac(Initializer): ...@@ -168,14 +168,22 @@ class Dirac(Initializer):
idx_list.append(offset) idx_list.append(offset)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
with fluid.dygraph.no_grad(): with fluid.dygraph.no_grad():
tmp_out = _C_ops.reshape(out_var, 'shape', [-1]) tmp_out, _ = _C_ops.reshape2(out_var, None, 'shape', [-1])
tmp_out._share_underline_tensor_to(out_var) tmp_out._share_underline_tensor_to(out_var)
else: else:
x_shape = block.create_var(
name=unique_name.generate(".".join([out_var.name, "XShape"])),
dtype=out_var.dtype,
shape=out_var.shape,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
block.append_op( block.append_op(
type="reshape", type="reshape2",
inputs={"X": out_var}, inputs={"X": out_var},
attrs={'shape': [-1]}, attrs={'shape': [-1]},
outputs={"Out": out_var}, outputs={"Out": out_var,
"XShape": x_shape},
stop_gradient=True) stop_gradient=True)
index_tensor = block.create_var( index_tensor = block.create_var(
...@@ -229,7 +237,8 @@ class Dirac(Initializer): ...@@ -229,7 +237,8 @@ class Dirac(Initializer):
tmp_out = _C_ops.final_state_scatter(out_var, index_tensor, tmp_out = _C_ops.final_state_scatter(out_var, index_tensor,
value_tensor, True) value_tensor, True)
tmp_out._share_underline_tensor_to(out_var) tmp_out._share_underline_tensor_to(out_var)
tmp_reshape_out = _C_ops.reshape(out_var, 'shape', origin_shape) tmp_reshape_out, _ = _C_ops.reshape2(out_var, None, 'shape',
origin_shape)
tmp_reshape_out._share_underline_tensor_to(out_var) tmp_reshape_out._share_underline_tensor_to(out_var)
if var.dtype != VarDesc.VarType.FP32: if var.dtype != VarDesc.VarType.FP32:
tmp_cast_out = _C_ops.cast(out_var, 'in_dtype', tmp_cast_out = _C_ops.cast(out_var, 'in_dtype',
...@@ -248,11 +257,19 @@ class Dirac(Initializer): ...@@ -248,11 +257,19 @@ class Dirac(Initializer):
attrs={'overwrite': True}, attrs={'overwrite': True},
outputs={"Out": out_var}, outputs={"Out": out_var},
stop_gradient=True) stop_gradient=True)
x_shape = block.create_var(
name=unique_name.generate(".".join([out_var.name, "XShape"])),
dtype=out_var.dtype,
shape=out_var.shape,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
block.append_op( block.append_op(
type="reshape", type="reshape2",
inputs={"X": out_var}, inputs={"X": out_var},
attrs={'shape': origin_shape}, attrs={'shape': origin_shape},
outputs={"Out": out_var}, outputs={"Out": out_var,
"XShape": x_shape},
stop_gradient=True) stop_gradient=True)
if var.dtype != VarDesc.VarType.FP32: if var.dtype != VarDesc.VarType.FP32:
block.append_op( block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册