未验证 提交 8d2896f1 编写于 作者: D Dong Daxiang 提交者: GitHub

【paddle.fleet】Fleet run graph in Executor and add two more strategies (#25844)

* split meta optimizer files
* add graph execution in execution, update two properties in DistributedStrategy, unit tests for these features
上级 6486fe8a
......@@ -92,6 +92,8 @@ message DistributedStrategy {
optional int32 hierarchical_allreduce_inter_nranks = 16 [ default = 1 ];
optional bool sync_batch_norm = 17 [ default = false ];
optional bool fuse_all_reduce_ops = 18 [ default = true ];
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
// optional bool enable_backward_optimizer_op_deps = 19 [ default = true ];
optional RecomputeConfig recompute_configs = 101;
......
......@@ -378,6 +378,30 @@ class DistributedStrategy(object):
else:
print("WARNING: fuse_all_reduce_ops should have value of bool type")
@property
def fuse_grad_size_in_MB(self):
return self.strategy.fuse_grad_size_in_MB
@fuse_grad_size_in_MB.setter
def fuse_grad_size_in_MB(self, value):
if isinstance(value, int):
self.strategy.fuse_grad_size_in_MB = value
else:
print("WARNING: fuse_grad_size_in_MB should have value of int type")
@property
def _fuse_grad_size_in_TFLOPS(self):
return self.strategy.fuse_grad_size_in_TFLOPS
@_fuse_grad_size_in_TFLOPS.setter
def _fuse_grad_size_in_TFLOPS(self, value):
if isinstance(value, float):
self.strategy.fuse_grad_size_in_TFLOPS = value
else:
print(
"WARNING: fuse_grad_size_in_TFLOPS should have value of float type"
)
@property
def nccl_comm_num(self):
return self.strategy.nccl_comm_num
......
......@@ -327,6 +327,12 @@ class Fleet(object):
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
else:
optimize_ops, params_grads = self.user_defined_optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
if graph_optimizer:
optimizer_ops, params_grads = graph_optimizer.minimize(
......@@ -338,7 +344,6 @@ class Fleet(object):
# if a graph optimizer takes effect, mostly
# optimizers_ops and params_grads are None
# i.e. users can not modify current computation graph anymore
if self._runtime_handle is None:
self._runtime_handle = RuntimeFactory()._create_runtime(
valid_strategy, self._role_maker, optimize_ops, params_grads)
......
......@@ -55,10 +55,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id]
trainer_endpoints_env = ",".join(trainer_endpoints)
trainers_num = self.role_maker.worker_num()
if trainer_id == 0:
other_trainer_endpoints = trainer_endpoints[:]
other_trainer_endpoints.remove(current_endpoint)
wait_server_ready(other_trainer_endpoints)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, build_strategy.nccl_comm_num):
......
......@@ -1154,6 +1154,23 @@ class Executor(object):
# For backward compatibility, run directly.
if not compiled:
# In distributed training, the compiled program is saved in Program._graph
has_compiled_graph = isinstance(program._graph,
compiler.CompiledProgram)
if has_compiled_graph:
program._graph._compile(scope, self.place)
# _graph in program does not support inference since the _graph is optimized
# through optimizer.minimize function and should not be used as inference graph
# assert not program._graph._is_inference
return self._run_parallel(
program._graph,
scope=scope,
feed=feed,
fetch_list=fetch_list,
fetch_var_name=fetch_var_name,
return_numpy=return_numpy,
return_merged=return_merged)
return self._run_program(
program,
feed=feed,
......
......@@ -31,11 +31,13 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_checkpoint)
list(APPEND MIXED_DIST_TEST_OPS test_collective_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_base)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach()
......@@ -364,11 +366,13 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_communicator_sync MODULES test_communicator_sync ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1)
py_test_modules(test_collective_optimizer MODULES test_collective_optimizer)
if(NOT APPLE)
py_test_modules(test_fleet_base MODULES test_fleet_base ENVS ${dist_ENVS})
py_test_modules(test_fleet_meta_optimizer MODULES test_fleet_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS})
py_test_modules(test_fleet_base MODULES test_fleet_base ENVS ${dist_ENVS})
py_test_modules(test_fleet_recompute_meta_optimizer MODULES test_fleet_recompute_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_graph_execution_meta_optimizer MODULES test_fleet_graph_execution_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS})
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS})
if(NOT WIN32)
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
endif(NOT WIN32)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing import Pool, Process
import os
def launch_func(func, env_dict):
for key in env_dict:
os.environ[key] = env_dict[key]
proc = Process(target=func)
return proc
......@@ -54,7 +54,6 @@ class TestStrategyConfig(unittest.TestCase):
configs = {"checkpoints": ["x", "y"]}
strategy.recompute_configs = configs
self.assertEqual(len(strategy.recompute_configs["checkpoints"]), 2)
print(strategy.recompute_configs)
def test_pipeline(self):
strategy = paddle.fleet.DistributedStrategy()
......@@ -145,6 +144,20 @@ class TestStrategyConfig(unittest.TestCase):
strategy.fuse_all_reduce_ops = "True"
self.assertEqual(strategy.fuse_all_reduce_ops, False)
def test_fuse_grad_size_in_MB(self):
strategy = paddle.fleet.DistributedStrategy()
strategy.fuse_grad_size_in_MB = 50
self.assertEqual(strategy.fuse_grad_size_in_MB, 50)
strategy.fuse_grad_size_in_MB = "40"
self.assertEqual(strategy.fuse_grad_size_in_MB, 50)
def test_fuse_grad_size_in_TFLOPS(self):
strategy = paddle.fleet.DistributedStrategy()
strategy._fuse_grad_size_in_TFLOPS = 0.1
self.assertGreater(strategy._fuse_grad_size_in_TFLOPS, 0.09)
strategy._fuse_grad_size_in_TFLOPS = "0.3"
self.assertGreater(strategy._fuse_grad_size_in_TFLOPS, 0.09)
def test_gradient_merge(self):
strategy = paddle.fleet.DistributedStrategy()
strategy.gradient_merge = True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import os
from launch_function_helper import launch_func
class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
def test_graph_execution_optimizer_not_apply(self):
node_a = {
"PADDLE_TRAINER_ID": "0",
"PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36003",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36003,127.0.0.1:36004",
"http_proxy": "",
"https_proxy": ""
}
node_b = {
"PADDLE_TRAINER_ID": "1",
"PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36004",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36003,127.0.0.1:36004",
"http_proxy": "",
"https_proxy": ""
}
def node_func():
import paddle.fleet as fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
proc_a = launch_func(node_func, node_a)
proc_a.start()
proc_b = launch_func(node_func, node_b)
proc_b.start()
proc_a.join()
proc_b.join()
def test_graph_execution_optimizer(self):
node_a = {
"PADDLE_TRAINER_ID": "0",
"PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36001",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002",
"http_proxy": "",
"https_proxy": ""
}
node_b = {
"PADDLE_TRAINER_ID": "1",
"PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36002",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002",
"http_proxy": "",
"https_proxy": ""
}
def node_func():
import paddle.fleet as fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.fleet.DistributedStrategy()
strategy.nccl_comm_num = 2
strategy.sync_nccl_allreduce = True
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
import numpy as np
def gen_data():
return {
"x": np.random.random(size=(128, 32)).astype('float32'),
"y": np.random.randint(
2, size=(128, 1)).astype('int64')
}
for i in range(10):
cost_val = exe.run(feed=gen_data(), fetch_list=[avg_cost.name])
print("cost of step[{}] = {}".format(i, cost_val))
proc_a = launch_func(node_func, node_a)
proc_a.start()
proc_b = launch_func(node_func, node_b)
proc_b.start()
proc_a.join()
proc_b.join()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import os
from launch_function_helper import launch_func
class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
def test_graph_execution_optimizer(self):
node_a = {
"PADDLE_TRAINER_ID": "0",
"PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36001",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002",
"http_proxy": "",
"https_proxy": ""
}
node_b = {
"PADDLE_TRAINER_ID": "1",
"PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36002",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002",
"http_proxy": "",
"https_proxy": ""
}
def node_func():
import paddle.fleet as fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.fleet.DistributedStrategy()
strategy.nccl_comm_num = 2
strategy.sync_nccl_allreduce = True
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
import numpy as np
def gen_data():
return {
"x": np.random.random(size=(128, 32)).astype('float32'),
"y": np.random.randint(
2, size=(128, 1)).astype('int64')
}
for i in range(10):
cost_val = exe.run(feed=gen_data(), fetch_list=[avg_cost.name])
print("cost of step[{}] = {}".format(i, cost_val))
proc_a = launch_func(node_func, node_a)
proc_a.start()
# just for coverage
for key in node_b:
os.environ[key] = node_b[key]
node_func()
proc_a.join()
if __name__ == "__main__":
unittest.main()
......@@ -17,7 +17,7 @@ import paddle
import os
class TestFleetMetaOptimizer(unittest.TestCase):
class TestFleetRecomputeMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
......@@ -25,49 +25,6 @@ class TestFleetMetaOptimizer(unittest.TestCase):
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
def test_graph_execution_optimizer_not_apply(self):
import paddle.fleet as fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker()
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
def test_graph_execution_optimizer(self):
import paddle.fleet as fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.fleet.DistributedStrategy()
strategy.nccl_comm_num = 2
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
def test_recompute_optimizer(self):
import paddle.fleet as fleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册