From c6ebcc7f5ab2053506c17847b5785c5ec9f34aa7 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 26 Sep 2018 15:27:58 +0800 Subject: [PATCH] add decayed_adagrad support for dist train --- .../tests/unittests/test_dist_transpiler.py | 19 +++++++++++++++++++ .../fluid/transpiler/distribute_transpiler.py | 3 +++ 2 files changed, 22 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index ecde407e6d..54a1c68a37 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -264,6 +264,25 @@ class TestLRDecay(TranspilerTest): ]) +class TestDecayedAdagrad(TranspilerTest): + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + opt = fluid.optimizer.DecayedAdagrad(learning_rate=0.1) + opt.minimize(avg_cost) + + def transpiler_test_impl(self): + pserver, startup = self.get_pserver(self.pserver1_ep) + trainer, _ = self.get_trainer() + + class TestLRDecayConditional(TranspilerTest): def net_conf(self): x = fluid.layers.data(name='x', shape=[1000], dtype='float32') diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 43071def7a..e4345198f0 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1430,6 +1430,9 @@ to transpile() call.") elif op_type == "rmsprop": if varkey in ["Moment", "MeanSquare"]: return param_shape + elif op_type == "decayed_adagrad": + if varkey == "Moment": + return param_shape elif op_type == "sgd": pass return orig_shape -- GitLab