未验证 提交 67bd5d9c 编写于 作者: Q qipengh 提交者: GitHub

[MLU]add lookup_table_v2 op and fix amp feature of bert with mlu device (#43366)

上级 affe25b7
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" #include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle { namespace paddle {
...@@ -22,6 +23,8 @@ using Tensor = framework::Tensor; ...@@ -22,6 +23,8 @@ using Tensor = framework::Tensor;
template <typename T> template <typename T>
class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> { class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>();
...@@ -51,6 +54,7 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> { ...@@ -51,6 +54,7 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> {
} }
MLUCnnlTensorDesc x_desc(*x); MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::IsNanInf(ctx, x_desc.get(), GetBasePtr(x), MLUCnnl::IsNanInf(ctx, x_desc.get(), GetBasePtr(x),
GetBasePtr(&is_finite)); GetBasePtr(&is_finite));
...@@ -70,12 +74,36 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> { ...@@ -70,12 +74,36 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> {
// out = in/scale, if found_inf = false // out = in/scale, if found_inf = false
// But when found_inf is true, the data of Out should not be used. // But when found_inf is true, the data of Out should not be used.
// So, on MLU, we always compute out with in/scale. // So, on MLU, we always compute out with in/scale.
MLUCnnlTensorDesc out_desc(*out); Tensor float_x;
Tensor float_out;
if (std::is_same<T, paddle::platform::float16>::value) {
float_x.Resize(x->dims());
float_out.Resize(out->dims());
float_x.mutable_data<MPDType>(ctx.GetPlace());
float_out.mutable_data<MPDType>(ctx.GetPlace());
MLUCnnlTensorDesc float_x_desc(float_x);
MLUCnnlTensorDesc float_out_desc(float_out);
auto cast_fp16_type =
GetCastDataType(DataType::FLOAT16, DataType::FLOAT32);
MLUCnnl::Cast(ctx, cast_fp16_type, x_desc.get(), GetBasePtr(x),
float_x_desc.get(), GetBasePtr(&float_x));
MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, float_x_desc.get(),
GetBasePtr(&float_x), scale_desc.get(), GetBasePtr(scale),
float_out_desc.get(), GetBasePtr(&float_out));
auto cast_fp32_type =
GetCastDataType(DataType::FLOAT32, DataType::FLOAT16);
MLUCnnl::Cast(ctx, cast_fp32_type, float_out_desc.get(),
GetBasePtr(&float_out), out_desc.get(), GetBasePtr(out));
} else {
MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, x_desc.get(), MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, x_desc.get(),
GetBasePtr(x), scale_desc.get(), GetBasePtr(scale), GetBasePtr(x), scale_desc.get(), GetBasePtr(scale),
out_desc.get(), GetBasePtr(out)); out_desc.get(), GetBasePtr(out));
} }
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle { namespace paddle {
...@@ -122,6 +123,8 @@ class LayerNormMLUKernel : public framework::OpKernel<T> { ...@@ -122,6 +123,8 @@ class LayerNormMLUKernel : public framework::OpKernel<T> {
template <typename T> template <typename T>
class LayerNormGradMLUKernel : public framework::OpKernel<T> { class LayerNormGradMLUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
...@@ -207,14 +210,14 @@ class LayerNormGradMLUKernel : public framework::OpKernel<T> { ...@@ -207,14 +210,14 @@ class LayerNormGradMLUKernel : public framework::OpKernel<T> {
if (dscale && (tmp_dscale.dtype() == DataType::FLOAT16 && if (dscale && (tmp_dscale.dtype() == DataType::FLOAT16 &&
dscale->dtype() == DataType::FLOAT32)) { dscale->dtype() == DataType::FLOAT32)) {
dscale->mutable_data<T>(place); dscale->mutable_data<MPDType>(place);
MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(), MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(),
GetBasePtr(&tmp_dscale), float32_desc.get(), GetBasePtr(&tmp_dscale), float32_desc.get(),
GetBasePtr(dscale)); GetBasePtr(dscale));
} }
if (dbias && (tmp_dbias.dtype() == DataType::FLOAT16 && if (dbias && (tmp_dbias.dtype() == DataType::FLOAT16 &&
dbias->dtype() == DataType::FLOAT32)) { dbias->dtype() == DataType::FLOAT32)) {
dbias->mutable_data<T>(place); dbias->mutable_data<MPDType>(place);
MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(), MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(),
GetBasePtr(&tmp_dbias), float32_desc.get(), GetBasePtr(&tmp_dbias), float32_desc.get(),
GetBasePtr(dbias)); GetBasePtr(dbias));
......
...@@ -18,7 +18,6 @@ namespace paddle { ...@@ -18,7 +18,6 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
constexpr int64_t kNoPadding = -1;
template <typename T> template <typename T>
class LookupTableV2MLUKernel : public framework::OpKernel<T> { class LookupTableV2MLUKernel : public framework::OpKernel<T> {
...@@ -27,6 +26,7 @@ class LookupTableV2MLUKernel : public framework::OpKernel<T> { ...@@ -27,6 +26,7 @@ class LookupTableV2MLUKernel : public framework::OpKernel<T> {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids"); // int tensor auto *ids_t = ctx.Input<framework::LoDTensor>("Ids"); // int tensor
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
auto *table_t = ctx.Input<framework::LoDTensor>("W"); auto *table_t = ctx.Input<framework::LoDTensor>("W");
int padding_idx = static_cast<int>(ctx.Attr<int64_t>("padding_idx"));
auto *table_var = ctx.InputVar("W"); auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -38,50 +38,27 @@ class LookupTableV2MLUKernel : public framework::OpKernel<T> { ...@@ -38,50 +38,27 @@ class LookupTableV2MLUKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc table_desc(*table_t); MLUCnnlTensorDesc table_desc(*table_t);
MLUCnnlTensorDesc output_desc(*output_t); MLUCnnlTensorDesc output_desc(*output_t);
int64_t padding_idx = ctx.Attr<int64_t>("padding_idx"); MLUCnnl::EmbeddingForward(ctx, padding_idx, table_desc.get(),
if (padding_idx == kNoPadding) { GetBasePtr(table_t), ids_desc.get(),
MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0, static_cast<const int *>(GetBasePtr(ids_t)),
table_desc.get(), GetBasePtr(table_t),
ids_desc.get(), GetBasePtr(ids_t),
output_desc.get(), GetBasePtr(output_t));
} else {
Tensor tmp_table_t(table_t->type());
tmp_table_t.mutable_data<T>(table_t->dims(), ctx.GetPlace());
Tensor index;
index.mutable_data<int32_t>({1, 1}, ctx.GetPlace());
auto idx_value = static_cast<int32_t>(padding_idx);
MLUCnnlTensorDesc index_desc(index);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &idx_value, index_desc.get(),
GetBasePtr(&index));
auto update_dim = phi::make_ddim({1, table_t->dims()[1]});
Tensor update;
update.mutable_data<T>(update_dim, ctx.GetPlace());
auto update_value = static_cast<T>(0);
MLUCnnlTensorDesc update_desc(update);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &update_value,
update_desc.get(), GetBasePtr(&update));
MLUCnnlTensorDesc tmp_table_desc(tmp_table_t);
MLUCnnl::ScatterNd(
ctx, CNNL_SCATTERND_UPDATE, index_desc.get(), GetBasePtr(&index),
update_desc.get(), GetBasePtr(&update), table_desc.get(),
GetBasePtr(table_t), tmp_table_desc.get(), GetBasePtr(&tmp_table_t));
MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0,
tmp_table_desc.get(), GetBasePtr(&tmp_table_t),
ids_desc.get(), GetBasePtr(ids_t),
output_desc.get(), GetBasePtr(output_t)); output_desc.get(), GetBasePtr(output_t));
} }
}
}; };
template <typename T> template <typename T>
class LookupTableV2GradMLUKernel : public framework::OpKernel<T> { class LookupTableV2GradMLUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(table_var->IsType<framework::LoDTensor>(), true,
platform::errors::PermissionDenied(
"Unsupported Variable Type , idx in "
"LookupTableV2GradMLUKernel should be LoDTensor."));
bool is_sparse = ctx.Attr<bool>("is_sparse");
PADDLE_ENFORCE_EQ(
is_sparse, false,
platform::errors::InvalidArgument(
"LookupTableV2GradMLUKernel dose NOT support is_sparse = True."));
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids"); auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");
auto *output_grad_t = auto *output_grad_t =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")); ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
...@@ -91,6 +68,13 @@ class LookupTableV2GradMLUKernel : public framework::OpKernel<T> { ...@@ -91,6 +68,13 @@ class LookupTableV2GradMLUKernel : public framework::OpKernel<T> {
int padding_idx = static_cast<int>(ctx.Attr<int64_t>("padding_idx")); int padding_idx = static_cast<int>(ctx.Attr<int64_t>("padding_idx"));
int64_t ids_numel = ids_t->numel();
PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::OutOfRange(
"Number of ids greater than int32_t::max , please check "
"number of ids in LookupTableV2GradMLUKernel."));
Tensor ids_int32(ids_t->dtype()); Tensor ids_int32(ids_t->dtype());
if (ids_t->dtype() != DataType::INT32) { if (ids_t->dtype() != DataType::INT32) {
ids_int32.mutable_data<int>(ids_t->dims(), ctx.GetPlace()); ids_int32.mutable_data<int>(ids_t->dims(), ctx.GetPlace());
...@@ -125,5 +109,4 @@ REGISTER_OP_MLU_KERNEL(lookup_table_v2, ops::LookupTableV2MLUKernel<float>, ...@@ -125,5 +109,4 @@ REGISTER_OP_MLU_KERNEL(lookup_table_v2, ops::LookupTableV2MLUKernel<float>,
REGISTER_OP_MLU_KERNEL(lookup_table_v2_grad, REGISTER_OP_MLU_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradMLUKernel<float>, ops::LookupTableV2GradMLUKernel<float>,
ops::LookupTableV2GradMLUKernel<int>,
ops::LookupTableV2GradMLUKernel<plat::float16>); ops::LookupTableV2GradMLUKernel<plat::float16>);
...@@ -2802,6 +2802,18 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { ...@@ -2802,6 +2802,18 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
cnnlReciprocal(handle, input_desc, input, output_desc, output)); cnnlReciprocal(handle, input_desc, input, output_desc, output));
} }
/* static */ void MLUCnnl::EmbeddingForward(
const ExecutionContext& ctx, const int padding_idx,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t indices_desc, const int* indices,
const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlEmbeddingForward_v2(
handle, weight_desc, weight, indices_desc, indices, padding_idx,
nullptr /*max_norm*/, nullptr /*norm_type*/, output_desc, output));
}
/* static */ void MLUCnnl::EmbeddingBackward( /* static */ void MLUCnnl::EmbeddingBackward(
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq, const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
const cnnlTensorDescriptor_t indices_desc, const void* indices, const cnnlTensorDescriptor_t indices_desc, const void* indices,
......
...@@ -1268,6 +1268,12 @@ class MLUCnnl { ...@@ -1268,6 +1268,12 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_desc, const cnnlTensorDescriptor_t output_desc,
void* output); void* output);
static void EmbeddingForward(
const ExecutionContext& ctx, const int padding_idx,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t indices_desc, const int* indices,
const cnnlTensorDescriptor_t output_desc, void* output);
static void EmbeddingBackward( static void EmbeddingBackward(
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq, const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
const cnnlTensorDescriptor_t indices_desc, const void* indices, const cnnlTensorDescriptor_t indices_desc, const void* indices,
......
...@@ -237,8 +237,8 @@ class AdamWMLUKernel : public AdamMLUKernel<T> { ...@@ -237,8 +237,8 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
ctx.device_context(), &skip_update_vec); ctx.device_context(), &skip_update_vec);
skip_update = skip_update_vec[0]; skip_update = skip_update_vec[0];
} }
VLOG(3) << "Skip update" << skip_update;
bool with_decay = ctx.Attr<bool>("with_decay"); bool with_decay = ctx.Attr<bool>("with_decay");
VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay;
if (!skip_update && with_decay) { if (!skip_update && with_decay) {
if (ctx.HasInput("MasterParam")) { if (ctx.HasInput("MasterParam")) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -57,15 +57,16 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -57,15 +57,16 @@ class SoftmaxOp : public framework::OperatorWithKernel {
} }
#endif #endif
#ifndef PADDLE_WITH_ASCEND_CL
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()) || PADDLE_ENFORCE_EQ(
platform::is_xpu_place(ctx.GetPlace()), platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"float16 can only be used on GPU/XPU place")); "float16 can only be used on GPU/NPU/XPU/MLU place"));
} }
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_); library_);
...@@ -174,9 +175,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -174,9 +175,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
if (!(platform::is_gpu_place(ctx.GetPlace()) || if (!(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()))) platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace())))
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU place")); "float16 can only be used on GPU/NPU/XPU/MLU place"));
} }
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
......
...@@ -97,7 +97,6 @@ class TestLookupTableV2FP16(TestLookupTableV2): ...@@ -97,7 +97,6 @@ class TestLookupTableV2FP16(TestLookupTableV2):
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0) self.place = paddle.device.MLUPlace(0)
self.__class__.no_need_check_grad = True
class TestLookupTableV2Dim32(TestLookupTableV2): class TestLookupTableV2Dim32(TestLookupTableV2):
...@@ -126,7 +125,6 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2): ...@@ -126,7 +125,6 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2):
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0) self.place = paddle.device.MLUPlace(0)
self.__class__.no_need_check_grad = True
class TestLookupTableV2WithPadding(TestLookupTableV2): class TestLookupTableV2WithPadding(TestLookupTableV2):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册