未验证 提交 67bd5d9c 编写于 作者: Q qipengh 提交者: GitHub

[MLU]add lookup_table_v2 op and fix amp feature of bert with mlu device (#43366)

上级 affe25b7
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
......@@ -22,6 +23,8 @@ using Tensor = framework::Tensor;
template <typename T>
class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>();
......@@ -51,6 +54,7 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> {
}
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::IsNanInf(ctx, x_desc.get(), GetBasePtr(x),
GetBasePtr(&is_finite));
......@@ -70,10 +74,34 @@ class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel<T> {
// out = in/scale, if found_inf = false
// But when found_inf is true, the data of Out should not be used.
// So, on MLU, we always compute out with in/scale.
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, x_desc.get(),
GetBasePtr(x), scale_desc.get(), GetBasePtr(scale),
out_desc.get(), GetBasePtr(out));
Tensor float_x;
Tensor float_out;
if (std::is_same<T, paddle::platform::float16>::value) {
float_x.Resize(x->dims());
float_out.Resize(out->dims());
float_x.mutable_data<MPDType>(ctx.GetPlace());
float_out.mutable_data<MPDType>(ctx.GetPlace());
MLUCnnlTensorDesc float_x_desc(float_x);
MLUCnnlTensorDesc float_out_desc(float_out);
auto cast_fp16_type =
GetCastDataType(DataType::FLOAT16, DataType::FLOAT32);
MLUCnnl::Cast(ctx, cast_fp16_type, x_desc.get(), GetBasePtr(x),
float_x_desc.get(), GetBasePtr(&float_x));
MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, float_x_desc.get(),
GetBasePtr(&float_x), scale_desc.get(), GetBasePtr(scale),
float_out_desc.get(), GetBasePtr(&float_out));
auto cast_fp32_type =
GetCastDataType(DataType::FLOAT32, DataType::FLOAT16);
MLUCnnl::Cast(ctx, cast_fp32_type, float_out_desc.get(),
GetBasePtr(&float_out), out_desc.get(), GetBasePtr(out));
} else {
MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, x_desc.get(),
GetBasePtr(x), scale_desc.get(), GetBasePtr(scale),
out_desc.get(), GetBasePtr(out));
}
}
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
......@@ -122,6 +123,8 @@ class LayerNormMLUKernel : public framework::OpKernel<T> {
template <typename T>
class LayerNormGradMLUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
......@@ -207,14 +210,14 @@ class LayerNormGradMLUKernel : public framework::OpKernel<T> {
if (dscale && (tmp_dscale.dtype() == DataType::FLOAT16 &&
dscale->dtype() == DataType::FLOAT32)) {
dscale->mutable_data<T>(place);
dscale->mutable_data<MPDType>(place);
MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(),
GetBasePtr(&tmp_dscale), float32_desc.get(),
GetBasePtr(dscale));
}
if (dbias && (tmp_dbias.dtype() == DataType::FLOAT16 &&
dbias->dtype() == DataType::FLOAT32)) {
dbias->mutable_data<T>(place);
dbias->mutable_data<MPDType>(place);
MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(),
GetBasePtr(&tmp_dbias), float32_desc.get(),
GetBasePtr(dbias));
......
......@@ -18,7 +18,6 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
constexpr int64_t kNoPadding = -1;
template <typename T>
class LookupTableV2MLUKernel : public framework::OpKernel<T> {
......@@ -27,6 +26,7 @@ class LookupTableV2MLUKernel : public framework::OpKernel<T> {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids"); // int tensor
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
auto *table_t = ctx.Input<framework::LoDTensor>("W");
int padding_idx = static_cast<int>(ctx.Attr<int64_t>("padding_idx"));
auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(
......@@ -38,43 +38,10 @@ class LookupTableV2MLUKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc table_desc(*table_t);
MLUCnnlTensorDesc output_desc(*output_t);
int64_t padding_idx = ctx.Attr<int64_t>("padding_idx");
if (padding_idx == kNoPadding) {
MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0,
table_desc.get(), GetBasePtr(table_t),
ids_desc.get(), GetBasePtr(ids_t),
output_desc.get(), GetBasePtr(output_t));
} else {
Tensor tmp_table_t(table_t->type());
tmp_table_t.mutable_data<T>(table_t->dims(), ctx.GetPlace());
Tensor index;
index.mutable_data<int32_t>({1, 1}, ctx.GetPlace());
auto idx_value = static_cast<int32_t>(padding_idx);
MLUCnnlTensorDesc index_desc(index);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &idx_value, index_desc.get(),
GetBasePtr(&index));
auto update_dim = phi::make_ddim({1, table_t->dims()[1]});
Tensor update;
update.mutable_data<T>(update_dim, ctx.GetPlace());
auto update_value = static_cast<T>(0);
MLUCnnlTensorDesc update_desc(update);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &update_value,
update_desc.get(), GetBasePtr(&update));
MLUCnnlTensorDesc tmp_table_desc(tmp_table_t);
MLUCnnl::ScatterNd(
ctx, CNNL_SCATTERND_UPDATE, index_desc.get(), GetBasePtr(&index),
update_desc.get(), GetBasePtr(&update), table_desc.get(),
GetBasePtr(table_t), tmp_table_desc.get(), GetBasePtr(&tmp_table_t));
MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0,
tmp_table_desc.get(), GetBasePtr(&tmp_table_t),
ids_desc.get(), GetBasePtr(ids_t),
output_desc.get(), GetBasePtr(output_t));
}
MLUCnnl::EmbeddingForward(ctx, padding_idx, table_desc.get(),
GetBasePtr(table_t), ids_desc.get(),
static_cast<const int *>(GetBasePtr(ids_t)),
output_desc.get(), GetBasePtr(output_t));
}
};
......@@ -82,6 +49,16 @@ template <typename T>
class LookupTableV2GradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(table_var->IsType<framework::LoDTensor>(), true,
platform::errors::PermissionDenied(
"Unsupported Variable Type , idx in "
"LookupTableV2GradMLUKernel should be LoDTensor."));
bool is_sparse = ctx.Attr<bool>("is_sparse");
PADDLE_ENFORCE_EQ(
is_sparse, false,
platform::errors::InvalidArgument(
"LookupTableV2GradMLUKernel dose NOT support is_sparse = True."));
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");
auto *output_grad_t =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
......@@ -91,6 +68,13 @@ class LookupTableV2GradMLUKernel : public framework::OpKernel<T> {
int padding_idx = static_cast<int>(ctx.Attr<int64_t>("padding_idx"));
int64_t ids_numel = ids_t->numel();
PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(), true,
platform::errors::OutOfRange(
"Number of ids greater than int32_t::max , please check "
"number of ids in LookupTableV2GradMLUKernel."));
Tensor ids_int32(ids_t->dtype());
if (ids_t->dtype() != DataType::INT32) {
ids_int32.mutable_data<int>(ids_t->dims(), ctx.GetPlace());
......@@ -125,5 +109,4 @@ REGISTER_OP_MLU_KERNEL(lookup_table_v2, ops::LookupTableV2MLUKernel<float>,
REGISTER_OP_MLU_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradMLUKernel<float>,
ops::LookupTableV2GradMLUKernel<int>,
ops::LookupTableV2GradMLUKernel<plat::float16>);
......@@ -2802,6 +2802,18 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
cnnlReciprocal(handle, input_desc, input, output_desc, output));
}
/* static */ void MLUCnnl::EmbeddingForward(
const ExecutionContext& ctx, const int padding_idx,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t indices_desc, const int* indices,
const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlEmbeddingForward_v2(
handle, weight_desc, weight, indices_desc, indices, padding_idx,
nullptr /*max_norm*/, nullptr /*norm_type*/, output_desc, output));
}
/* static */ void MLUCnnl::EmbeddingBackward(
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
const cnnlTensorDescriptor_t indices_desc, const void* indices,
......
......@@ -1268,6 +1268,12 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_desc,
void* output);
static void EmbeddingForward(
const ExecutionContext& ctx, const int padding_idx,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t indices_desc, const int* indices,
const cnnlTensorDescriptor_t output_desc, void* output);
static void EmbeddingBackward(
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
const cnnlTensorDescriptor_t indices_desc, const void* indices,
......
......@@ -237,8 +237,8 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
ctx.device_context(), &skip_update_vec);
skip_update = skip_update_vec[0];
}
VLOG(3) << "Skip update" << skip_update;
bool with_decay = ctx.Attr<bool>("with_decay");
VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay;
if (!skip_update && with_decay) {
if (ctx.HasInput("MasterParam")) {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -57,15 +57,16 @@ class SoftmaxOp : public framework::OperatorWithKernel {
}
#endif
#ifndef PADDLE_WITH_ASCEND_CL
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU/XPU place"));
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU place"));
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
......@@ -174,9 +175,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP16) {
if (!(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace())))
platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace())))
PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU place"));
"float16 can only be used on GPU/NPU/XPU/MLU place"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
......
......@@ -97,7 +97,6 @@ class TestLookupTableV2FP16(TestLookupTableV2):
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
self.__class__.no_need_check_grad = True
class TestLookupTableV2Dim32(TestLookupTableV2):
......@@ -126,7 +125,6 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2):
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
self.__class__.no_need_check_grad = True
class TestLookupTableV2WithPadding(TestLookupTableV2):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册