未验证 提交 d2cdc7e3 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix]Fix segment fault in order setting (#52293)

* fix bug in proto

* add utest
上级 155018ee
......@@ -55,7 +55,6 @@ message HybridConfig {
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
repeated string order = 5 ;
}
message AMPConfig {
......
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import google.protobuf
import google.protobuf.text_format
......@@ -149,6 +151,7 @@ class DistributedStrategy:
if _global_flags().is_public(key):
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])
self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp']
self.__lock_attr = True
logger.info("distributed strategy initialized")
......@@ -1691,8 +1694,13 @@ class DistributedStrategy:
@hybrid_configs.setter
def hybrid_configs(self, configs):
hybrid_config = copy.deepcopy(configs)
if "order" in hybrid_config:
self.hybrid_parallel_order = hybrid_config["order"]
hybrid_config.pop('order')
check_configs_key(
self.strategy.hybrid_configs, configs, "hybrid_configs"
self.strategy.hybrid_configs, hybrid_config, "hybrid_configs"
)
assign_configs_value(self.strategy.hybrid_configs, configs)
......
......@@ -412,9 +412,7 @@ class Fleet:
"mp": ['model', self.mp_degree],
}
order = self.hybrid_configs["order"]
if not order:
order = ['dp', 'pp', 'sharding', 'mp']
order = self._user_defined_strategy.hybrid_parallel_order
if order[:].sort() != list(d_hybrid_degree.keys())[:].sort():
raise AssertionError(
'The order of hybrid_config setting is incorrect.'
......
......@@ -84,6 +84,21 @@ class TestStrategyConfig(unittest.TestCase):
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
def test_hybrid_parallel_configs_order(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 4,
"order": ['sharding', 'mp', 'dp', 'pp'],
}
self.assertEqual(strategy.hybrid_configs["dp_degree"], 1)
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
self.assertEqual(
strategy.hybrid_parallel_order, ['sharding', 'mp', 'dp', 'pp']
)
def test_localsgd(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.localsgd = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册