未验证 提交 eca6638c 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

modify reshape to reshape2 in paddle.nn.initializer.dirac (#42396)

上级 683f152a
......@@ -1037,11 +1037,11 @@ class TestDiracInitializer1(unittest.TestCase):
block = start_prog.global_block()
self.assertEqual(len(block.ops), self.num_ops)
self.assertEqual(block.ops[0].type, 'fill_constant')
self.assertEqual(block.ops[1].type, 'reshape')
self.assertEqual(block.ops[1].type, 'reshape2')
self.assertEqual(block.ops[2].type, 'assign_value')
self.assertEqual(block.ops[3].type, 'assign_value')
self.assertEqual(block.ops[4].type, 'scatter')
self.assertEqual(block.ops[5].type, 'reshape')
self.assertEqual(block.ops[5].type, 'reshape2')
exe = paddle.static.Executor()
exe.run(start_prog)
......
......@@ -168,14 +168,22 @@ class Dirac(Initializer):
idx_list.append(offset)
if framework.in_dygraph_mode():
with fluid.dygraph.no_grad():
tmp_out = _C_ops.reshape(out_var, 'shape', [-1])
tmp_out, _ = _C_ops.reshape2(out_var, None, 'shape', [-1])
tmp_out._share_underline_tensor_to(out_var)
else:
x_shape = block.create_var(
name=unique_name.generate(".".join([out_var.name, "XShape"])),
dtype=out_var.dtype,
shape=out_var.shape,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
block.append_op(
type="reshape",
type="reshape2",
inputs={"X": out_var},
attrs={'shape': [-1]},
outputs={"Out": out_var},
outputs={"Out": out_var,
"XShape": x_shape},
stop_gradient=True)
index_tensor = block.create_var(
......@@ -229,7 +237,8 @@ class Dirac(Initializer):
tmp_out = _C_ops.final_state_scatter(out_var, index_tensor,
value_tensor, True)
tmp_out._share_underline_tensor_to(out_var)
tmp_reshape_out = _C_ops.reshape(out_var, 'shape', origin_shape)
tmp_reshape_out, _ = _C_ops.reshape2(out_var, None, 'shape',
origin_shape)
tmp_reshape_out._share_underline_tensor_to(out_var)
if var.dtype != VarDesc.VarType.FP32:
tmp_cast_out = _C_ops.cast(out_var, 'in_dtype',
......@@ -248,11 +257,19 @@ class Dirac(Initializer):
attrs={'overwrite': True},
outputs={"Out": out_var},
stop_gradient=True)
x_shape = block.create_var(
name=unique_name.generate(".".join([out_var.name, "XShape"])),
dtype=out_var.dtype,
shape=out_var.shape,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
block.append_op(
type="reshape",
type="reshape2",
inputs={"X": out_var},
attrs={'shape': origin_shape},
outputs={"Out": out_var},
outputs={"Out": out_var,
"XShape": x_shape},
stop_gradient=True)
if var.dtype != VarDesc.VarType.FP32:
block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册