未验证 提交 d3da0ef9 编写于 作者: W Wu Yi 提交者: GitHub

Fix dist train with rmsprop (#12649)

* fix dist train with rmsprop

* add rmsprop transpiler test

* update by comment
上级 989cae25
......@@ -536,5 +536,35 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
class TestRMSPropOptimizer(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
pserver2, startup2 = self.get_pserver(self.pserver2_ep)
self.assertEqual(len(pserver.blocks), 3)
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[1].ops],
["sum", "scale", "rmsprop"])
# the variable #fc_w will be split into two blocks
fc_w_var = startup.global_block().var("fc_w.block1")
self.assertEqual(fc_w_var.shape, (500, 1000))
moment_var = startup.global_block().var("momentum_1")
self.assertEqual(moment_var.shape, (500, 1000))
if __name__ == "__main__":
unittest.main()
......@@ -1182,18 +1182,39 @@ class DistributeTranspiler(object):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
def _get_param_block(opt_op):
# param is already created on global program
param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, opt_op.input("Param")[0]):
param_block = p
break
return param_block
for key in opt_op.input_names:
if key == "Grad":
new_inputs[key] = merged_var
# For RMSProp optimizer
elif key == "Moment" or key == "MeanSquare":
param_block = _get_param_block(opt_op)
if not param_block:
return
moment_var = origin_program.global_block().vars[opt_op.input(
key)[0]]
tmpvar = pserver_block.create_var(
name=moment_var.name,
persistable=moment_var.persistable,
dtype=moment_var.dtype,
# change to use same shape as param
# TODO(typhoonzero): didn't append .block in the var name,
# may affect checkpoint saving? Need to verify.
shape=param_block.shape)
new_inputs[key] = tmpvar
elif key == "Param":
# param is already created on global program
param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, opt_op.input(key)[0]):
param_block = p
break
param_block = _get_param_block(opt_op)
if not param_block:
return
tmpvar = pserver_block.create_var(
......@@ -1219,7 +1240,7 @@ class DistributeTranspiler(object):
for key in opt_op.input_names:
new_shape = None
if key in ["Param", "Grad", "LearningRate"]:
if key in ["Param", "Grad", "LearningRate", "Moment", "MeanSquare"]:
continue
var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册