未验证 提交 8d2cb14f 编写于 作者: M mapingshuo 提交者: GitHub

support gradient merge with recompute, test=develop (#27834)

* support gradient merge with recompute, test=develop
test=develop
上级 274071a1
......@@ -19,11 +19,12 @@ class GradientMergeOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(GradientMergeOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.wrapped_opt = GM(optimizer)
self.wrapped_opt = None
self.meta_optimizers_white_list = [
"LarsOptimizer",
"LambOptimizer",
"GraphExecutionOptimizer",
"RecomputeOptimizer",
]
self.meta_optimizers_black_list = []
......@@ -31,6 +32,10 @@ class GradientMergeOptimizer(MetaOptimizerBase):
user_defined_strategy):
super(GradientMergeOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _init_wrapped_opt(self):
config = self.user_defined_strategy.gradient_merge_configs
self.wrapped_opt = GM(self.inner_opt)
self.wrapped_opt._set_k_steps(
self.user_defined_strategy.gradient_merge_configs["k_steps"])
self.wrapped_opt._set_avg(
......@@ -49,7 +54,7 @@ class GradientMergeOptimizer(MetaOptimizerBase):
dist_strategy.gradient_merge_configs = {}
def _enable_strategy(self, dist_strategy, context):
# we currently do not support auto-enable gradient merge
# we currently do not support auto-enable GradientMerge
return
def minimize_impl(self,
......@@ -57,6 +62,7 @@ class GradientMergeOptimizer(MetaOptimizerBase):
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._init_wrapped_opt()
optimize_ops, params_grads = \
self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
......
......@@ -118,5 +118,8 @@ class TestFleetMetaOptimizer(unittest.TestCase):
'init_k_steps': 1,
'begin_step': 1,
}
elif name == "gradient_merge":
strategy.gradient_merge = True
strategy.gradient_merge_configs = {"k_steps": 2, "avg": True}
else:
raise NotImplementedError()
......@@ -18,35 +18,36 @@ import os
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
from fleet_meta_optimizer_base import TestFleetMetaOptimizer
class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
paddle.enable_static()
class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer):
def test_gradient_merge_optimizer(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.gradient_merge = True
strategy.gradient_merge_configs = {"k_steps": 2, "avg": True}
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'gradient_merge')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
vars = [x.name for x in train_prog.list_vars()]
with open("main_program", 'w') as f:
f.write(str(train_prog))
self.assertIn('@GradientMerge', ''.join(vars))
def test_recom_gm_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'gradient_merge')
self.set_strategy(strategy, 'recompute')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
vars = [x.name for x in train_prog.list_vars()]
self.assertIn('@GradientMerge', ''.join(vars))
self.assertIn('subprog', ''.join(vars))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册