From c3ae0d4009658bbbb74c0f36b4a19c8e099ec4ba Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Thu, 13 May 2021 11:11:51 +0800 Subject: [PATCH] solved some npu bugs (#32793) --- paddle/fluid/framework/operator.cc | 8 +- paddle/fluid/framework/operator.h | 8 ++ paddle/fluid/framework/section_worker.cc | 16 ++- .../operators/collective/recv_v2_op_npu.cc | 15 +-- .../fluid/operators/lookup_table_v2_op_npu.cc | 5 + python/paddle/distributed/collective.py | 99 +++++++++++++++++-- .../fleet/meta_optimizers/sharding/utils.py | 9 +- .../meta_optimizers/sharding_optimizer.py | 19 ++-- python/paddle/fluid/dataset.py | 4 +- python/paddle/fluid/layers/nn.py | 4 +- python/paddle/fluid/optimizer.py | 8 +- .../npu/test_lookup_table_v2_op_npu.py | 2 +- 12 files changed, 166 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 955c917b2c1..c27f48f73c8 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1228,6 +1228,8 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, // will be executed and a warning will be given at the same time. if (SupportGPU()) { expected_kernel_key.place_ = dev_ctx->GetPlace(); + } else if (SupportNPU()) { + expected_kernel_key.place_ = dev_ctx->GetPlace(); } else { expected_kernel_key.place_ = platform::CPUPlace(); LOG_FIRST_N(WARNING, 1) @@ -1299,7 +1301,11 @@ void OperatorWithKernel::TransferInplaceVarsBack( auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto original_dims = original_tensor->dims(); original_tensor->ShareDataWith(*transformed_tensor); - original_tensor->Resize(original_dims); + // In order to solve the problem that the output latitude of NPU reshape + // operator is not changed when inplace. + if (type_ != "reshape2" && type_ != "reshape2_grad") { + original_tensor->Resize(original_dims); + } } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 3fc61581eca..fc01513a866 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -154,6 +154,7 @@ class OperatorBase { std::string DebugString() const { return DebugStringEx(nullptr); } virtual bool SupportGPU() const { return false; } + virtual bool SupportNPU() const { return false; } const std::string& Type() const { return type_; } @@ -490,6 +491,13 @@ class OperatorWithKernel : public OperatorBase { return platform::is_gpu_place(kern_pair.first.place_); }); } + bool SupportNPU() const override { + auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); + return std::any_of(op_kernels.begin(), op_kernels.end(), + [](OpKernelMap::const_reference kern_pair) { + return platform::is_npu_place(kern_pair.first.place_); + }); + } bool SupportsMKLDNN(proto::VarType::Type data_type) const; bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 00ff50abadd..993b9ac52c5 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -110,8 +110,22 @@ void SectionWorker::TrainFiles() { BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size)); } } +#elif defined(PADDLE_WITH_ASCEND_CL) + if (IsFastEagerDeletionModeEnabled()) { + VLOG(4) << "Use unsafe fast gc for NPU."; + gc.reset(new NPUUnsafeFastGarbageCollector( + BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Please set FLAGS_fast_eager_deletion_mode=true to use " + "GarbageCollector on NPU.")); + // TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector. + VLOG(4) << "Use default stream gc for NPU."; + gc.reset(new NPUDefaultStreamGarbageCollector( + BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size)); + } #endif - } + } // max_memory_size >= 0 if (schedule_mode_ == 0) { // F-then-B scheduler which runs Forward phase for all microbatches, diff --git a/paddle/fluid/operators/collective/recv_v2_op_npu.cc b/paddle/fluid/operators/collective/recv_v2_op_npu.cc index 69f1f4681a3..52a23c50c0e 100644 --- a/paddle/fluid/operators/collective/recv_v2_op_npu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op_npu.cc @@ -27,10 +27,11 @@ class CRecvOpASCENDKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_ASCEND_CL) - auto x = ctx.Output("Out"); - void* ptr = reinterpret_cast(const_cast(x->data())); - int numel = x->numel(); - HcclDataType dtype = platform::ToHCCLDataType(x->type()); + auto out = ctx.Output("Out"); + out->mutable_data(out->dims(), ctx.GetPlace()); + void* ptr = reinterpret_cast(const_cast(out->data())); + int numel = out->numel(); + HcclDataType dtype = platform::ToHCCLDataType(out->type()); int ring_id = ctx.Attr("ring_id"); auto place = ctx.GetPlace(); @@ -54,8 +55,10 @@ class CRecvOpASCENDKernel : public framework::OpKernel { int root = peer; VLOG(3) << "begin hccl recv, parameter is: " - << "root " << root << ", comm: " << comm->comm() - << ", stream: " << stream; + << "ring_id:" << ring_id << ", nranks:" << nranks + << ", peer:" << peer << ", numel:" << numel << ", ptr:" << ptr + << ", dtype:" << dtype << ", root:" << root + << ", comm: " << comm->comm() << ", stream: " << stream; PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast( ptr, numel, dtype, (uint32_t)root, comm->comm(), stream)); diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 9574b325ef7..87618b954d2 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -29,6 +29,11 @@ class LookupTableV2NPUKernel : public framework::OpKernel { auto *output_t = ctx.Output("Out"); // float tensor auto *table_t = ctx.Input("W"); + // It seems cann 20.1 accepts int64, but cann 20.2+ not. + PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32, + platform::errors::Unimplemented( + "The index of LookupTableV2 should be int32.")); + auto *table_var = ctx.InputVar("W"); PADDLE_ENFORCE_EQ( table_var->IsType(), true, diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index ba4c3b09f9f..e28ef1e94b1 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -25,6 +25,7 @@ from ..fluid.data_feeder import check_type from ..fluid.data_feeder import check_dtype from ..fluid.layers.tensor import fill_constant from ..fluid.layers import utils +from ..fluid.dygraph import layers from ..fluid.dygraph.parallel import prepare_context import paddle from .fleet import fleet @@ -875,6 +876,84 @@ def _mp_allreduce(tensor, raise NotImplementedError("No support _mp_allreduce in dygraph mode.") +class _Linear(layers.Layer): + """ + Linear + """ + + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + super(_Linear, self).__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.bias = self.create_parameter( + shape=[out_features], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + self.name = name + + def forward(self, input): + out = _linear( + x=input, weight=self.weight, bias=self.bias, name=self.name) + return out + + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'in_features={}, out_features={}, dtype={}{}'.format( + self.weight.shape[0], self.weight.shape[1], self._dtype, name_str) + + +def _linear(x, weight, bias=None, name=None): + """ + Fuction Linear + """ + if in_dygraph_mode(): + pre_bias = _varbase_creator(dtype=x.dtype) + core.ops.matmul(x, weight, pre_bias, 'transpose_X', False, + 'transpose_Y', False, "alpha", 1) + return dygraph_utils._append_bias_in_dygraph( + pre_bias, bias, axis=len(x.shape) - 1) + else: + helper = LayerHelper('linear', **locals()) + dtype = x.dtype + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'linear') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') + + inputs = {'X': [x], 'Y': [weight]} + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + tmp = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs) + if bias is not None: + res = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='elementwise_add', + inputs={'X': [tmp], + 'Y': [bias]}, + outputs={'Out': [res]}, + attrs={'axis': len(x.shape) - 1}) + else: + res = tmp + return res + + def _parallel_linear(x, num_rows, num_cols, @@ -900,12 +979,20 @@ def _parallel_linear(x, else: x = _c_identity(x, group=group) - linear = paddle.nn.Linear( - num_rows, - num_cols, - weight_attr=param_attr, - bias_attr=bias_attr, - name=name) + if core.is_compiled_with_npu(): + linear = _Linear( + num_rows, + num_cols, + weight_attr=param_attr, + bias_attr=bias_attr, + name=name) + else: + linear = paddle.nn.Linear( + num_rows, + num_cols, + weight_attr=param_attr, + bias_attr=bias_attr, + name=name) linear_out = linear(x) startup_block = paddle.static.default_startup_program().global_block() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index f4ceb2d287a..ca3606c16e5 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -402,13 +402,18 @@ def get_grad_device(grad_name, shard): return shard.global_param2device[base_name] -def get_first_check_finite_and_unscale_op_idx(block): +def get_first_check_finite_and_unscale_op_idx(block, raise_error=True): for idx, op in enumerate(block.ops): if op.type == "check_finite_and_unscale": return idx - raise ValueError("check_finite_and_unscale does not exist in block") + if raise_error: + raise ValueError( + "amp is turned on but check_finite_and_unscale op does not exist in main block" + ) + + return -1 def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 82e54a89e10..aafb15e0a01 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -298,7 +298,7 @@ class ShardingOptimizer(MetaOptimizerBase): print("persistable FP32 grad: ") print(accumulated_grad_names) first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( - main_block) + main_block, raise_error=self.user_defined_strategy.amp) insert_reduce_ops( main_block, first_optimize_op_index, @@ -309,14 +309,15 @@ class ShardingOptimizer(MetaOptimizerBase): use_calc_stream=True) if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( - main_block) - insert_allreduce_ops( - main_block, - first_optimize_op_index, - self.dp_ring_id, - accumulated_grad_names, - core.op_proto_and_checker_maker.OpRole.Optimize, - use_calc_stream=True) + main_block, raise_error=self.user_defined_strategy.amp) + if first_optimize_op_index >= 0: + insert_allreduce_ops( + main_block, + first_optimize_op_index, + self.dp_ring_id, + accumulated_grad_names, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True) # if not use sharding, adapt amp/clip, for remain parallelism. # cast --> amp --> clip --> opt diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index b4cd3326dde..2b9d5128560 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -252,9 +252,11 @@ class DatasetBase(object): slot_var.type = "float" elif var.dtype == core.VarDesc.VarType.INT64: slot_var.type = "uint64" + elif var.dtype == core.VarDesc.VarType.INT32: + slot_var.type = "uint32" else: raise ValueError( - "Currently, fluid.dataset only supports dtype=float32 and dtype=int64" + "Currently, fluid.dataset only supports dtype=float32, dtype=int32 and dtype=int64" ) def set_hdfs_config(self, fs_name, fs_ugi): diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index aa021c463bf..f87485c6a8f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14772,7 +14772,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): the size of the last shard will be less than the calculated `shard_size` Args: - input (Tensor): Input indices with data type int64. It's last dimension must be 1. + input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1. index_num (int): An integer defining the range of the index. nshards (int): The number of shards. shard_id (int): The index of the current shard. @@ -14793,7 +14793,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): print(shard_label) # [[-1], [1]] """ - check_variable_and_dtype(input, 'input', ['int64'], 'shard_index') + check_variable_and_dtype(input, 'input', ['int64', 'int32'], 'shard_index') op_type = 'shard_index' helper = LayerHelper(op_type, **locals()) if shard_id < 0 or shard_id >= nshards: diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 41b2843ea33..83c4398e41a 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4200,6 +4200,8 @@ class PipelineOptimizer(object): op.type == 'elementwise_div'): device = "gpu:all" op._set_attr(self._op_device_key, device) + elif op.type == "alloc_float_status": + op._set_attr(self._op_device_key, "gpu:all") else: other_known_ops = [ 'update_loss_scaling', @@ -4207,6 +4209,7 @@ class PipelineOptimizer(object): 'concat', 'sum', 'check_finite_and_unscale', + 'alloc_float_status', ] assert op.type in other_known_ops, "For other ops without " \ "op_device set, they must be one of {}, but it " \ @@ -4272,8 +4275,9 @@ class PipelineOptimizer(object): "{} has not been set.".format(op.type)) if device == "gpu:all": continue dev_type = device.split(':')[0] - assert dev_type == "gpu", ("Now only gpu devices are supported " - "for pipeline parallelism.") + assert dev_type == "gpu" or dev_type == 'npu', ( + "Now only gpu and npu devices are supported " + "for pipeline parallelism.") if not device in device_list: device_list.append(device) return device_list diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py index 400ddd9d4aa..2463ddb7137 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -41,7 +41,7 @@ class TestLookupTableV2(OpTest): vocab = 10 dim = 20 w = np.ones([vocab, dim]).astype(self.dtype) - x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64) + x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32) out = np.ones([bsz, seqlen, dim]).astype(self.dtype) self.inputs = { -- GitLab