From bb5963da14ce6554fcef7a8ae1949b9843fc1b8a Mon Sep 17 00:00:00 2001 From: lilong12 Date: Wed, 16 Jun 2021 22:05:05 +0800 Subject: [PATCH] [CP] add a strategy to run program with fleet (#33511) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add raw program meta optimizer (#32597) * add raw program, test=develop * add precision unitest for executor all reduce (#33339) * fix dp (#33297) Co-authored-by: Yuang Liu Co-authored-by: 李季 <2042519524@qq.com> --- .../framework/distributed_strategy.proto | 1 + .../fleet/base/distributed_strategy.py | 26 +++ .../fleet/meta_optimizers/__init__.py | 1 + .../meta_optimizers/raw_program_optimizer.py | 197 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 12 +- .../dist_fleet_raw_program_optimizer.py | 109 ++++++++++ .../fluid/tests/unittests/test_dist_base.py | 76 ++++++- .../test_dist_fleet_raw_program_optimizer.py | 45 ++++ .../test_fleet_raw_program_meta_optimizer.py | 53 +++++ .../unittests/test_raw_program_optimizer.py | 77 +++++++ 10 files changed, 592 insertions(+), 5 deletions(-) create mode 100755 python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py create mode 100644 python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py create mode 100644 python/paddle/fluid/tests/unittests/test_raw_program_optimizer.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 38831192c8c..181e3b68853 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -175,6 +175,7 @@ message DistributedStrategy { optional float last_comm_group_size_MB = 27 [ default = 1 ]; optional bool find_unused_parameters = 28 [ default = false ]; optional bool tensor_parallel = 29 [ default = false ]; + optional bool without_graph_optimization = 30 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 640bc00cb6c..f9cd623afef 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -827,6 +827,32 @@ class DistributedStrategy(object): "sharding_configs") assign_configs_value(self.strategy.sharding_configs, configs) + @property + def without_graph_optimization(self): + """ + Run program using Executor other than ParallelExecutor. + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.without_graph_optimization = True + + """ + return self.strategy.without_graph_optimization + + @without_graph_optimization.setter + @is_strict_auto + def without_graph_optimization(self, flag): + if isinstance(flag, bool): + self.strategy.without_graph_optimization = flag + else: + print( + "WARNING: without_graph_optimization should have value of bool type" + ) + @property def pipeline(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 827835fde20..1788e044fe8 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -28,3 +28,4 @@ from .sharding_optimizer import ShardingOptimizer from .dygraph_optimizer import HybridParallelOptimizer from .dygraph_optimizer import HybridParallelGradScaler from .tensor_parallel_optimizer import TensorParallelOptimizer +from .raw_program_optimizer import RawProgramOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py new file mode 100755 index 00000000000..b232d8c9c49 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -0,0 +1,197 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from __future__ import print_function +from __future__ import division +import os + +import paddle.fluid as fluid +from paddle.fluid import core, unique_name +from ..base.private_helper_function import wait_server_ready +from .meta_optimizer_base import MetaOptimizerBase +from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_loss_grad_op, is_backward_op, is_optimizer_op + + +class RawProgramOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(RawProgramOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + ] + self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] + self.global_ring_id = 0 + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(RawProgramOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + self.without_graph_optimization = user_defined_strategy.without_graph_optimization + + def _can_apply(self): + if not self.role_maker._is_collective: + return False + + if self.without_graph_optimization == True: + return True + return False + + def _disable_strategy(self, dist_strategy): + dist_strategy.without_graph_optimization = False + + def _enable_strategy(self, dist_strategy, context): + dist_strategy.without_graph_optimization = True + + def _broadcast_params(self, ring_id): + block = self.startup_program.global_block() + param = None + for param in block.iter_parameters(): + if param.is_distributed: + continue + + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + + if not param: return # no parameter on this device + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Forward}) + + def _get_process_group_info(self): + # global ring info + self.global_endpoints = self.endpoints + self.global_rank = self.rank + self.global_nranks = self.nranks + + def _init_process_group(self): + self._get_process_group_info() + collective_helper = CollectiveHelper(self.role_maker, wait_port=False) + # Create global ring for all gpus (ring_id = 0) + collective_helper._init_communicator( + self.startup_program, self.current_endpoint, self.global_endpoints, + self.global_rank, self.global_ring_id, True, self.global_ring_id, + True) + self._broadcast_params(self.global_ring_id) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + self.endpoints = self.role_maker._get_trainer_endpoints() + self.current_endpoint = self.endpoints[self.role_maker._worker_index()] + self.rank = self.role_maker._worker_index() + self.nranks = self.role_maker._worker_num() + if startup_program is None: + startup_program = fluid.default_startup_program() + self.startup_program = startup_program + + block = loss.block + program = block.program + self.main_program = program + + optimize_ops, params_grads = self.inner_opt.minimize( + loss, startup_program, parameter_list, no_grad_set) + if self.nranks == 1: + return optimize_ops, params_grads + self._init_process_group() + + self.main_program = program + if self.nranks > 1: + self._transpile_main_program(loss) + return optimize_ops, params_grads + + def _transpile_main_program(self, loss): + self._insert_loss_grad_ops(loss) + self._insert_allreduce_ops() + + def _insert_loss_grad_ops(self, loss): + """ + In order to keep the learning rate consistent in different numbers of + training workers, we scale the loss grad by the number of workers + """ + block = self.main_program.global_block() + for idx, op in reversed(list(enumerate(block.ops))): + if is_loss_grad_op(op): + loss_grad_var = block.vars[op.output_arg_names[0]] + block._insert_op( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={ + 'scale': 1.0 / self.nranks, + OP_ROLE_KEY: OpRole.Backward + }) + + def _insert_allreduce_ops(self): + block = self.main_program.global_block() + ring_id = self.global_ring_id + grad = None + for idx, op in reversed(list(enumerate(block.ops))): + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.attr(OP_ROLE_VAR_KEY) + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0 + offset = 1 + for i in range(0, len(op_role_var), 2): + param_name = op_role_var[i] + param = block.var(param_name) + grad_name = op_role_var[i + 1] + grad = block.var(grad_name) + if param.is_distributed: + continue + + block._insert_op( + idx + offset, + type='c_sync_calc_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={OP_ROLE_KEY: OpRole.Backward, }) + offset += 1 + block._insert_op( + idx + offset, + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward + }) + + if grad is None: + return + + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + block._insert_op( + idx, + type='c_sync_comm_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + break diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 37bcac49574..8341e9b93e6 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) +list(APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables) @@ -54,6 +55,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_2) list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_3) list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) @@ -100,6 +102,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv_api) LIST(REMOVE_ITEM TEST_OPS test_collective_wait) LIST(REMOVE_ITEM TEST_OPS test_memcpy_op) + LIST(REMOVE_ITEM TEST_OPS test_raw_program_optimizer) endif() if(WIN32) @@ -571,7 +574,7 @@ endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf) # Coverage pipeline use cuda 10.1 now, profiler will random hang in cuda 10.1, # see https://github.com/PaddlePaddle/Paddle/issues/29082 for details. -# We guess there are some bugs in cuda 10.1 or 10.2, +# We guess there are some bugs in cuda 10.1 or 10.2, # since this unittest is stable in cuda 11 (py3 pipeline) now. if(NOT WITH_COVERAGE) py_test_modules(test_parallel_executor_profiler MODULES test_parallel_executor_profiler) @@ -596,8 +599,8 @@ py_test_modules(test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_c py_test_modules(test_fuse_bn_add_act_pass MODULES test_fuse_bn_add_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000) # NOTE: These unittests will appear NaN steadily in windows CI. After analysis, -# it is found that windows CI will run all the training unittests with the ON_INFER option turned on, -# which will not appear in other CIs. The calculation behavior of some ops in inference mode is +# it is found that windows CI will run all the training unittests with the ON_INFER option turned on, +# which will not appear in other CIs. The calculation behavior of some ops in inference mode is # inconsistent with that in non-inference mode. if(NOT ON_INFER) py_test_modules(test_parallel_executor_seresnext_base_cpu MODULES test_parallel_executor_seresnext_base_cpu) @@ -640,7 +643,7 @@ if (WITH_XPU) add_subdirectory(xpu) endif() -# dist xpu tests: +# dist xpu tests: if (WITH_XPU_BKCL) py_test(test_collective_reduce_api_xpu SRCS "test_collective_reduce_api.py") py_test(test_collective_allreduce_api_xpu SRCS "test_collective_allreduce_api.py") @@ -708,6 +711,7 @@ if (WITH_DISTRIBUTE) set_tests_properties(test_dist_fleet_ctr2 PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_fleet_sparse_embedding_ctr PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200) + set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120) endif() if (WITH_DISTRIBUTE AND NOT APPLE) diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer.py b/python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer.py new file mode 100644 index 00000000000..575c07390a3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test_dist_base import TestDistRunnerBase, runtime_main +import unittest +import paddle +import os +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker +import numpy as np +from functools import reduce +import paddle.fluid as fluid + +paddle.enable_static() + +DTYPE = "float32" +paddle.dataset.mnist.fetch() + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +def cnn_model(data): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=data, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + + SIZE = 10 + input_shape = conv_pool_2.shape + param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE] + scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5 + + predict = fluid.layers.fc( + input=conv_pool_2, + size=SIZE, + act="softmax", + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + return predict + + +class TestFleetMetaOptimizerPrecision(TestDistRunnerBase): + def get_model(self, batch_size=2, single_device=False): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + test_program = fluid.default_main_program().clone(for_test=True) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + + optimizer = paddle.fluid.optimizer.Adam(0.01) + if single_device: + optimizer.minimize(avg_cost) + else: + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.without_graph_optimization = True + optimizer = fleet.distributed_optimizer( + optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + return test_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestFleetMetaOptimizerPrecision) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index edc510e4e76..78b06bd5333 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -186,6 +186,76 @@ class TestDistRunnerBase(object): fleet.save_inference_model(exe, infer_save_dir_fleet, feeded_var_names, [avg_cost]) + def run_use_fleet_api_20_trainer(self, args): + """ + 1. remove codes for DistributedStrategy and leave the DistributedStrategy part to get_model() + 2. to run with fleet 2.0 api, set flags _use_fleet_api and _use_fleet_api_20 to True + 3. for now, not support test for model save + """ + assert args.update_method == "nccl2" or "bkcl" + + self.lr = args.lr + print_to_err("use_fleet 2.0", "fleet.node_num:") + + test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ + self.get_model(batch_size=args.batch_size) + + if fluid.core.is_compiled_with_cuda(): + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + elif fluid.core.is_compiled_with_xpu(): + device_id = int(os.getenv("FLAGS_selected_xpus", "0")) + place = fluid.XPUPlace(device_id) + else: + raise ValueError( + "fleet dygraph api must in paddlepaddle-xpu or paddlepaddle-gpu." + ) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + eprint(type(self).__name__, "run worker startup program done.") + + feed_var_list = [ + var + for var in fluid.default_main_program().global_block().vars.values() + if var.is_data + ] + + eprint("feed_var_list:", feed_var_list) + + if feed_var_list[0].name == 'label': + feed_var_list = feed_var_list[::-1] + + feeder = fluid.DataFeeder(feed_var_list, place) + reader_generator = train_reader() + + def get_data(): + origin_batch = next(reader_generator) + if args.update_method != "local" and args.use_reader_alloc: + new_batch = [] + for offset, item in enumerate(origin_batch): + if offset % 2 == args.trainer_id: + new_batch.append(item) + return new_batch + else: + return origin_batch + + print_to_err(type(self).__name__, "begin to train on trainer") + out_losses = [] + for i in six.moves.xrange(RUN_STEP): + loss, = exe.run(fluid.default_main_program(), + fetch_list=[avg_cost.name], + feed=feeder.feed(get_data())) + out_losses.append(loss[0]) + print_to_err(type(self).__name__, "run step %d finished" % i) + print_to_err(type(self).__name__, "trainer run finished") + print_to_err(type(self).__name__, "dist losses: {}".format(out_losses)) + + if six.PY2: + print(pickle.dumps(out_losses)) + else: + sys.stdout.buffer.write(pickle.dumps(out_losses)) + def run_use_fleet_api_trainer(self, args): assert args.update_method == "nccl2" or "bkcl" @@ -630,6 +700,7 @@ def runtime_main(test_class): parser.add_argument('--use_hallreduce', action='store_true') parser.add_argument('--use_pipeline', action='store_true') parser.add_argument('--use_fleet_api', action='store_true') + parser.add_argument('--use_fleet_api_20', action='store_true') parser.add_argument('--use_local_sgd', action='store_true') parser.add_argument('--ut4grad_allreduce', action='store_true') parser.add_argument( @@ -671,6 +742,8 @@ def runtime_main(test_class): model.run_pserver(args) elif args.use_fleet_api: model.run_use_fleet_api_trainer(args) + elif args.use_fleet_api_20: + model.run_use_fleet_api_20_trainer(args) elif args.use_pipeline: model.run_pipeline_trainer(args) else: @@ -734,6 +807,7 @@ class TestDistBase(unittest.TestCase): self._nccl_comm_num = 1 self._enable_backward_deps = False self._use_fleet_api = False + self._use_fleet_api_20 = False self._use_local_sgd = False self._ut4grad_allreduce = False self._use_hallreduce = False @@ -1060,7 +1134,7 @@ class TestDistBase(unittest.TestCase): tr_cmd += " --fuse_all_reduce {}".format(self._fuse_all_reduce) if self._use_fleet_api: - tr_cmd += " --use_fleet_api" + tr_cmd += " --use_fleet_api_20" if self._use_fleet_api_20 else " --use_fleet_api" if self._use_local_sgd: tr_cmd += " --use_local_sgd" if self._ut4grad_allreduce: diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py new file mode 100644 index 00000000000..e729bfe0537 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from test_dist_base import TestDistBase +import paddle +import os + +paddle.enable_static() +flag_name = os.path.splitext(__file__)[0] + + +class TestFleetMetaOptimizerPrecision(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + self._nccl2_reduce_layer = True + self._use_fleet_api = True + self._use_fleet_api_20 = True + + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "dist_fleet_raw_program_optimizer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py new file mode 100644 index 00000000000..604109b262d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + +paddle.enable_static() + + +class TestFleetMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" + + def test_pipeline_optimizer(self): + import paddle.distributed.fleet as fleet + import paddle.distributed.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.without_graph_optimization = True + + optimizer = paddle.fluid.optimizer.Adam(0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_raw_program_optimizer.py b/python/paddle/fluid/tests/unittests/test_raw_program_optimizer.py new file mode 100644 index 00000000000..34930e3577b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_raw_program_optimizer.py @@ -0,0 +1,77 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.distributed.fleet as fleet +import numpy as np +import os + + +class TestRawProgramOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "0" + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" + + def mlp(self, input_x, input_y, hid_dim=128, label_dim=2): + fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh') + fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh') + prediction = paddle.static.nn.fc(x=[fc_2], + size=label_dim, + activation='softmax') + cost = paddle.nn.functional.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.mean(x=cost) + return avg_cost + + def gen_data(self): + return { + "x": np.random.random(size=(128, 32)).astype('float32'), + "y": np.random.randint( + 2, size=(128, 1)).astype('int64') + } + + def test_single_gpu(self): + paddle.enable_static() + fleet.init(is_collective=True) + sharding_program = paddle.static.Program() + sharding_startup_program = paddle.static.Program() + strategy = fleet.DistributedStrategy() + strategy.without_graph_optimization = True + with fluid.program_guard(sharding_program, sharding_startup_program): + with fluid.unique_name.guard(): + input_x = paddle.static.data( + name="x", shape=[None, 32], dtype='float32') + input_y = paddle.static.data( + name="y", shape=[None, 1], dtype='int64') + cost = self.mlp(input_x=input_x, input_y=input_y) + output_name = cost.name + optimizer = fleet.distributed_optimizer(fluid.optimizer.Adam(), + strategy) + optimizer.minimize(cost) + + trainer_id = fleet.worker_index() + exe = paddle.static.Executor(paddle.CUDAPlace(trainer_id)) + rank = fleet.worker_index() + exe.run(sharding_startup_program) + exe.run(program=sharding_program, feed=self.gen_data()) + + +if __name__ == "__main__": + unittest.main() -- GitLab