未验证 提交 d2cdc7e3 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix]Fix segment fault in order setting (#52293)

* fix bug in proto

* add utest
上级 155018ee
...@@ -55,7 +55,6 @@ message HybridConfig { ...@@ -55,7 +55,6 @@ message HybridConfig {
optional int32 mp_degree = 2 [ default = 1 ]; optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ];
repeated string order = 5 ;
} }
message AMPConfig { message AMPConfig {
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import google.protobuf import google.protobuf
import google.protobuf.text_format import google.protobuf.text_format
...@@ -149,6 +151,7 @@ class DistributedStrategy: ...@@ -149,6 +151,7 @@ class DistributedStrategy:
if _global_flags().is_public(key): if _global_flags().is_public(key):
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key]) self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])
self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp']
self.__lock_attr = True self.__lock_attr = True
logger.info("distributed strategy initialized") logger.info("distributed strategy initialized")
...@@ -1691,8 +1694,13 @@ class DistributedStrategy: ...@@ -1691,8 +1694,13 @@ class DistributedStrategy:
@hybrid_configs.setter @hybrid_configs.setter
def hybrid_configs(self, configs): def hybrid_configs(self, configs):
hybrid_config = copy.deepcopy(configs)
if "order" in hybrid_config:
self.hybrid_parallel_order = hybrid_config["order"]
hybrid_config.pop('order')
check_configs_key( check_configs_key(
self.strategy.hybrid_configs, configs, "hybrid_configs" self.strategy.hybrid_configs, hybrid_config, "hybrid_configs"
) )
assign_configs_value(self.strategy.hybrid_configs, configs) assign_configs_value(self.strategy.hybrid_configs, configs)
......
...@@ -412,9 +412,7 @@ class Fleet: ...@@ -412,9 +412,7 @@ class Fleet:
"mp": ['model', self.mp_degree], "mp": ['model', self.mp_degree],
} }
order = self.hybrid_configs["order"] order = self._user_defined_strategy.hybrid_parallel_order
if not order:
order = ['dp', 'pp', 'sharding', 'mp']
if order[:].sort() != list(d_hybrid_degree.keys())[:].sort(): if order[:].sort() != list(d_hybrid_degree.keys())[:].sort():
raise AssertionError( raise AssertionError(
'The order of hybrid_config setting is incorrect.' 'The order of hybrid_config setting is incorrect.'
......
...@@ -84,6 +84,21 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -84,6 +84,21 @@ class TestStrategyConfig(unittest.TestCase):
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
def test_hybrid_parallel_configs_order(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 4,
"order": ['sharding', 'mp', 'dp', 'pp'],
}
self.assertEqual(strategy.hybrid_configs["dp_degree"], 1)
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
self.assertEqual(
strategy.hybrid_parallel_order, ['sharding', 'mp', 'dp', 'pp']
)
def test_localsgd(self): def test_localsgd(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.localsgd = True strategy.localsgd = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册