未验证 提交 dcdd18ae 编写于 作者: T tangwei12 提交者: GitHub

fix bug with half (#22378) (#22415)

* fix bug with half communicator
上级 f0431607
...@@ -271,13 +271,21 @@ class DistributedTranspiler(Fleet): ...@@ -271,13 +271,21 @@ class DistributedTranspiler(Fleet):
elif isinstance(config, DistributeTranspilerConfig): elif isinstance(config, DistributeTranspilerConfig):
if config.sync_mode: if config.sync_mode:
self._transpile_config = SyncStrategy() self._transpile_config = SyncStrategy()
elif config.geo_sgd_mode:
self._transpile_config = GeoStrategy(
config.geo_sgd_need_push_nums)
elif config.runtime_split_send_recv and config.half_async:
self._transpile_config = HalfAsyncStrategy()
else: else:
self._transpile_config = AsyncStrategy() if config.runtime_split_send_recv:
if config.geo_sgd_mode:
self._transpile_config = GeoStrategy(
config.geo_sgd_need_push_nums)
elif config.half_async:
self._transpile_config = HalfAsyncStrategy()
else:
self._transpile_config = AsyncStrategy()
else:
self._transpile_config = HalfAsyncStrategy()
# for half_async compatibility
config.half_async = True
config.runtime_split_send_recv = True
self._transpile_config.set_program_config(config) self._transpile_config.set_program_config(config)
else: else:
raise TypeError( raise TypeError(
...@@ -359,7 +367,7 @@ class TranspilerOptimizer(DistributedOptimizer): ...@@ -359,7 +367,7 @@ class TranspilerOptimizer(DistributedOptimizer):
"In {} mode, strategy must be an instance of DistributeTranspilerConfig, SyncStrategy, HalfAsyncStrategy, AsyncStrategy, or GeoStrategy". "In {} mode, strategy must be an instance of DistributeTranspilerConfig, SyncStrategy, HalfAsyncStrategy, AsyncStrategy, or GeoStrategy".
format(fleet._mode)) format(fleet._mode))
else: else:
self._strategy = DistributedStrategy() self._strategy = StrategyFactory.create_sync_strategy()
def backward(self, def backward(self,
loss, loss,
......
...@@ -48,20 +48,20 @@ class TrainerRuntimeConfig(object): ...@@ -48,20 +48,20 @@ class TrainerRuntimeConfig(object):
def get_communicator_flags(self): def get_communicator_flags(self):
_communicator_flags = dict() _communicator_flags = dict()
_communicator_flags[ _communicator_flags["communicator_max_merge_var_num"] = str(
"communicator_max_merge_var_num"] = self.max_merge_var_num self.max_merge_var_num)
_communicator_flags[ _communicator_flags["communicator_send_queue_size"] = str(
"communicator_send_queue_size"] = self.send_queue_size self.send_queue_size)
_communicator_flags[ _communicator_flags["communicator_independent_recv_thread"] = str(
"communicator_independent_recv_thread"] = self.independent_recv_thread self.independent_recv_thread)
_communicator_flags[ _communicator_flags["communicator_min_send_grad_num_before_recv"] = str(
"communicator_min_send_grad_num_before_recv"] = self.min_send_grad_num_before_recv self.min_send_grad_num_before_recv)
_communicator_flags[ _communicator_flags["communicator_thread_pool_size"] = str(
"communicator_thread_pool_size"] = self.thread_pool_size self.thread_pool_size)
_communicator_flags[ _communicator_flags["communicator_send_wait_times"] = str(
"communicator_send_wait_times"] = self.send_wait_times self.send_wait_times)
_communicator_flags[ _communicator_flags["communicator_is_sgd_optimizer"] = str(
"communicator_is_sgd_optimizer"] = self.is_sgd_optimizer self.is_sgd_optimizer)
return _communicator_flags return _communicator_flags
def __repr__(self): def __repr__(self):
......
...@@ -16,6 +16,8 @@ import unittest ...@@ -16,6 +16,8 @@ import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, StrategyFactory
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
import os import os
...@@ -105,7 +107,7 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -105,7 +107,7 @@ class TestStrategyFactor(unittest.TestCase):
self.assertIn('communicator_send_queue_size', self.assertIn('communicator_send_queue_size',
trainer_communicator_flags) trainer_communicator_flags)
self.assertEqual( self.assertEqual(
trainer_communicator_flags['communicator_send_queue_size'], 100) trainer_communicator_flags['communicator_send_queue_size'], '100')
# test set_trainer_runtime_config exception # test set_trainer_runtime_config exception
trainer_runtime_config_dict['unknown'] = None trainer_runtime_config_dict['unknown'] = None
...@@ -166,5 +168,37 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -166,5 +168,37 @@ class TestStrategyFactor(unittest.TestCase):
server_runtime_config_illegal) server_runtime_config_illegal)
class TestCreateDefaultStrategy(unittest.TestCase):
def test_default_strategy(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
optimizer = fluid.optimizer.SGD(0.0001)
optimizer = fleet.distributed_optimizer(optimizer)
class TestHalfAsyncStrategy(unittest.TestCase):
def test_half_async_strategy(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
half_async_config = DistributeTranspilerConfig()
half_async_config.sync_mode = False
half_async_config.geo_sgd_mode = False
half_async_config.runtime_split_send_recv = False
optimizer = fluid.optimizer.SGD(0.0001)
optimizer = fleet.distributed_optimizer(optimizer, half_async_config)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册