未验证 提交 e6aacd1e 编写于 作者: W wangguanqun 提交者: GitHub

add trainer desc config to distributed strategy (#34457)

* add trainer desc config to distributed strategy

* code style modified
上级 41c4f723
......@@ -146,6 +146,13 @@ message AsyncConfig {
optional int32 use_ps_gpu = 12 [ default = 0 ];
}
message TrainerDescConfig {
optional string dump_fields_path = 1;
repeated string dump_fields = 2;
repeated string dump_param = 3;
repeated string stat_var_names = 4;
}
message PipelineConfig {
optional int32 micro_batch_size = 1 [ default = 1 ];
optional int32 accumulate_steps = 2 [ default = 1 ];
......@@ -206,6 +213,7 @@ message DistributedStrategy {
optional ShardingConfig sharding_configs = 111;
optional HybridConfig hybrid_configs = 112;
optional TensorParallelConfig tensor_parallel_configs = 113;
optional TrainerDescConfig trainer_desc_configs = 114;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
optional GradientScaleConfig gradient_scale_configs = 203;
......
......@@ -360,6 +360,45 @@ class DistributedStrategy(object):
"a_sync_configs")
assign_configs_value(self.strategy.a_sync_configs, configs)
@property
def trainer_desc_configs(self):
"""
Set trainer desc configurations.
**Notes**:
dump_fields_path(str): the path of dump fields
dump_fields(list(str)): the fields that you want to dump
dump_param(list(str)): the param that you want to dump
stat_var_names(list(str)):
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
role_maker = fleet.PaddleCloudRoleMaker()
fleet.init(role_maker)
strategy = fleet.DistributedStrategy()
configs = {"dump_fields_path": "./dump_data", "dump_fields": ["xxx", "yyy"]}
strategy.trainer_desc_configs = configs
# code block for defining loss and local optimizer
# sgd = fleet.distributed_optimizer(optimizer, strategy)
"""
return get_msg_dict(self.strategy.trainer_desc_configs)
@trainer_desc_configs.setter
@is_strict_auto
def trainer_desc_configs(self, configs):
check_configs_key(self.strategy.trainer_desc_configs, configs,
"trainer_desc_configs")
assign_configs_value(self.strategy.trainer_desc_configs, configs)
@property
def amp(self):
"""
......
......@@ -1476,6 +1476,14 @@ class Fleet(object):
context["graph_optimize_ops"] = optimize_ops
context["graph_optimize_grads"] = params_grads
program = paddle.static.default_main_program()
opt_info = {}
opt_info["mpi_size"] = self.worker_num()
opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items():
opt_info[k] = v
program._fleet_opt = opt_info
if self._runtime_handle is None:
self._runtime_handle = RuntimeFactory()._create_runtime(context)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import unittest
import paddle
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid.transpiler.details.program_utils as pu
paddle.enable_static()
class TestDistStrategyTrainerDescConfig(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_PSERVER_NUMS"] = "2"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
def test_trainer_desc_config(self):
os.environ["TRAINING_ROLE"] = "TRAINER"
import paddle.distributed.fleet as fleet
fleet.init(role_maker.PaddleCloudRoleMaker())
x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32')
y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = paddle.fluid.layers.square_error_cost(input=x, label=y)
avg_cost = paddle.fluid.layers.mean(cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
config = {
"dump_fields_path": "dump_data",
"dump_fields": ["xxx", "yyy"],
"dump_param": []
}
strategy.trainer_desc_configs = config
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
program = paddle.static.default_main_program()
self.assertEqual(program._fleet_opt["dump_fields_path"], "dump_data")
self.assertEqual(len(program._fleet_opt["dump_fields"]), 2)
self.assertEqual(len(program._fleet_opt["dump_param"]), 0)
self.assertEqual(program._fleet_opt["mpi_size"],
int(os.environ["PADDLE_TRAINERS_NUM"]))
if __name__ == "__main__":
unittest.main()
......@@ -255,6 +255,19 @@ class TestStrategyConfig(unittest.TestCase):
strategy.a_sync_configs = configs
self.assertEqual(strategy.a_sync_configs["k_steps"], 1000)
def test_trainer_desc_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {
"dump_fields_path": "dump_data",
"dump_fields": ["xxx", "yyy"],
"dump_param": []
}
strategy.trainer_desc_configs = configs
self.assertEqual(strategy.trainer_desc_configs["dump_fields_path"],
"dump_data")
self.assertEqual(len(strategy.trainer_desc_configs["dump_fields"]), 2)
self.assertEqual(len(strategy.trainer_desc_configs["dump_param"]), 0)
def test_elastic(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.elastic = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册