未验证 提交 d51daede 编写于 作者: W Wu Yi 提交者: GitHub

add ftrl support for dist train test=develop (#14176)

上级 f37bd035
......@@ -283,6 +283,25 @@ class TestDecayedAdagrad(TranspilerTest):
trainer, _ = self.get_trainer()
class TestFtrl(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
opt = fluid.optimizer.Ftrl(learning_rate=0.1)
opt.minimize(avg_cost)
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
trainer, _ = self.get_trainer()
class TestLRDecayConditional(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
......
......@@ -1456,6 +1456,9 @@ to transpile() call.")
elif op_type == "decayed_adagrad":
if varkey == "Moment":
return param_shape
elif op_type == "ftrl":
if varkey in ["SquaredAccumulator", "LinearAccumulator"]:
return param_shape
elif op_type == "sgd":
pass
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册