diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu index cf9df34a2467f8461c4c284b4848c54b76edf452..6b60d989d2c9c96a54b09c8f1110960a29279b22 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu @@ -20,8 +20,9 @@ namespace paddle { namespace operators { template -__global__ void GpuInverse(const T* s, T* o) { +__global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) { *o = Inverse(*s); + *found_inf = false; } template @@ -30,10 +31,11 @@ __global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num, const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < num) { - if (!isfinite(in[idx])) { + T val = in[idx] * (*scale); + out[idx] = val; + if (!isfinite(val)) { *found_inf = true; } - out[idx] = *found_inf ? in[idx] : in[idx] * (*scale); } } @@ -49,13 +51,13 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel { const T* scale_data = scale->data(); bool* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); - cudaMemset(found_inf_data, false, found_inf->numel() * sizeof(bool)); framework::Tensor inverse_scale = ctx.AllocateTmpTensor({1}, dev_ctx); T* inverse_scale_v = inverse_scale.template data(); - GpuInverse<<<1, 1, 0, dev_ctx.stream()>>>(scale_data, inverse_scale_v); + InverseAndMemset<<<1, 1, 0, dev_ctx.stream()>>>( + scale_data, inverse_scale_v, found_inf_data); for (size_t i = 0; i < xs.size(); ++i) { const auto* x = xs[i]; diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cu b/paddle/fluid/operators/amp/update_loss_scaling_op.cu index 2bc60423d247447adf18eb3ef050ca9b395a2e2f..4da45df7ecfdb900cdbcd71fb7de24de93bb3ec4 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cu +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cu @@ -61,13 +61,14 @@ class LazyZeroInputs { bool has_inf{false}; memory::Copy(platform::CPUPlace(), &has_inf, gpu_place, found_inf_data, sizeof(bool), dev_ctx.stream()); + dev_ctx.Wait(); // wait async copy if (has_inf) { VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --"; for (size_t i = 0; i < xs.size(); ++i) { auto* out = outs[i]; T* out_data = out->mutable_data(dev_ctx.GetPlace()); int num = out->numel(); - cudaMemset(out_data, 0, num * sizeof(T)); + cudaMemsetAsync(out_data, 0, num * sizeof(T), dev_ctx.stream()); } } } diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index d861aa7579f461c61032e77e38e59f9376df0210..24e0b196d4974ae9f8e3fe0612691aa53b48c2f3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -53,6 +53,15 @@ class AMPOptimizer(MetaOptimizerBase): config['incr_ratio'], config['decr_ratio'], config['use_dynamic_loss_scaling']) + # if worker_num > 1, all cards will communication with each other, + # add is_distributed to optimize amp, overlap communication and + # computation by split the check_finite_and_unscale op. + is_distributed = self.role_maker._worker_num() > 1 + if self.user_defined_strategy.sharding: + # FIXME(wangxi). sharding failed when split check_finite_and_unscale + is_distributed = False + self.wrapped_opt._set_distributed(is_distributed) + def _can_apply(self): if not self.role_maker._is_collective: return False diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 529c664e7083ccd86d65464302dbaac7bffaab3c..a4279cde42b5a9f74ca8ab0a52eaa7366f4a4156 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -61,6 +61,7 @@ class OptimizerWithMixedPrecision(object): self._param_grads = None self._train_program = None + self._is_distributed = False self._scaled_loss = None self._loss_scaling = None self._init_loss_scaling = init_loss_scaling @@ -73,6 +74,12 @@ class OptimizerWithMixedPrecision(object): self._num_good_steps = None self._num_bad_steps = None + def _set_distributed(self, flag): + # if distributed, all cards will communication with each other, + # overlap communication and computation by split the + # check_finite_and_unscale op. + self._is_distributed = flag + def get_loss_scaling(self): """Return the real-time loss scaling factor. """ @@ -168,13 +175,28 @@ class OptimizerWithMixedPrecision(object): """ grads = [g for _, g in params_grads] - with self._train_program._optimized_guard(grads): - grads, found_inf = check_finite_and_unscale( - grads, self._loss_scaling, name="find_infinite_scale") + if not self._is_distributed: + with self._train_program._optimized_guard(grads): + grads, found_inf = check_finite_and_unscale( + grads, self._loss_scaling, name="find_infinite_scale") + else: + # if distributed, split check_finite_and_unscale to overlap + # unscale with communication + found_infs = [] + for p, g in params_grads: + with self._train_program._optimized_guard([p, g]): + _, found_inf = check_finite_and_unscale( + [g, ], self._loss_scaling, name="find_infinite_scale") + found_infs.append(found_inf) if self._use_dynamic_loss_scaling: - with self._train_program._optimized_guard(grads): - grads = update_loss_scaling( + if self._is_distributed: + with self._train_program._optimized_guard([]): + all_infs = layers.concat(found_infs) + found_inf = layers.reduce_any(all_infs) + + with self._train_program._optimized_guard([]): + update_loss_scaling( grads, found_inf, self._loss_scaling, @@ -186,13 +208,7 @@ class OptimizerWithMixedPrecision(object): self._decr_ratio, name="update_loss_scaling") - params_unscaled_grads = [] - for pg, new_g in zip(params_grads, grads): - params_unscaled_grads.append((pg[0], new_g)) - # apply_gradient append all ops in global block, thus we shouldn't - # apply gradient in the switch branch. - optimize_ops = self._optimizer.apply_gradients(params_unscaled_grads) - + optimize_ops = self._optimizer.apply_gradients(params_grads) return optimize_ops def apply_optimize(self, loss, startup_program, params_grads): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py index eb4ac1356eaaff4359854176ad18edb0cef178e6..30f6607df9d8ad99e1c2bb87b03e157390f848bb 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py @@ -19,6 +19,7 @@ import paddle.distributed.fleet as fleet from paddle.distributed.fleet.meta_optimizers import AMPOptimizer import os from fleet_meta_optimizer_base import TestFleetMetaOptimizer +import paddle.distributed.fleet.base.role_maker as role_maker paddle.enable_static() @@ -32,7 +33,10 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer): opt = fluid.optimizer.MomentumOptimizer( learning_rate=0.001, momentum=0.9) opt = AMPOptimizer(opt) - opt.user_defined_strategy = strategy + + self.set_strategy(strategy, 'amp') + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + opt._set_basic_info(avg_cost, role, opt, strategy) params_grads = opt.backward(avg_cost, startup_prog) ops = [op.type for op in avg_cost.block.ops] @@ -47,7 +51,10 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer): opt = fluid.optimizer.MomentumOptimizer( learning_rate=0.001, momentum=0.9) opt = AMPOptimizer(opt) - opt.user_defined_strategy = strategy + + self.set_strategy(strategy, 'amp') + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + opt._set_basic_info(avg_cost, role, opt, strategy) params_grads = opt.backward(avg_cost, startup_prog) with fluid.program_guard(train_prog, startup_prog): opt.apply_gradients(params_grads) @@ -64,7 +71,10 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer): opt = fluid.optimizer.MomentumOptimizer( learning_rate=0.001, momentum=0.9) opt = AMPOptimizer(opt) - opt.user_defined_strategy = strategy + + self.set_strategy(strategy, 'amp') + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + opt._set_basic_info(avg_cost, role, opt, strategy) params_grads = opt.backward(avg_cost, startup_prog) opt.apply_optimize(avg_cost, startup_prog, params_grads) @@ -83,6 +93,22 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer): self.assertIn('cast', ops) self.assertIn('check_finite_and_unscale', ops) + def test_amp_distributed_optimizer(self): + """ test amp when distributed """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'amp') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + check_count = 0 + for name in ops: + if name == 'check_finite_and_unscale': + check_count += 1 + self.assertEqual(check_count, len(train_prog.all_parameters())) + def test_amp_recompute_optimizer(self): """ test amp + recompute """ train_prog, startup_prog = fluid.Program(), fluid.Program()