From 229ec32abf9596348b547aee5217dd5507dd32a3 Mon Sep 17 00:00:00 2001 From: qipengh Date: Fri, 18 Feb 2022 10:06:47 +0800 Subject: [PATCH] [MLU]add matmul and matmul_v2 op (#39539) * [MLU]add matmul and matmul_v2 op * [MLU] fix data_type and del matmul * [MLU] fix compile error * [MLU] fix ci_check error --- .../fluid/imperative/gradient_accumulator.cc | 4 +- paddle/fluid/operators/activation_op_mlu.cc | 25 +- paddle/fluid/operators/concat_op_mlu.cc | 8 +- paddle/fluid/operators/conv_op_mlu.cc | 10 +- .../fluid/operators/fill_constant_op_mlu.cc | 5 +- paddle/fluid/operators/matmul_v2_op_mlu.cc | 319 +++++++++++++++ paddle/fluid/operators/mean_op_mlu.cc | 27 +- .../operators/metrics/accuracy_op_mlu.cc | 37 +- paddle/fluid/operators/mlu/mlu_baseop.cc | 5 +- paddle/fluid/operators/mlu/mlu_baseop.h | 45 +-- .../reduce_ops/reduce_mean_op_mlu.cc | 24 +- paddle/fluid/operators/scale_op_mlu.cc | 10 +- .../softmax_with_cross_entropy_op_mlu.cc | 2 +- paddle/fluid/operators/split_op_mlu.cc | 8 +- paddle/fluid/operators/sum_op_mlu.cc | 8 +- paddle/fluid/operators/top_k_op_mlu.cc | 2 +- paddle/fluid/operators/top_k_v2_op_mlu.cc | 2 +- .../unittests/mlu/test_matmul_v2_op_mlu.py | 378 ++++++++++++++++++ 18 files changed, 794 insertions(+), 125 deletions(-) create mode 100644 paddle/fluid/operators/matmul_v2_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 2020e2900c..7e61d3dab1 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -383,9 +383,9 @@ void TensorAdd(const VarType& src, VarType* dst) { operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor); operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor); PADDLE_ENFORCE_MLU_SUCCESS(cnnlAssignAdd( - dev_ctx->cnnl_handle(), static_cast(&alpha), + dev_ctx->cnnl_handle(), static_cast(&alpha), src_tensor_desc.get(), operators::GetBasePtr(&src_tensor), nullptr, 0, - static_cast(&beta), dst_tensor_desc.get(), + static_cast(&beta), dst_tensor_desc.get(), operators::GetBasePtr(dst_tensor))); return; } diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc index a2ac1e7d42..b9b2d9ed05 100644 --- a/paddle/fluid/operators/activation_op_mlu.cc +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -38,12 +38,10 @@ class ActivationMLUKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); MLUCnnlActivationDesc act_desc(act_mode, alpha); - MLUCnnlTensorDesc input_desc( - *input, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(input->dtype()))); - MLUCnnlTensorDesc output_desc( - *output, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(output->dtype()))); + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->dtype())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->dtype())); MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), reinterpret_cast(input->data()), @@ -63,15 +61,12 @@ class ActivationGradMLUKernel : public framework::OpKernel { dx->mutable_data(ctx.GetPlace()); - MLUCnnlTensorDesc dout_desc( - *dout, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(dout->dtype()))); - MLUCnnlTensorDesc out_desc( - *out, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(out->dtype()))); - MLUCnnlTensorDesc dx_desc( - *dx, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(dx->dtype()))); + MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(dout->dtype())); + MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(out->dtype())); + MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(dx->dtype())); MLUCnnlActivationDesc act_desc(act_mode, alpha); MLUCnnl::ActiveGrad( ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, diff --git a/paddle/fluid/operators/concat_op_mlu.cc b/paddle/fluid/operators/concat_op_mlu.cc index 1a6e29f3ac..73e374b629 100644 --- a/paddle/fluid/operators/concat_op_mlu.cc +++ b/paddle/fluid/operators/concat_op_mlu.cc @@ -61,15 +61,13 @@ class ConcatMLUKernel : public framework::OpKernel { std::vector desc_vector; for (size_t i = 0; i < ins_size; i++) { input_descs.emplace_back(MLUCnnlTensorDesc( - *ins[i], CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(ins[i]->dtype())))); + *ins[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(ins[i]->dtype()))); desc_vector.push_back(input_descs.back().get()); inputs.push_back(GetBasePtr(ins[i])); } // init out tensors - MLUCnnlTensorDesc output_desc( - *out, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(out->dtype()))); + MLUCnnlTensorDesc output_desc(*out, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(out->dtype())); // MLU should do sth MLUCnnl::Concat(ctx, ins_size_t, axis_t, desc_vector.data(), inputs.data(), diff --git a/paddle/fluid/operators/conv_op_mlu.cc b/paddle/fluid/operators/conv_op_mlu.cc index 2155de4d05..fa95f38f5f 100644 --- a/paddle/fluid/operators/conv_op_mlu.cc +++ b/paddle/fluid/operators/conv_op_mlu.cc @@ -80,14 +80,12 @@ class MLUConvOpKernel : public framework::OpKernel { true /*need_reshape_or_alloc*/); cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; - MLUCnnlTensorDesc input_desc( - input_tensor, data_layout, - ToCnnlDataType(framework::TransToProtoVarType(input_tensor.dtype()))); + MLUCnnlTensorDesc input_desc(input_tensor, data_layout, + ToCnnlDataType(input_tensor.dtype())); MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, ToCnnlDataType(trans_filter.type())); - MLUCnnlTensorDesc output_desc( - output_tensor, data_layout, - ToCnnlDataType(framework::TransToProtoVarType(output_tensor.dtype()))); + MLUCnnlTensorDesc output_desc(output_tensor, data_layout, + ToCnnlDataType(output_tensor.dtype())); MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), strides.data(), dilations.data(), groups, diff --git a/paddle/fluid/operators/fill_constant_op_mlu.cc b/paddle/fluid/operators/fill_constant_op_mlu.cc index 609af914d4..10e7c72d15 100644 --- a/paddle/fluid/operators/fill_constant_op_mlu.cc +++ b/paddle/fluid/operators/fill_constant_op_mlu.cc @@ -72,9 +72,8 @@ class FillConstantMLUKernel : public framework::OpKernel { auto shape = GetShape(ctx); out_var->mutable_data(shape, ctx.GetPlace()); - MLUCnnlTensorDesc output_desc( - *out_var, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(out_var->dtype()))); + MLUCnnlTensorDesc output_desc(*out_var, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(out_var->dtype())); MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var)); } }; diff --git a/paddle/fluid/operators/matmul_v2_op_mlu.cc b/paddle/fluid/operators/matmul_v2_op_mlu.cc new file mode 100644 index 0000000000..1466949367 --- /dev/null +++ b/paddle/fluid/operators/matmul_v2_op_mlu.cc @@ -0,0 +1,319 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +static void Mul(const framework::ExecutionContext& ctx, const Tensor& X, + const Tensor& Y, Tensor* Out) { + Out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), x_desc.get(), GetBasePtr(&X), + y_desc.get(), GetBasePtr(&Y), out_desc.get(), + GetBasePtr(Out), ToCnnlDataType()); +} + +template +static void MatMul2D(const framework::ExecutionContext& ctx, const Tensor& X, + const Tensor& Y, Tensor* Out, const bool trans_x, + const bool trans_y) { + Out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + MLUCnnl::Matmul(ctx, trans_x, trans_y, x_desc.get(), GetBasePtr(&X), + y_desc.get(), GetBasePtr(&Y), out_desc.get(), + GetBasePtr(Out)); +} + +template +static void MatMulND(const framework::ExecutionContext& ctx, const Tensor& X, + const Tensor& Y, Tensor* Out, const bool trans_x, + const bool trans_y) { + if (!Out->initialized()) { + Out->mutable_data(ctx.GetPlace()); + } + + MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + MLUCnnl::BatchMatmul(ctx, trans_x, trans_y, x_desc.get(), GetBasePtr(&X), + y_desc.get(), GetBasePtr(&Y), out_desc.get(), + GetBasePtr(Out)); +} + +template +static void ReduceDims(const framework::ExecutionContext& ctx, + const std::vector& dims, + const std::vector& bcast_dims, const Tensor& in, + Tensor* out) { + std::vector axes; + int64_t size = bcast_dims.size(); + int64_t diff = bcast_dims.size() - dims.size(); + for (int64_t i = 0; i < size; ++i) { + if (i < diff) { + axes.push_back(i); + continue; + } + if (bcast_dims[i] > dims[i - diff]) { + axes.push_back(i); + } + } + out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc in_desc(in, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + std::vector reduce_dims(axes.begin(), axes.end()); + MLUCnnlReduceDesc reduce_desc(reduce_dims, CNNL_REDUCE_ADD, + ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduce_desc.get(), nullptr, + in_desc.get(), GetBasePtr(&in), 0 /*indices_size*/, nullptr, + nullptr, out_desc.get(), GetBasePtr(out)); +} + +template +class MatMulV2MLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* Out = ctx.Output("Out"); + const bool trans_x = ctx.Attr("trans_x"); + const bool trans_y = ctx.Attr("trans_y"); + + std::vector x_dims = framework::vectorize(X->dims()); + std::vector y_dims = framework::vectorize(Y->dims()); + std::vector out_dims = framework::vectorize(Out->dims()); + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + + // Case 1: [K] x [K] = [1] + // Equal: [1, K] x [K, 1] = [1, 1] => [1] + const bool all_one_dim = (x_ndim == 1 && y_ndim == 1); + if (all_one_dim) { + Out->Resize({1, 1}); + } + + // Resize dim 1 to 2 + Tensor x_temp, y_temp; + x_temp.ShareDataWith(*X); + y_temp.ShareDataWith(*Y); + if (x_ndim == 1) { + x_dims.insert(x_dims.begin(), 1); + x_temp.Resize(framework::make_ddim(x_dims)); + x_ndim = 2; + // matmul op of mlu needs `std::max(x->dim, y->dim) == out->dim` + if (out_dims.size() < y_dims.size()) { + std::vector temp_out_dims(out_dims.begin(), out_dims.end()); + temp_out_dims.insert(temp_out_dims.end() - 1, 1); + Out->Resize(framework::make_ddim(temp_out_dims)); + } + } + if (y_ndim == 1) { + y_dims.push_back(1); + y_temp.Resize(framework::make_ddim(y_dims)); + y_ndim = 2; + // matmul op of mlu needs `std::max(x->dim, y->dim) == out->dim` + if (out_dims.size() < x_dims.size()) { + std::vector temp_out_dims(out_dims.begin(), out_dims.end()); + temp_out_dims.push_back(1); + Out->Resize(framework::make_ddim(temp_out_dims)); + } + } + + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, + platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, + platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2])); + } + + if (x_ndim == 2 && y_ndim == 2) { + // Case 2: [M, K] x [K, N] = [M, N] + MatMul2D(ctx, x_temp, y_temp, Out, trans_x, trans_y); + } else { + // Case 3: [B, M, K] x [K, N] = [B, M, N] + // Case 4: [B, M, K] x [B, K, N] = [B, M, N] + MatMulND(ctx, x_temp, y_temp, Out, trans_x, trans_y); + } + + if (framework::vectorize(Out->dims()) != out_dims) { + Out->Resize(framework::make_ddim(out_dims)); + } + } +}; + +template +class MatMulGradV2MLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + const bool trans_x = ctx.Attr("trans_x"); + const bool trans_y = ctx.Attr("trans_y"); + + std::vector x_dims = framework::vectorize(X->dims()); + std::vector y_dims = framework::vectorize(Y->dims()); + std::vector out_dims = framework::vectorize(dOut->dims()); + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int out_ndim = out_dims.size(); + + // Case 1: [K] x [K] = [1] + if (x_ndim == 1 && y_ndim == 1) { + if (dX) { + Mul(ctx, *dOut, *Y, dX); + } + if (dY) { + Mul(ctx, *dOut, *X, dY); + } + return; + } + + // Resize dim 1 to 2 + Tensor x_temp, y_temp, dout_temp; + x_temp.ShareDataWith(*X); + y_temp.ShareDataWith(*Y); + dout_temp.ShareDataWith(*dOut); + if (x_ndim == 1) { + x_dims.insert(x_dims.begin(), 1); + out_dims.insert(out_dims.end() - 1, 1); + x_temp.Resize(framework::make_ddim(x_dims)); + dout_temp.Resize(framework::make_ddim(out_dims)); + x_ndim = 2; + out_ndim += 1; + } + if (y_ndim == 1) { + y_dims.push_back(1); + out_dims.push_back(1); + y_temp.Resize(framework::make_ddim(y_dims)); + dout_temp.Resize(framework::make_ddim(out_dims)); + y_ndim = 2; + out_ndim += 1; + } + + // Case 2: [M, K] x [K, N] = [M, N] + if (out_ndim == 2) { + if (dX) { + dX->Resize(framework::make_ddim(x_dims)); + if (trans_x) { + MatMul2D(ctx, y_temp, dout_temp, dX, trans_y, true); + } else { + MatMul2D(ctx, dout_temp, y_temp, dX, false, !trans_y); + } + dX->Resize(X->dims()); + } + if (dY) { + dY->Resize(framework::make_ddim(y_dims)); + if (trans_y) { + MatMul2D(ctx, dout_temp, x_temp, dY, true, trans_x); + } else { + MatMul2D(ctx, x_temp, dout_temp, dY, !trans_x, false); + } + dY->Resize(Y->dims()); + } + return; + } + + // Case 3: [B, M, K] x [K, N] = [B, M, N] + // Case 4: [B, M, K] x [B, K, N] = [B, M, N] + std::vector x_bcast_dims(out_ndim, 1); + std::vector y_bcast_dims(out_ndim, 1); + std::copy(out_dims.begin(), out_dims.end() - 2, x_bcast_dims.begin()); + std::copy(out_dims.begin(), out_dims.end() - 2, y_bcast_dims.begin()); + std::copy(x_dims.end() - 2, x_dims.end(), x_bcast_dims.end() - 2); + std::copy(y_dims.end() - 2, y_dims.end(), y_bcast_dims.end() - 2); + + if (dX) { + Tensor dx_temp(X->type()); + if (x_dims != x_bcast_dims) { + dx_temp.Resize(framework::make_ddim(x_bcast_dims)); + } else { + dX->mutable_data(ctx.GetPlace()); + dx_temp.ShareDataWith(*dX); + } + + if (trans_x) { + MatMulND(ctx, y_temp, dout_temp, &dx_temp, trans_y, true); + } else { + MatMulND(ctx, dout_temp, y_temp, &dx_temp, false, !trans_y); + } + + if (x_dims != x_bcast_dims) { + ReduceDims(ctx, x_dims, x_bcast_dims, dx_temp, dX); + } + } + + if (dY) { + Tensor dy_temp(Y->type()); + if (y_dims != y_bcast_dims) { + dy_temp.Resize(framework::make_ddim(y_bcast_dims)); + } else { + dY->mutable_data(ctx.GetPlace()); + dy_temp.ShareDataWith(*dY); + } + + if (trans_y) { + MatMulND(ctx, dout_temp, x_temp, &dy_temp, true, trans_x); + } else { + MatMulND(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false); + } + + if (y_dims != y_bcast_dims) { + ReduceDims(ctx, y_dims, y_bcast_dims, dy_temp, dY); + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(matmul_v2, ops::MatMulV2MLUKernel, + ops::MatMulV2MLUKernel); +REGISTER_OP_MLU_KERNEL(matmul_v2_grad, ops::MatMulGradV2MLUKernel, + ops::MatMulGradV2MLUKernel); diff --git a/paddle/fluid/operators/mean_op_mlu.cc b/paddle/fluid/operators/mean_op_mlu.cc index c25b25aa50..f8246165c5 100644 --- a/paddle/fluid/operators/mean_op_mlu.cc +++ b/paddle/fluid/operators/mean_op_mlu.cc @@ -45,12 +45,10 @@ class MeanMLUKernel : public framework::OpKernel { reduce_dims.push_back(i); } - MLUCnnlTensorDesc input_desc( - *input, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(input->dtype()))); - MLUCnnlTensorDesc output_desc( - *output, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(output->dtype()))); + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->dtype())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->dtype())); MLUCnnlReduceDesc reduction_desc( reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType(), @@ -90,21 +88,18 @@ class MeanMLUGradKernel : public framework::OpKernel { } // means - Tensor mean_var(framework::TransToProtoVarType(output_grad->dtype())); + Tensor mean_var(output_grad->dtype()); mean_var.mutable_data(input_grad->dims(), context.GetPlace()); - MLUCnnlTensorDesc mean_var_desc( - mean_var, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(mean_var.dtype()))); + MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(mean_var.dtype())); auto value = static_cast(1.0 / static_cast(input_grad->numel())); MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var)); // means mul output_grad - MLUCnnlTensorDesc in_desc( - *output_grad, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(output_grad->dtype()))); - MLUCnnlTensorDesc out_desc( - *input_grad, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(input_grad->dtype()))); + MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output_grad->dtype())); + MLUCnnlTensorDesc out_desc(*input_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input_grad->dtype())); MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN); diff --git a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc index 0649f9172e..a22f66aff7 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc @@ -35,39 +35,40 @@ class AccuracyMLUKernel : public framework::OpKernel { } // cast `indices` or `label` if their type is not INT32 - Tensor indices_int32(VT::INT32); - Tensor label_int32(VT::INT32); - if (indices->type() != VT::INT32) { - PADDLE_ENFORCE_EQ(MLUSupportsCast(indices->type(), VT::INT32), true, - platform::errors::Unavailable( + Tensor indices_int32(framework::TransToPtenDataType(VT::INT32)); + Tensor label_int32(framework::TransToPtenDataType(VT::INT32)); + auto indices_type = framework::TransToProtoVarType(indices->type()); + if (indices_type != VT::INT32) { + PADDLE_ENFORCE_EQ(MLUSupportsCast(indices_type, VT::INT32), true, + platform::errors::Unimplemented( "In accuracy mlu kernel, cast indices from [%s] to " "[%s] is not supported.", - framework::DataTypeToString(indices->type()), + framework::DataTypeToString(indices_type), framework::DataTypeToString(VT::INT32))); indices_int32.Resize(indices->dims()); indices_int32.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc org_indices_desc(*indices); MLUCnnlTensorDesc indices_int32_desc(indices_int32); - cnnlCastDataType_t cast_type = - GetCastDataType(indices->type(), VT::INT32); + cnnlCastDataType_t cast_type = GetCastDataType(indices_type, VT::INT32); MLUCnnl::Cast(ctx, cast_type, org_indices_desc.get(), GetBasePtr(indices), indices_int32_desc.get(), GetBasePtr(&indices_int32)); } else { indices_int32.ShareDataWith(*indices); } - if (label->type() != VT::INT32) { + auto label_type = framework::TransToProtoVarType(label->type()); + if (label_type != VT::INT32) { PADDLE_ENFORCE_EQ( - MLUSupportsCast(label->type(), VT::INT32), true, - platform::errors::Unavailable( + MLUSupportsCast(label_type, VT::INT32), true, + platform::errors::Unimplemented( "In accuracy mlu kernel, cast label from [%s] to [%s] " "is not supported.", - framework::DataTypeToString(label->type()), + framework::DataTypeToString(label_type), framework::DataTypeToString(VT::INT32))); label_int32.Resize(label->dims()); label_int32.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc org_label_desc(*label); MLUCnnlTensorDesc label_int32_desc(label_int32); - cnnlCastDataType_t cast_type = GetCastDataType(label->type(), VT::INT32); + cnnlCastDataType_t cast_type = GetCastDataType(label_type, VT::INT32); MLUCnnl::Cast(ctx, cast_type, org_label_desc.get(), GetBasePtr(label), label_int32_desc.get(), GetBasePtr(&label_int32)); } else { @@ -77,7 +78,7 @@ class AccuracyMLUKernel : public framework::OpKernel { // equal MLUCnnlTensorDesc indices_int32_desc(indices_int32); MLUCnnlTensorDesc label_int32_desc(label_int32); - Tensor equal_tensor(VT::BOOL); + Tensor equal_tensor(framework::TransToPtenDataType(VT::BOOL)); equal_tensor.Resize(indices->dims()); equal_tensor.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc equal_tensor_desc(equal_tensor); @@ -87,7 +88,7 @@ class AccuracyMLUKernel : public framework::OpKernel { GetBasePtr(&equal_tensor)); // cast equal - Tensor equal_fp32(VT::FP32); + Tensor equal_fp32(framework::TransToPtenDataType(VT::FP32)); equal_fp32.Resize(indices->dims()); equal_fp32.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc equal_fp32_desc(equal_fp32); @@ -98,7 +99,7 @@ class AccuracyMLUKernel : public framework::OpKernel { // [correct] // reduce_max - Tensor correct_max(VT::FP32); + Tensor correct_max(framework::TransToPtenDataType(VT::FP32)); correct_max.Resize(framework::make_ddim({num_samples})); correct_max.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc correct_max_desc(correct_max); @@ -111,7 +112,7 @@ class AccuracyMLUKernel : public framework::OpKernel { correct_max_desc.get(), GetBasePtr(&correct_max)); // reduce_sum - Tensor correct_sum(VT::FP32); + Tensor correct_sum(framework::TransToPtenDataType(VT::FP32)); correct_sum.Resize(correct->dims()); correct_sum.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc correct_sum_desc(correct_sum); @@ -137,7 +138,7 @@ class AccuracyMLUKernel : public framework::OpKernel { MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total)); // use `total` of type `float32` for calculating accuracy - Tensor total_fp32(VT::FP32); + Tensor total_fp32(framework::TransToPtenDataType(VT::FP32)); total_fp32.Resize(total->dims()); total_fp32.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc total_fp32_desc(total_fp32); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 82ea75943d..4937f528dd 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -178,9 +178,8 @@ MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor, } MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor) - : MLUCnnlTensorDesc( - tensor, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(tensor.dtype()))) {} + : MLUCnnlTensorDesc(tensor, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(tensor.dtype())) {} MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor, cnnlTensorLayout_t layout, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 91eddaf792..97d30629dd 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -49,44 +49,32 @@ inline const void* GetBasePtr(const Tensor* t) { return t->data(); } inline void* GetBasePtr(Tensor* t) { return t->data(); } -template -inline cnnlDataType_t ToCnnlDataType(const T& t) { - auto type = framework::ToDataType(t); - return ToCnnlDataType(type); -} - -template -inline cnnlDataType_t ToCnnlDataType() { - auto type = framework::ToDataType(std::type_index(typeid(T))); - return ToCnnlDataType(type); -} - -template <> -inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) { +inline cnnlDataType_t ToCnnlDataType( + const paddle::experimental::DataType& dtype) { cnnlDataType_t type = CNNL_DTYPE_FLOAT; - switch (t) { - case framework::proto::VarType::FP16: + switch (dtype) { + case DataType::FLOAT16: type = CNNL_DTYPE_HALF; break; - case framework::proto::VarType::FP32: + case DataType::FLOAT32: type = CNNL_DTYPE_FLOAT; break; - case framework::proto::VarType::INT8: + case DataType::INT8: type = CNNL_DTYPE_INT8; break; - case framework::proto::VarType::INT16: + case DataType::INT16: type = CNNL_DTYPE_INT16; break; - case framework::proto::VarType::INT32: + case DataType::INT32: type = CNNL_DTYPE_INT32; break; - case framework::proto::VarType::INT64: + case DataType::INT64: type = CNNL_DTYPE_INT64; break; - case framework::proto::VarType::BOOL: + case DataType::BOOL: type = CNNL_DTYPE_BOOL; break; - case framework::proto::VarType::UINT8: + case DataType::UINT8: type = CNNL_DTYPE_UINT8; break; default: @@ -95,6 +83,17 @@ inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) { return type; } +inline cnnlDataType_t ToCnnlDataType( + const paddle::framework::proto::VarType::Type& type) { + return ToCnnlDataType(framework::TransToPtenDataType(type)); +} + +template +inline cnnlDataType_t ToCnnlDataType() { + auto type = framework::ToDataType(std::type_index(typeid(T))); + return ToCnnlDataType(type); +} + // Converts (via narrowing) a type T value to a type U, and checks that the // value has no value change due to the conversion. template diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc index a7ea28314f..029b572e8d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc @@ -46,12 +46,10 @@ class ReduceMeanMLUKernel : public framework::OpKernel { } } - MLUCnnlTensorDesc input_desc( - *input, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(input->dtype()))); - MLUCnnlTensorDesc output_desc( - *output, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(output->dtype()))); + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->dtype())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->dtype())); MLUCnnlReduceDesc reduction_desc( reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType(), @@ -91,8 +89,7 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel { reduce_numel *= input_dims[d]; } - Tensor tmp_output_grad( - framework::TransToProtoVarType(output_grad->dtype())); + Tensor tmp_output_grad(output_grad->dtype()); auto tmp_output_dims = input_dims; for (auto d : reduce_dims) { tmp_output_dims[d] = 1; @@ -100,13 +97,10 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel { tmp_output_grad.ShareDataWith(*output_grad); tmp_output_grad.Resize(framework::make_ddim(tmp_output_dims)); - MLUCnnlTensorDesc output_grad_desc( - tmp_output_grad, CNNL_LAYOUT_ARRAY, - ToCnnlDataType( - framework::TransToProtoVarType(tmp_output_grad.dtype()))); - MLUCnnlTensorDesc input_grad_desc( - *input_grad, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(input_grad->dtype()))); + MLUCnnlTensorDesc output_grad_desc(tmp_output_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(tmp_output_grad.dtype())); + MLUCnnlTensorDesc input_grad_desc(*input_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input_grad->dtype())); auto value = static_cast(1.0 / static_cast(reduce_numel)); MLUCnnl::Fill(context, value, input_grad_desc.get(), diff --git a/paddle/fluid/operators/scale_op_mlu.cc b/paddle/fluid/operators/scale_op_mlu.cc index fa222ecae1..dac37ce540 100644 --- a/paddle/fluid/operators/scale_op_mlu.cc +++ b/paddle/fluid/operators/scale_op_mlu.cc @@ -85,15 +85,13 @@ class ScaleMLUKernel : public framework::OpKernel { ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc new_bias_desc(new_bias_tensor); - MLUCnnlOpTensorDesc mul_op_desc( - CNNL_OP_TENSOR_MUL, - ToCnnlDataType(framework::TransToProtoVarType(in->dtype())), - CNNL_NOT_PROPAGATE_NAN); + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, + ToCnnlDataType(in->dtype()), + CNNL_NOT_PROPAGATE_NAN); MLUCnnl::OpTensor( ctx, mul_op_desc.get(), scale_desc.get(), GetBasePtr(&scale_tensor), bias_desc.get(), GetBasePtr(&bias_tensor), new_bias_desc.get(), - GetBasePtr(&new_bias_tensor), - ToCnnlDataType(framework::TransToProtoVarType(in->dtype()))); + GetBasePtr(&new_bias_tensor), ToCnnlDataType(in->dtype())); MLUCnnl::Scale(ctx, axis, input_desc.get(), GetBasePtr(in), scale_desc.get(), GetBasePtr(&scale_tensor), new_bias_desc.get(), GetBasePtr(&new_bias_tensor), diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc index 0f14e6dabd..a51f68530c 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc @@ -87,7 +87,7 @@ class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel { platform::errors::InvalidArgument( "If soft_label=False, axis must be -1 or" " can be regard as last dimention in mlu kernel.")); - framework::Tensor labels_int32(VT::INT32); + framework::Tensor labels_int32(framework::TransToPtenDataType(VT::INT32)); labels_int32.Resize(labels->dims()); labels_int32.mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/split_op_mlu.cc b/paddle/fluid/operators/split_op_mlu.cc index cf913ffe94..adc3ea14e3 100644 --- a/paddle/fluid/operators/split_op_mlu.cc +++ b/paddle/fluid/operators/split_op_mlu.cc @@ -61,15 +61,13 @@ class SplitMLUKernel : public framework::OpKernel { for (size_t i = 0; i < outs.size(); i++) { outs[i]->mutable_data(ctx.GetPlace()); output_descs.emplace_back(MLUCnnlTensorDesc( - *outs[i], CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(outs[i]->dtype())))); + *outs[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(outs[i]->dtype()))); desc_vector.push_back(output_descs.back().get()); vct_tensor.push_back(GetBasePtr(outs[i])); } // init in tensors - MLUCnnlTensorDesc input_desc( - *in, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(in->dtype()))); + MLUCnnlTensorDesc input_desc(*in, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(in->dtype())); // MLU should do sth MLUCnnl::Split(ctx, num_tensor, axis, input_desc.get(), GetBasePtr(in), diff --git a/paddle/fluid/operators/sum_op_mlu.cc b/paddle/fluid/operators/sum_op_mlu.cc index b84ac3c798..179c038e83 100644 --- a/paddle/fluid/operators/sum_op_mlu.cc +++ b/paddle/fluid/operators/sum_op_mlu.cc @@ -43,15 +43,13 @@ class SumMLUKernel : public framework::OpKernel { std::vector desc_vector; for (int i = 0; i < ins_size; i++) { input_descs.emplace_back(MLUCnnlTensorDesc( - *ins[i], CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(ins[i]->dtype())))); + *ins[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(ins[i]->dtype()))); desc_vector.push_back(input_descs.back().get()); inputs.push_back(GetBasePtr(ins[i])); } // init out tensors - MLUCnnlTensorDesc output_desc( - *out, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(framework::TransToProtoVarType(out->dtype()))); + MLUCnnlTensorDesc output_desc(*out, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(out->dtype())); uint32_t ins_size_t = static_cast(ins_size); MLUCnnl::AddN(ctx, ins_size_t, desc_vector.data(), inputs.data(), output_desc.get(), GetBasePtr(out)); diff --git a/paddle/fluid/operators/top_k_op_mlu.cc b/paddle/fluid/operators/top_k_op_mlu.cc index e5064ed90d..a9f835f6fe 100644 --- a/paddle/fluid/operators/top_k_op_mlu.cc +++ b/paddle/fluid/operators/top_k_op_mlu.cc @@ -47,7 +47,7 @@ class TopkMLUKernel : public framework::OpKernel { const bool sorted = true; const int axis = -1; // cnnl only support int32/int16 type of indices - framework::Tensor indices_int32(VT::INT32); + framework::Tensor indices_int32(framework::TransToPtenDataType(VT::INT32)); indices_int32.Resize(indices->dims()); indices_int32.mutable_data(place); diff --git a/paddle/fluid/operators/top_k_v2_op_mlu.cc b/paddle/fluid/operators/top_k_v2_op_mlu.cc index cc05e11495..7bada0179a 100644 --- a/paddle/fluid/operators/top_k_v2_op_mlu.cc +++ b/paddle/fluid/operators/top_k_v2_op_mlu.cc @@ -55,7 +55,7 @@ class TopkV2MLUKernel : public framework::OpKernel { indices->mutable_data(place); // cnnl only support int32/int16 type of indices - framework::Tensor indices_int32(VT::INT32); + framework::Tensor indices_int32(framework::TransToPtenDataType(VT::INT32)); indices_int32.Resize(indices->dims()); indices_int32.mutable_data(place); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py new file mode 100644 index 0000000000..011769c29d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py @@ -0,0 +1,378 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from test_matmul_v2_op import reference_matmul + +paddle.enable_static() +SEED = 2022 + + +class TestMatMulV2Op(OpTest): + """ + case 1 + """ + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def config(self): + self.x_shape = (100, ) + self.y_shape = (100, ) + self.trans_x = False + self.trans_y = False + + def init_kernel_type(self): + self.dtype = "float32" + + def setUp(self): + self.set_mlu() + self.init_kernel_type() + self.config() + self.op_type = "matmul_v2" + x = np.random.random(self.x_shape).astype(self.dtype) + y = np.random.random(self.y_shape).astype(self.dtype) + # -0.1 ~ 0.1 + x = -0.1 + 0.2 * x + y = -0.1 + 0.2 * y + result = reference_matmul(x, y, self.trans_x, self.trans_y) + result = result.astype(self.dtype) + self.inputs = { + 'X': x, + 'Y': y, + } + self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y} + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-7) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + +class TestMatMuklOp2(TestMatMulV2Op): + """ + case 2 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 3, 2, 100) + self.trans_x = False + self.trans_y = True + + +class TestMatMuklOp3(TestMatMulV2Op): + """ + case 3 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp4(TestMatMulV2Op): + """ + case 4 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 2, 100, 2) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp5(TestMatMulV2Op): + """ + case 5 + """ + + def config(self): + self.x_shape = (1, 1, 100, 1) + self.y_shape = (100, ) + self.trans_x = True + self.trans_y = False + + +class TestMatMuklOp6(TestMatMulV2Op): + """ + case 6 + """ + + def config(self): + self.x_shape = (1, 2, 102, 1) + self.y_shape = (102, ) + self.trans_x = True + self.trans_y = False + + +class TestMatMuklOp7(TestMatMulV2Op): + """ + case 7 + """ + + def config(self): + self.x_shape = (1, 2, 1, 100) + self.y_shape = (100, ) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp8(TestMatMulV2Op): + """ + case 8 + """ + + def config(self): + self.x_shape = (1, 1, 2, 100) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp9(TestMatMulV2Op): + """ + case 9 + """ + + def config(self): + self.x_shape = (1, 1, 1, 100) + self.y_shape = (2, 1, 2, 100) + self.trans_x = False + self.trans_y = True + + +class TestMatMuklOp10(TestMatMulV2Op): + """ + case 10 + """ + + def config(self): + self.x_shape = (1, 1, 25, 4) + self.y_shape = (1, 2, 4, 25) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp11(TestMatMulV2Op): + """ + case 11 + """ + + def config(self): + self.x_shape = (2, 1, 2, 100) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp12(TestMatMulV2Op): + """ + case 12 + """ + + def config(self): + self.x_shape = (2, 1, 4, 25) + self.y_shape = (1, 1, 4, 25) + self.trans_x = True + self.trans_y = False + + +class TestMatMuklOp13(TestMatMulV2Op): + """ + case 13 + """ + + def config(self): + self.x_shape = (2, 2, 10, 10) + self.y_shape = (2, 2, 10, 10) + self.trans_x = True + self.trans_y = False + + +class TestMatMuklOp14(TestMatMulV2Op): + """ + case 14_1 + """ + + def config(self): + self.x_shape = (3, 1, 6, 6) + self.y_shape = (1, 2, 6, 9) + self.trans_x = True + self.trans_y = False + + +class TestMatMuklOp15(TestMatMulV2Op): + """ + case 14_2 + """ + + def config(self): + self.x_shape = (3, 1, 6, 6) + self.y_shape = (1, 2, 6, 9) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp16(TestMatMulV2Op): + """ + case 16 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (100) + self.y_shape = (1, 2, 2, 100, 2) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp17(TestMatMulV2Op): + """ + case 17 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (2, 1, 100) + self.y_shape = (100) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOpBroadcast1(TestMatMulV2Op): + """ + case 14_3 + """ + + def config(self): + self.x_shape = (3, 1, 10, 10) + self.y_shape = (1, 2, 10, 10) + self.trans_x = True + self.trans_y = True + + +class TestMatMuklOpBroadcast2(TestMatMulV2Op): + """ + case 14_4 + """ + + def config(self): + self.x_shape = (3, 1, 10, 10) + self.y_shape = (1, 2, 10, 10) + self.trans_x = False + self.trans_y = True + + +#--------------------test matmul fp16-------------------- + + +def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5): + class TestMatMulOpFp16Case(parent): + def init_kernel_type(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=atol) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X', 'Y'], + 'Out', + max_relative_error=max_relative_error) + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16") + TestMatMulOpFp16Case.__name__ = cls_name + globals()[cls_name] = TestMatMulOpFp16Case + + +create_test_fp16_class(TestMatMulV2Op) +create_test_fp16_class(TestMatMuklOp2) +create_test_fp16_class(TestMatMuklOp3) +create_test_fp16_class(TestMatMuklOp4) +create_test_fp16_class(TestMatMuklOp5) +create_test_fp16_class(TestMatMuklOp6) +create_test_fp16_class(TestMatMuklOp7) +create_test_fp16_class(TestMatMuklOp8) +create_test_fp16_class(TestMatMuklOp9) +create_test_fp16_class(TestMatMuklOp10) +create_test_fp16_class(TestMatMuklOp11) +create_test_fp16_class(TestMatMuklOp12) +create_test_fp16_class(TestMatMuklOp13) +create_test_fp16_class(TestMatMuklOp14) +create_test_fp16_class(TestMatMuklOp15) +create_test_fp16_class(TestMatMuklOp16) +create_test_fp16_class(TestMatMuklOp17) + + +class TestMatMulV2API(unittest.TestCase): + def setUp(self): + self.places = [paddle.CPUPlace()] + if paddle.is_compiled_with_mlu(): + self.places.append(paddle.device.MLUPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32") + input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32") + + result = paddle.matmul(input_x, input_y) + + x_np = np.random.random([4, 3]).astype("float32") + y_np = np.random.random([3, 4]).astype("float32") + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input_x": x_np, + "input_y": y_np}, + fetch_list=[result]) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_x = np.random.random([4, 3]).astype("float32") + input_y = np.random.random([3, 4]).astype("float32") + x = paddle.to_tensor(input_x) + y = paddle.to_tensor(input_y) + result = paddle.matmul(x, y) + + def test_dygraph_fp16(self): + if paddle.is_compiled_with_mlu(): + place = paddle.device.MLUPlace(0) + with fluid.dygraph.guard(place): + input_x = np.random.random([4, 3]).astype("float16") + input_y = np.random.random([3, 4]).astype("float16") + x = paddle.to_tensor(input_x) + y = paddle.to_tensor(input_y) + result = paddle.matmul(x, y) + + +if __name__ == '__main__': + unittest.main() -- GitLab