未验证 提交 dcdd18ae 编写于 作者: T tangwei12 提交者: GitHub

fix bug with half (#22378) (#22415)

* fix bug with half communicator
上级 f0431607
......@@ -271,13 +271,21 @@ class DistributedTranspiler(Fleet):
elif isinstance(config, DistributeTranspilerConfig):
if config.sync_mode:
self._transpile_config = SyncStrategy()
elif config.geo_sgd_mode:
else:
if config.runtime_split_send_recv:
if config.geo_sgd_mode:
self._transpile_config = GeoStrategy(
config.geo_sgd_need_push_nums)
elif config.runtime_split_send_recv and config.half_async:
elif config.half_async:
self._transpile_config = HalfAsyncStrategy()
else:
self._transpile_config = AsyncStrategy()
else:
self._transpile_config = HalfAsyncStrategy()
# for half_async compatibility
config.half_async = True
config.runtime_split_send_recv = True
self._transpile_config.set_program_config(config)
else:
raise TypeError(
......@@ -359,7 +367,7 @@ class TranspilerOptimizer(DistributedOptimizer):
"In {} mode, strategy must be an instance of DistributeTranspilerConfig, SyncStrategy, HalfAsyncStrategy, AsyncStrategy, or GeoStrategy".
format(fleet._mode))
else:
self._strategy = DistributedStrategy()
self._strategy = StrategyFactory.create_sync_strategy()
def backward(self,
loss,
......
......@@ -48,20 +48,20 @@ class TrainerRuntimeConfig(object):
def get_communicator_flags(self):
_communicator_flags = dict()
_communicator_flags[
"communicator_max_merge_var_num"] = self.max_merge_var_num
_communicator_flags[
"communicator_send_queue_size"] = self.send_queue_size
_communicator_flags[
"communicator_independent_recv_thread"] = self.independent_recv_thread
_communicator_flags[
"communicator_min_send_grad_num_before_recv"] = self.min_send_grad_num_before_recv
_communicator_flags[
"communicator_thread_pool_size"] = self.thread_pool_size
_communicator_flags[
"communicator_send_wait_times"] = self.send_wait_times
_communicator_flags[
"communicator_is_sgd_optimizer"] = self.is_sgd_optimizer
_communicator_flags["communicator_max_merge_var_num"] = str(
self.max_merge_var_num)
_communicator_flags["communicator_send_queue_size"] = str(
self.send_queue_size)
_communicator_flags["communicator_independent_recv_thread"] = str(
self.independent_recv_thread)
_communicator_flags["communicator_min_send_grad_num_before_recv"] = str(
self.min_send_grad_num_before_recv)
_communicator_flags["communicator_thread_pool_size"] = str(
self.thread_pool_size)
_communicator_flags["communicator_send_wait_times"] = str(
self.send_wait_times)
_communicator_flags["communicator_is_sgd_optimizer"] = str(
self.is_sgd_optimizer)
return _communicator_flags
def __repr__(self):
......
......@@ -16,6 +16,8 @@ import unittest
import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, StrategyFactory
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
import os
......@@ -105,7 +107,7 @@ class TestStrategyFactor(unittest.TestCase):
self.assertIn('communicator_send_queue_size',
trainer_communicator_flags)
self.assertEqual(
trainer_communicator_flags['communicator_send_queue_size'], 100)
trainer_communicator_flags['communicator_send_queue_size'], '100')
# test set_trainer_runtime_config exception
trainer_runtime_config_dict['unknown'] = None
......@@ -166,5 +168,37 @@ class TestStrategyFactor(unittest.TestCase):
server_runtime_config_illegal)
class TestCreateDefaultStrategy(unittest.TestCase):
def test_default_strategy(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
optimizer = fluid.optimizer.SGD(0.0001)
optimizer = fleet.distributed_optimizer(optimizer)
class TestHalfAsyncStrategy(unittest.TestCase):
def test_half_async_strategy(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
half_async_config = DistributeTranspilerConfig()
half_async_config.sync_mode = False
half_async_config.geo_sgd_mode = False
half_async_config.runtime_split_send_recv = False
optimizer = fluid.optimizer.SGD(0.0001)
optimizer = fleet.distributed_optimizer(optimizer, half_async_config)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册