From 67bd5d9c3ef5e063ebed89f76da207cc92123a69 Mon Sep 17 00:00:00 2001 From: qipengh Date: Mon, 13 Jun 2022 23:52:41 +0800 Subject: [PATCH] [MLU]add lookup_table_v2 op and fix amp feature of bert with mlu device (#43366) --- .../amp/check_finite_and_unscale_op_mlu.cc | 36 +++++++++-- paddle/fluid/operators/layer_norm_op_mlu.cc | 7 ++- .../fluid/operators/lookup_table_v2_op_mlu.cc | 61 +++++++------------ paddle/fluid/operators/mlu/mlu_baseop.cc | 12 ++++ paddle/fluid/operators/mlu/mlu_baseop.h | 6 ++ .../fluid/operators/optimizers/adam_op_mlu.cc | 2 +- paddle/fluid/operators/softmax_op.cc | 20 +++--- .../mlu/test_lookup_table_v2_op_mlu.py | 2 - 8 files changed, 89 insertions(+), 57 deletions(-) diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_mlu.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_mlu.cc index 237cfcc6f1..48ca1e22df 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op_mlu.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_mlu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" namespace paddle { @@ -22,6 +23,8 @@ using Tensor = framework::Tensor; template class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const { auto& dev_ctx = ctx.template device_context(); @@ -51,6 +54,7 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel { } MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc out_desc(*out); MLUCnnl::IsNanInf(ctx, x_desc.get(), GetBasePtr(x), GetBasePtr(&is_finite)); @@ -70,10 +74,34 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel { // out = in/scale, if found_inf = false // But when found_inf is true, the data of Out should not be used. // So, on MLU, we always compute out with in/scale. - MLUCnnlTensorDesc out_desc(*out); - MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, x_desc.get(), - GetBasePtr(x), scale_desc.get(), GetBasePtr(scale), - out_desc.get(), GetBasePtr(out)); + Tensor float_x; + Tensor float_out; + if (std::is_same::value) { + float_x.Resize(x->dims()); + float_out.Resize(out->dims()); + float_x.mutable_data(ctx.GetPlace()); + float_out.mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc float_x_desc(float_x); + MLUCnnlTensorDesc float_out_desc(float_out); + auto cast_fp16_type = + GetCastDataType(DataType::FLOAT16, DataType::FLOAT32); + MLUCnnl::Cast(ctx, cast_fp16_type, x_desc.get(), GetBasePtr(x), + float_x_desc.get(), GetBasePtr(&float_x)); + + MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, float_x_desc.get(), + GetBasePtr(&float_x), scale_desc.get(), GetBasePtr(scale), + float_out_desc.get(), GetBasePtr(&float_out)); + + auto cast_fp32_type = + GetCastDataType(DataType::FLOAT32, DataType::FLOAT16); + MLUCnnl::Cast(ctx, cast_fp32_type, float_out_desc.get(), + GetBasePtr(&float_out), out_desc.get(), GetBasePtr(out)); + } else { + MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, x_desc.get(), + GetBasePtr(x), scale_desc.get(), GetBasePtr(scale), + out_desc.get(), GetBasePtr(out)); + } } } }; diff --git a/paddle/fluid/operators/layer_norm_op_mlu.cc b/paddle/fluid/operators/layer_norm_op_mlu.cc index a368af86a3..919358febd 100644 --- a/paddle/fluid/operators/layer_norm_op_mlu.cc +++ b/paddle/fluid/operators/layer_norm_op_mlu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" namespace paddle { @@ -122,6 +123,8 @@ class LayerNormMLUKernel : public framework::OpKernel { template class LayerNormGradMLUKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); @@ -207,14 +210,14 @@ class LayerNormGradMLUKernel : public framework::OpKernel { if (dscale && (tmp_dscale.dtype() == DataType::FLOAT16 && dscale->dtype() == DataType::FLOAT32)) { - dscale->mutable_data(place); + dscale->mutable_data(place); MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(), GetBasePtr(&tmp_dscale), float32_desc.get(), GetBasePtr(dscale)); } if (dbias && (tmp_dbias.dtype() == DataType::FLOAT16 && dbias->dtype() == DataType::FLOAT32)) { - dbias->mutable_data(place); + dbias->mutable_data(place); MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(), GetBasePtr(&tmp_dbias), float32_desc.get(), GetBasePtr(dbias)); diff --git a/paddle/fluid/operators/lookup_table_v2_op_mlu.cc b/paddle/fluid/operators/lookup_table_v2_op_mlu.cc index c8ab269c02..b69a52c761 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_mlu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_mlu.cc @@ -18,7 +18,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -constexpr int64_t kNoPadding = -1; template class LookupTableV2MLUKernel : public framework::OpKernel { @@ -27,6 +26,7 @@ class LookupTableV2MLUKernel : public framework::OpKernel { auto *ids_t = ctx.Input("Ids"); // int tensor auto *output_t = ctx.Output("Out"); // float tensor auto *table_t = ctx.Input("W"); + int padding_idx = static_cast(ctx.Attr("padding_idx")); auto *table_var = ctx.InputVar("W"); PADDLE_ENFORCE_EQ( @@ -38,43 +38,10 @@ class LookupTableV2MLUKernel : public framework::OpKernel { MLUCnnlTensorDesc table_desc(*table_t); MLUCnnlTensorDesc output_desc(*output_t); - int64_t padding_idx = ctx.Attr("padding_idx"); - if (padding_idx == kNoPadding) { - MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0, - table_desc.get(), GetBasePtr(table_t), - ids_desc.get(), GetBasePtr(ids_t), - output_desc.get(), GetBasePtr(output_t)); - } else { - Tensor tmp_table_t(table_t->type()); - tmp_table_t.mutable_data(table_t->dims(), ctx.GetPlace()); - - Tensor index; - index.mutable_data({1, 1}, ctx.GetPlace()); - auto idx_value = static_cast(padding_idx); - MLUCnnlTensorDesc index_desc(index); - MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &idx_value, index_desc.get(), - GetBasePtr(&index)); - - auto update_dim = phi::make_ddim({1, table_t->dims()[1]}); - Tensor update; - update.mutable_data(update_dim, ctx.GetPlace()); - - auto update_value = static_cast(0); - MLUCnnlTensorDesc update_desc(update); - MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &update_value, - update_desc.get(), GetBasePtr(&update)); - - MLUCnnlTensorDesc tmp_table_desc(tmp_table_t); - MLUCnnl::ScatterNd( - ctx, CNNL_SCATTERND_UPDATE, index_desc.get(), GetBasePtr(&index), - update_desc.get(), GetBasePtr(&update), table_desc.get(), - GetBasePtr(table_t), tmp_table_desc.get(), GetBasePtr(&tmp_table_t)); - - MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0, - tmp_table_desc.get(), GetBasePtr(&tmp_table_t), - ids_desc.get(), GetBasePtr(ids_t), - output_desc.get(), GetBasePtr(output_t)); - } + MLUCnnl::EmbeddingForward(ctx, padding_idx, table_desc.get(), + GetBasePtr(table_t), ids_desc.get(), + static_cast(GetBasePtr(ids_t)), + output_desc.get(), GetBasePtr(output_t)); } }; @@ -82,6 +49,16 @@ template class LookupTableV2GradMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + auto *table_var = ctx.InputVar("W"); + PADDLE_ENFORCE_EQ(table_var->IsType(), true, + platform::errors::PermissionDenied( + "Unsupported Variable Type , idx in " + "LookupTableV2GradMLUKernel should be LoDTensor.")); + bool is_sparse = ctx.Attr("is_sparse"); + PADDLE_ENFORCE_EQ( + is_sparse, false, + platform::errors::InvalidArgument( + "LookupTableV2GradMLUKernel dose NOT support is_sparse = True.")); auto *ids_t = ctx.Input("Ids"); auto *output_grad_t = ctx.Input(framework::GradVarName("Out")); @@ -91,6 +68,13 @@ class LookupTableV2GradMLUKernel : public framework::OpKernel { int padding_idx = static_cast(ctx.Attr("padding_idx")); + int64_t ids_numel = ids_t->numel(); + PADDLE_ENFORCE_EQ( + ids_numel <= std::numeric_limits::max(), true, + platform::errors::OutOfRange( + "Number of ids greater than int32_t::max , please check " + "number of ids in LookupTableV2GradMLUKernel.")); + Tensor ids_int32(ids_t->dtype()); if (ids_t->dtype() != DataType::INT32) { ids_int32.mutable_data(ids_t->dims(), ctx.GetPlace()); @@ -125,5 +109,4 @@ REGISTER_OP_MLU_KERNEL(lookup_table_v2, ops::LookupTableV2MLUKernel, REGISTER_OP_MLU_KERNEL(lookup_table_v2_grad, ops::LookupTableV2GradMLUKernel, - ops::LookupTableV2GradMLUKernel, ops::LookupTableV2GradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 4183181ac7..daae452b23 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -2802,6 +2802,18 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { cnnlReciprocal(handle, input_desc, input, output_desc, output)); } +/* static */ void MLUCnnl::EmbeddingForward( + const ExecutionContext& ctx, const int padding_idx, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t indices_desc, const int* indices, + const cnnlTensorDescriptor_t output_desc, void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlEmbeddingForward_v2( + handle, weight_desc, weight, indices_desc, indices, padding_idx, + nullptr /*max_norm*/, nullptr /*norm_type*/, output_desc, output)); +} + /* static */ void MLUCnnl::EmbeddingBackward( const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq, const cnnlTensorDescriptor_t indices_desc, const void* indices, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 8d280618dc..288d74a135 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1268,6 +1268,12 @@ class MLUCnnl { const cnnlTensorDescriptor_t output_desc, void* output); + static void EmbeddingForward( + const ExecutionContext& ctx, const int padding_idx, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t indices_desc, const int* indices, + const cnnlTensorDescriptor_t output_desc, void* output); + static void EmbeddingBackward( const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq, const cnnlTensorDescriptor_t indices_desc, const void* indices, diff --git a/paddle/fluid/operators/optimizers/adam_op_mlu.cc b/paddle/fluid/operators/optimizers/adam_op_mlu.cc index 9d33502123..36d0fb491a 100644 --- a/paddle/fluid/operators/optimizers/adam_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_mlu.cc @@ -237,8 +237,8 @@ class AdamWMLUKernel : public AdamMLUKernel { ctx.device_context(), &skip_update_vec); skip_update = skip_update_vec[0]; } - VLOG(3) << "Skip update" << skip_update; bool with_decay = ctx.Attr("with_decay"); + VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay; if (!skip_update && with_decay) { if (ctx.HasInput("MasterParam")) { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 7304467833..d6287f4c76 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -57,15 +57,16 @@ class SoftmaxOp : public framework::OperatorWithKernel { } #endif -#ifndef PADDLE_WITH_ASCEND_CL if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()) || - platform::is_xpu_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument( - "float16 can only be used on GPU/XPU place")); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()) || + platform::is_npu_place(ctx.GetPlace()) || + platform::is_xpu_place(ctx.GetPlace()) || + platform::is_mlu_place(ctx.GetPlace()), + true, + platform::errors::InvalidArgument( + "float16 can only be used on GPU/NPU/XPU/MLU place")); } -#endif return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, library_); @@ -174,9 +175,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { if (input_data_type == framework::proto::VarType::FP16) { if (!(platform::is_gpu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) || - platform::is_xpu_place(ctx.GetPlace()))) + platform::is_xpu_place(ctx.GetPlace()) || + platform::is_mlu_place(ctx.GetPlace()))) PADDLE_THROW(platform::errors::InvalidArgument( - "float16 can only be used on GPU/NPU/XPU place")); + "float16 can only be used on GPU/NPU/XPU/MLU place")); } return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, diff --git a/python/paddle/fluid/tests/unittests/mlu/test_lookup_table_v2_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_lookup_table_v2_op_mlu.py index 17ef85dd2b..2efa8823fd 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_lookup_table_v2_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_lookup_table_v2_op_mlu.py @@ -97,7 +97,6 @@ class TestLookupTableV2FP16(TestLookupTableV2): def set_mlu(self): self.__class__.use_mlu = True self.place = paddle.device.MLUPlace(0) - self.__class__.no_need_check_grad = True class TestLookupTableV2Dim32(TestLookupTableV2): @@ -126,7 +125,6 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2): def set_mlu(self): self.__class__.use_mlu = True self.place = paddle.device.MLUPlace(0) - self.__class__.no_need_check_grad = True class TestLookupTableV2WithPadding(TestLookupTableV2): -- GitLab