From d3da0ef9f6f9c5b1fe17bdc3d8e4f079ed80e1ee Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Tue, 14 Aug 2018 13:18:27 +0800 Subject: [PATCH] Fix dist train with rmsprop (#12649) * fix dist train with rmsprop * add rmsprop transpiler test * update by comment --- .../tests/unittests/test_dist_transpiler.py | 30 ++++++++++++++++ .../fluid/transpiler/distribute_transpiler.py | 35 +++++++++++++++---- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 3bc5e6ada5..b6f4f0726f 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -536,5 +536,35 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) +class TestRMSPropOptimizer(TranspilerTest): + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + optimizer = fluid.optimizer.RMSProp(learning_rate=0.1) + optimizer.minimize(avg_cost) + return + + def transpiler_test_impl(self): + pserver, startup = self.get_pserver(self.pserver1_ep) + pserver2, startup2 = self.get_pserver(self.pserver2_ep) + + self.assertEqual(len(pserver.blocks), 3) + # block1~2: optimize pass + self.assertEqual([op.type for op in pserver.blocks[1].ops], + ["sum", "scale", "rmsprop"]) + # the variable #fc_w will be split into two blocks + fc_w_var = startup.global_block().var("fc_w.block1") + self.assertEqual(fc_w_var.shape, (500, 1000)) + moment_var = startup.global_block().var("momentum_1") + self.assertEqual(moment_var.shape, (500, 1000)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 63422a6ec5..1dd9578ef8 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1182,18 +1182,39 @@ class DistributeTranspiler(object): program = optimize_block.program pserver_block = program.global_block() new_inputs = dict() + # update param/grad shape first, then other inputs like # moment can use the updated shape + def _get_param_block(opt_op): + # param is already created on global program + param_block = None + for p in self.param_grad_ep_mapping[endpoint]["params"]: + if same_or_split_var(p.name, opt_op.input("Param")[0]): + param_block = p + break + return param_block + for key in opt_op.input_names: if key == "Grad": new_inputs[key] = merged_var + # For RMSProp optimizer + elif key == "Moment" or key == "MeanSquare": + param_block = _get_param_block(opt_op) + if not param_block: + return + moment_var = origin_program.global_block().vars[opt_op.input( + key)[0]] + tmpvar = pserver_block.create_var( + name=moment_var.name, + persistable=moment_var.persistable, + dtype=moment_var.dtype, + # change to use same shape as param + # TODO(typhoonzero): didn't append .block in the var name, + # may affect checkpoint saving? Need to verify. + shape=param_block.shape) + new_inputs[key] = tmpvar elif key == "Param": - # param is already created on global program - param_block = None - for p in self.param_grad_ep_mapping[endpoint]["params"]: - if same_or_split_var(p.name, opt_op.input(key)[0]): - param_block = p - break + param_block = _get_param_block(opt_op) if not param_block: return tmpvar = pserver_block.create_var( @@ -1219,7 +1240,7 @@ class DistributeTranspiler(object): for key in opt_op.input_names: new_shape = None - if key in ["Param", "Grad", "LearningRate"]: + if key in ["Param", "Grad", "LearningRate", "Moment", "MeanSquare"]: continue var = self.origin_program.global_block().vars[opt_op.input(key)[0]] # update accumulator variable shape -- GitLab