未验证 提交 229ec32a 编写于 作者: Q qipengh 提交者: GitHub

[MLU]add matmul and matmul_v2 op (#39539)

* [MLU]add matmul and matmul_v2 op

* [MLU] fix data_type and del matmul

* [MLU] fix compile error

* [MLU] fix ci_check error
上级 50fb57c9
...@@ -383,9 +383,9 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -383,9 +383,9 @@ void TensorAdd(const VarType& src, VarType* dst) {
operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor); operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor); operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlAssignAdd( PADDLE_ENFORCE_MLU_SUCCESS(cnnlAssignAdd(
dev_ctx->cnnl_handle(), static_cast<void*>(&alpha), dev_ctx->cnnl_handle(), static_cast<const void*>(&alpha),
src_tensor_desc.get(), operators::GetBasePtr(&src_tensor), nullptr, 0, src_tensor_desc.get(), operators::GetBasePtr(&src_tensor), nullptr, 0,
static_cast<void*>(&beta), dst_tensor_desc.get(), static_cast<const void*>(&beta), dst_tensor_desc.get(),
operators::GetBasePtr(dst_tensor))); operators::GetBasePtr(dst_tensor)));
return; return;
} }
......
...@@ -38,12 +38,10 @@ class ActivationMLUKernel : public framework::OpKernel<T> { ...@@ -38,12 +38,10 @@ class ActivationMLUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
MLUCnnlActivationDesc act_desc(act_mode, alpha); MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnlTensorDesc input_desc( MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
*input, CNNL_LAYOUT_ARRAY, ToCnnlDataType(input->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(input->dtype()))); MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
MLUCnnlTensorDesc output_desc( ToCnnlDataType(output->dtype()));
*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(framework::TransToProtoVarType(output->dtype())));
MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(),
reinterpret_cast<const void*>(input->data<T>()), reinterpret_cast<const void*>(input->data<T>()),
...@@ -63,15 +61,12 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> { ...@@ -63,15 +61,12 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc dout_desc( MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY,
*dout, CNNL_LAYOUT_ARRAY, ToCnnlDataType(dout->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(dout->dtype()))); MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY,
MLUCnnlTensorDesc out_desc( ToCnnlDataType(out->dtype()));
*out, CNNL_LAYOUT_ARRAY, MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(framework::TransToProtoVarType(out->dtype()))); ToCnnlDataType(dx->dtype()));
MLUCnnlTensorDesc dx_desc(
*dx, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(framework::TransToProtoVarType(dx->dtype())));
MLUCnnlActivationDesc act_desc(act_mode, alpha); MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad( MLUCnnl::ActiveGrad(
ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
......
...@@ -61,15 +61,13 @@ class ConcatMLUKernel : public framework::OpKernel<T> { ...@@ -61,15 +61,13 @@ class ConcatMLUKernel : public framework::OpKernel<T> {
std::vector<cnnlTensorDescriptor_t> desc_vector; std::vector<cnnlTensorDescriptor_t> desc_vector;
for (size_t i = 0; i < ins_size; i++) { for (size_t i = 0; i < ins_size; i++) {
input_descs.emplace_back(MLUCnnlTensorDesc( input_descs.emplace_back(MLUCnnlTensorDesc(
*ins[i], CNNL_LAYOUT_ARRAY, *ins[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(ins[i]->dtype())));
ToCnnlDataType(framework::TransToProtoVarType(ins[i]->dtype()))));
desc_vector.push_back(input_descs.back().get()); desc_vector.push_back(input_descs.back().get());
inputs.push_back(GetBasePtr(ins[i])); inputs.push_back(GetBasePtr(ins[i]));
} }
// init out tensors // init out tensors
MLUCnnlTensorDesc output_desc( MLUCnnlTensorDesc output_desc(*out, CNNL_LAYOUT_ARRAY,
*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType(out->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(out->dtype())));
// MLU should do sth // MLU should do sth
MLUCnnl::Concat(ctx, ins_size_t, axis_t, desc_vector.data(), inputs.data(), MLUCnnl::Concat(ctx, ins_size_t, axis_t, desc_vector.data(), inputs.data(),
......
...@@ -80,14 +80,12 @@ class MLUConvOpKernel : public framework::OpKernel<T> { ...@@ -80,14 +80,12 @@ class MLUConvOpKernel : public framework::OpKernel<T> {
true /*need_reshape_or_alloc*/); true /*need_reshape_or_alloc*/);
cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC;
MLUCnnlTensorDesc input_desc( MLUCnnlTensorDesc input_desc(input_tensor, data_layout,
input_tensor, data_layout, ToCnnlDataType(input_tensor.dtype()));
ToCnnlDataType(framework::TransToProtoVarType(input_tensor.dtype())));
MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, MLUCnnlTensorDesc filter_desc(trans_filter, data_layout,
ToCnnlDataType(trans_filter.type())); ToCnnlDataType(trans_filter.type()));
MLUCnnlTensorDesc output_desc( MLUCnnlTensorDesc output_desc(output_tensor, data_layout,
output_tensor, data_layout, ToCnnlDataType(output_tensor.dtype()));
ToCnnlDataType(framework::TransToProtoVarType(output_tensor.dtype())));
MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(),
strides.data(), dilations.data(), groups, strides.data(), dilations.data(), groups,
......
...@@ -72,9 +72,8 @@ class FillConstantMLUKernel : public framework::OpKernel<T> { ...@@ -72,9 +72,8 @@ class FillConstantMLUKernel : public framework::OpKernel<T> {
auto shape = GetShape(ctx); auto shape = GetShape(ctx);
out_var->mutable_data<T>(shape, ctx.GetPlace()); out_var->mutable_data<T>(shape, ctx.GetPlace());
MLUCnnlTensorDesc output_desc( MLUCnnlTensorDesc output_desc(*out_var, CNNL_LAYOUT_ARRAY,
*out_var, CNNL_LAYOUT_ARRAY, ToCnnlDataType(out_var->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(out_var->dtype())));
MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var)); MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var));
} }
}; };
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
static void Mul(const framework::ExecutionContext& ctx, const Tensor& X,
const Tensor& Y, Tensor* Out) {
Out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx, mul_op_desc.get(), x_desc.get(), GetBasePtr(&X),
y_desc.get(), GetBasePtr(&Y), out_desc.get(),
GetBasePtr(Out), ToCnnlDataType<T>());
}
template <typename T>
static void MatMul2D(const framework::ExecutionContext& ctx, const Tensor& X,
const Tensor& Y, Tensor* Out, const bool trans_x,
const bool trans_y) {
Out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnl::Matmul(ctx, trans_x, trans_y, x_desc.get(), GetBasePtr(&X),
y_desc.get(), GetBasePtr(&Y), out_desc.get(),
GetBasePtr(Out));
}
template <typename T>
static void MatMulND(const framework::ExecutionContext& ctx, const Tensor& X,
const Tensor& Y, Tensor* Out, const bool trans_x,
const bool trans_y) {
if (!Out->initialized()) {
Out->mutable_data<T>(ctx.GetPlace());
}
MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnl::BatchMatmul(ctx, trans_x, trans_y, x_desc.get(), GetBasePtr(&X),
y_desc.get(), GetBasePtr(&Y), out_desc.get(),
GetBasePtr(Out));
}
template <typename T>
static void ReduceDims(const framework::ExecutionContext& ctx,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& bcast_dims, const Tensor& in,
Tensor* out) {
std::vector<int64_t> axes;
int64_t size = bcast_dims.size();
int64_t diff = bcast_dims.size() - dims.size();
for (int64_t i = 0; i < size; ++i) {
if (i < diff) {
axes.push_back(i);
continue;
}
if (bcast_dims[i] > dims[i - diff]) {
axes.push_back(i);
}
}
out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc in_desc(in, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
std::vector<int> reduce_dims(axes.begin(), axes.end());
MLUCnnlReduceDesc reduce_desc(reduce_dims, CNNL_REDUCE_ADD,
ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduce_desc.get(), nullptr,
in_desc.get(), GetBasePtr(&in), 0 /*indices_size*/, nullptr,
nullptr, out_desc.get(), GetBasePtr(out));
}
template <typename T>
class MatMulV2MLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<framework::Tensor>("X");
auto* Y = ctx.Input<framework::Tensor>("Y");
auto* Out = ctx.Output<framework::Tensor>("Out");
const bool trans_x = ctx.Attr<bool>("trans_x");
const bool trans_y = ctx.Attr<bool>("trans_y");
std::vector<int64_t> x_dims = framework::vectorize(X->dims());
std::vector<int64_t> y_dims = framework::vectorize(Y->dims());
std::vector<int64_t> out_dims = framework::vectorize(Out->dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
// Case 1: [K] x [K] = [1]
// Equal: [1, K] x [K, 1] = [1, 1] => [1]
const bool all_one_dim = (x_ndim == 1 && y_ndim == 1);
if (all_one_dim) {
Out->Resize({1, 1});
}
// Resize dim 1 to 2
Tensor x_temp, y_temp;
x_temp.ShareDataWith(*X);
y_temp.ShareDataWith(*Y);
if (x_ndim == 1) {
x_dims.insert(x_dims.begin(), 1);
x_temp.Resize(framework::make_ddim(x_dims));
x_ndim = 2;
// matmul op of mlu needs `std::max(x->dim, y->dim) == out->dim`
if (out_dims.size() < y_dims.size()) {
std::vector<int64_t> temp_out_dims(out_dims.begin(), out_dims.end());
temp_out_dims.insert(temp_out_dims.end() - 1, 1);
Out->Resize(framework::make_ddim(temp_out_dims));
}
}
if (y_ndim == 1) {
y_dims.push_back(1);
y_temp.Resize(framework::make_ddim(y_dims));
y_ndim = 2;
// matmul op of mlu needs `std::max(x->dim, y->dim) == out->dim`
if (out_dims.size() < x_dims.size()) {
std::vector<int64_t> temp_out_dims(out_dims.begin(), out_dims.end());
temp_out_dims.push_back(1);
Out->Resize(framework::make_ddim(temp_out_dims));
}
}
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (trans_y) {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2]));
}
if (x_ndim == 2 && y_ndim == 2) {
// Case 2: [M, K] x [K, N] = [M, N]
MatMul2D<T>(ctx, x_temp, y_temp, Out, trans_x, trans_y);
} else {
// Case 3: [B, M, K] x [K, N] = [B, M, N]
// Case 4: [B, M, K] x [B, K, N] = [B, M, N]
MatMulND<T>(ctx, x_temp, y_temp, Out, trans_x, trans_y);
}
if (framework::vectorize(Out->dims()) != out_dims) {
Out->Resize(framework::make_ddim(out_dims));
}
}
};
template <typename T>
class MatMulGradV2MLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<framework::Tensor>("X");
auto* Y = ctx.Input<framework::Tensor>("Y");
auto* dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
const bool trans_x = ctx.Attr<bool>("trans_x");
const bool trans_y = ctx.Attr<bool>("trans_y");
std::vector<int64_t> x_dims = framework::vectorize(X->dims());
std::vector<int64_t> y_dims = framework::vectorize(Y->dims());
std::vector<int64_t> out_dims = framework::vectorize(dOut->dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int out_ndim = out_dims.size();
// Case 1: [K] x [K] = [1]
if (x_ndim == 1 && y_ndim == 1) {
if (dX) {
Mul<T>(ctx, *dOut, *Y, dX);
}
if (dY) {
Mul<T>(ctx, *dOut, *X, dY);
}
return;
}
// Resize dim 1 to 2
Tensor x_temp, y_temp, dout_temp;
x_temp.ShareDataWith(*X);
y_temp.ShareDataWith(*Y);
dout_temp.ShareDataWith(*dOut);
if (x_ndim == 1) {
x_dims.insert(x_dims.begin(), 1);
out_dims.insert(out_dims.end() - 1, 1);
x_temp.Resize(framework::make_ddim(x_dims));
dout_temp.Resize(framework::make_ddim(out_dims));
x_ndim = 2;
out_ndim += 1;
}
if (y_ndim == 1) {
y_dims.push_back(1);
out_dims.push_back(1);
y_temp.Resize(framework::make_ddim(y_dims));
dout_temp.Resize(framework::make_ddim(out_dims));
y_ndim = 2;
out_ndim += 1;
}
// Case 2: [M, K] x [K, N] = [M, N]
if (out_ndim == 2) {
if (dX) {
dX->Resize(framework::make_ddim(x_dims));
if (trans_x) {
MatMul2D<T>(ctx, y_temp, dout_temp, dX, trans_y, true);
} else {
MatMul2D<T>(ctx, dout_temp, y_temp, dX, false, !trans_y);
}
dX->Resize(X->dims());
}
if (dY) {
dY->Resize(framework::make_ddim(y_dims));
if (trans_y) {
MatMul2D<T>(ctx, dout_temp, x_temp, dY, true, trans_x);
} else {
MatMul2D<T>(ctx, x_temp, dout_temp, dY, !trans_x, false);
}
dY->Resize(Y->dims());
}
return;
}
// Case 3: [B, M, K] x [K, N] = [B, M, N]
// Case 4: [B, M, K] x [B, K, N] = [B, M, N]
std::vector<int64_t> x_bcast_dims(out_ndim, 1);
std::vector<int64_t> y_bcast_dims(out_ndim, 1);
std::copy(out_dims.begin(), out_dims.end() - 2, x_bcast_dims.begin());
std::copy(out_dims.begin(), out_dims.end() - 2, y_bcast_dims.begin());
std::copy(x_dims.end() - 2, x_dims.end(), x_bcast_dims.end() - 2);
std::copy(y_dims.end() - 2, y_dims.end(), y_bcast_dims.end() - 2);
if (dX) {
Tensor dx_temp(X->type());
if (x_dims != x_bcast_dims) {
dx_temp.Resize(framework::make_ddim(x_bcast_dims));
} else {
dX->mutable_data<T>(ctx.GetPlace());
dx_temp.ShareDataWith(*dX);
}
if (trans_x) {
MatMulND<T>(ctx, y_temp, dout_temp, &dx_temp, trans_y, true);
} else {
MatMulND<T>(ctx, dout_temp, y_temp, &dx_temp, false, !trans_y);
}
if (x_dims != x_bcast_dims) {
ReduceDims<T>(ctx, x_dims, x_bcast_dims, dx_temp, dX);
}
}
if (dY) {
Tensor dy_temp(Y->type());
if (y_dims != y_bcast_dims) {
dy_temp.Resize(framework::make_ddim(y_bcast_dims));
} else {
dY->mutable_data<T>(ctx.GetPlace());
dy_temp.ShareDataWith(*dY);
}
if (trans_y) {
MatMulND<T>(ctx, dout_temp, x_temp, &dy_temp, true, trans_x);
} else {
MatMulND<T>(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false);
}
if (y_dims != y_bcast_dims) {
ReduceDims<T>(ctx, y_dims, y_bcast_dims, dy_temp, dY);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(matmul_v2, ops::MatMulV2MLUKernel<float>,
ops::MatMulV2MLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(matmul_v2_grad, ops::MatMulGradV2MLUKernel<float>,
ops::MatMulGradV2MLUKernel<plat::float16>);
...@@ -45,12 +45,10 @@ class MeanMLUKernel : public framework::OpKernel<T> { ...@@ -45,12 +45,10 @@ class MeanMLUKernel : public framework::OpKernel<T> {
reduce_dims.push_back(i); reduce_dims.push_back(i);
} }
MLUCnnlTensorDesc input_desc( MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
*input, CNNL_LAYOUT_ARRAY, ToCnnlDataType(input->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(input->dtype()))); MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
MLUCnnlTensorDesc output_desc( ToCnnlDataType(output->dtype()));
*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(framework::TransToProtoVarType(output->dtype())));
MLUCnnlReduceDesc reduction_desc( MLUCnnlReduceDesc reduction_desc(
reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType<T>(), reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType<T>(),
...@@ -90,21 +88,18 @@ class MeanMLUGradKernel : public framework::OpKernel<T> { ...@@ -90,21 +88,18 @@ class MeanMLUGradKernel : public framework::OpKernel<T> {
} }
// means // means
Tensor mean_var(framework::TransToProtoVarType(output_grad->dtype())); Tensor mean_var(output_grad->dtype());
mean_var.mutable_data<T>(input_grad->dims(), context.GetPlace()); mean_var.mutable_data<T>(input_grad->dims(), context.GetPlace());
MLUCnnlTensorDesc mean_var_desc( MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY,
mean_var, CNNL_LAYOUT_ARRAY, ToCnnlDataType(mean_var.dtype()));
ToCnnlDataType(framework::TransToProtoVarType(mean_var.dtype())));
auto value = static_cast<T>(1.0 / static_cast<float>(input_grad->numel())); auto value = static_cast<T>(1.0 / static_cast<float>(input_grad->numel()));
MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var)); MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var));
// means mul output_grad // means mul output_grad
MLUCnnlTensorDesc in_desc( MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY,
*output_grad, CNNL_LAYOUT_ARRAY, ToCnnlDataType(output_grad->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(output_grad->dtype()))); MLUCnnlTensorDesc out_desc(*input_grad, CNNL_LAYOUT_ARRAY,
MLUCnnlTensorDesc out_desc( ToCnnlDataType(input_grad->dtype()));
*input_grad, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(framework::TransToProtoVarType(input_grad->dtype())));
MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN); CNNL_NOT_PROPAGATE_NAN);
......
...@@ -35,39 +35,40 @@ class AccuracyMLUKernel : public framework::OpKernel<T> { ...@@ -35,39 +35,40 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
} }
// cast `indices` or `label` if their type is not INT32 // cast `indices` or `label` if their type is not INT32
Tensor indices_int32(VT::INT32); Tensor indices_int32(framework::TransToPtenDataType(VT::INT32));
Tensor label_int32(VT::INT32); Tensor label_int32(framework::TransToPtenDataType(VT::INT32));
if (indices->type() != VT::INT32) { auto indices_type = framework::TransToProtoVarType(indices->type());
PADDLE_ENFORCE_EQ(MLUSupportsCast(indices->type(), VT::INT32), true, if (indices_type != VT::INT32) {
platform::errors::Unavailable( PADDLE_ENFORCE_EQ(MLUSupportsCast(indices_type, VT::INT32), true,
platform::errors::Unimplemented(
"In accuracy mlu kernel, cast indices from [%s] to " "In accuracy mlu kernel, cast indices from [%s] to "
"[%s] is not supported.", "[%s] is not supported.",
framework::DataTypeToString(indices->type()), framework::DataTypeToString(indices_type),
framework::DataTypeToString(VT::INT32))); framework::DataTypeToString(VT::INT32)));
indices_int32.Resize(indices->dims()); indices_int32.Resize(indices->dims());
indices_int32.mutable_data<int>(ctx.GetPlace()); indices_int32.mutable_data<int>(ctx.GetPlace());
MLUCnnlTensorDesc org_indices_desc(*indices); MLUCnnlTensorDesc org_indices_desc(*indices);
MLUCnnlTensorDesc indices_int32_desc(indices_int32); MLUCnnlTensorDesc indices_int32_desc(indices_int32);
cnnlCastDataType_t cast_type = cnnlCastDataType_t cast_type = GetCastDataType(indices_type, VT::INT32);
GetCastDataType(indices->type(), VT::INT32);
MLUCnnl::Cast(ctx, cast_type, org_indices_desc.get(), GetBasePtr(indices), MLUCnnl::Cast(ctx, cast_type, org_indices_desc.get(), GetBasePtr(indices),
indices_int32_desc.get(), GetBasePtr(&indices_int32)); indices_int32_desc.get(), GetBasePtr(&indices_int32));
} else { } else {
indices_int32.ShareDataWith(*indices); indices_int32.ShareDataWith(*indices);
} }
if (label->type() != VT::INT32) { auto label_type = framework::TransToProtoVarType(label->type());
if (label_type != VT::INT32) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
MLUSupportsCast(label->type(), VT::INT32), true, MLUSupportsCast(label_type, VT::INT32), true,
platform::errors::Unavailable( platform::errors::Unimplemented(
"In accuracy mlu kernel, cast label from [%s] to [%s] " "In accuracy mlu kernel, cast label from [%s] to [%s] "
"is not supported.", "is not supported.",
framework::DataTypeToString(label->type()), framework::DataTypeToString(label_type),
framework::DataTypeToString(VT::INT32))); framework::DataTypeToString(VT::INT32)));
label_int32.Resize(label->dims()); label_int32.Resize(label->dims());
label_int32.mutable_data<int>(ctx.GetPlace()); label_int32.mutable_data<int>(ctx.GetPlace());
MLUCnnlTensorDesc org_label_desc(*label); MLUCnnlTensorDesc org_label_desc(*label);
MLUCnnlTensorDesc label_int32_desc(label_int32); MLUCnnlTensorDesc label_int32_desc(label_int32);
cnnlCastDataType_t cast_type = GetCastDataType(label->type(), VT::INT32); cnnlCastDataType_t cast_type = GetCastDataType(label_type, VT::INT32);
MLUCnnl::Cast(ctx, cast_type, org_label_desc.get(), GetBasePtr(label), MLUCnnl::Cast(ctx, cast_type, org_label_desc.get(), GetBasePtr(label),
label_int32_desc.get(), GetBasePtr(&label_int32)); label_int32_desc.get(), GetBasePtr(&label_int32));
} else { } else {
...@@ -77,7 +78,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> { ...@@ -77,7 +78,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
// equal // equal
MLUCnnlTensorDesc indices_int32_desc(indices_int32); MLUCnnlTensorDesc indices_int32_desc(indices_int32);
MLUCnnlTensorDesc label_int32_desc(label_int32); MLUCnnlTensorDesc label_int32_desc(label_int32);
Tensor equal_tensor(VT::BOOL); Tensor equal_tensor(framework::TransToPtenDataType(VT::BOOL));
equal_tensor.Resize(indices->dims()); equal_tensor.Resize(indices->dims());
equal_tensor.mutable_data<bool>(ctx.GetPlace()); equal_tensor.mutable_data<bool>(ctx.GetPlace());
MLUCnnlTensorDesc equal_tensor_desc(equal_tensor); MLUCnnlTensorDesc equal_tensor_desc(equal_tensor);
...@@ -87,7 +88,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> { ...@@ -87,7 +88,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
GetBasePtr(&equal_tensor)); GetBasePtr(&equal_tensor));
// cast equal // cast equal
Tensor equal_fp32(VT::FP32); Tensor equal_fp32(framework::TransToPtenDataType(VT::FP32));
equal_fp32.Resize(indices->dims()); equal_fp32.Resize(indices->dims());
equal_fp32.mutable_data<float>(ctx.GetPlace()); equal_fp32.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc equal_fp32_desc(equal_fp32); MLUCnnlTensorDesc equal_fp32_desc(equal_fp32);
...@@ -98,7 +99,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> { ...@@ -98,7 +99,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
// [correct] // [correct]
// reduce_max // reduce_max
Tensor correct_max(VT::FP32); Tensor correct_max(framework::TransToPtenDataType(VT::FP32));
correct_max.Resize(framework::make_ddim({num_samples})); correct_max.Resize(framework::make_ddim({num_samples}));
correct_max.mutable_data<float>(ctx.GetPlace()); correct_max.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc correct_max_desc(correct_max); MLUCnnlTensorDesc correct_max_desc(correct_max);
...@@ -111,7 +112,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> { ...@@ -111,7 +112,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
correct_max_desc.get(), GetBasePtr(&correct_max)); correct_max_desc.get(), GetBasePtr(&correct_max));
// reduce_sum // reduce_sum
Tensor correct_sum(VT::FP32); Tensor correct_sum(framework::TransToPtenDataType(VT::FP32));
correct_sum.Resize(correct->dims()); correct_sum.Resize(correct->dims());
correct_sum.mutable_data<float>(ctx.GetPlace()); correct_sum.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc correct_sum_desc(correct_sum); MLUCnnlTensorDesc correct_sum_desc(correct_sum);
...@@ -137,7 +138,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> { ...@@ -137,7 +138,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total)); MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total));
// use `total` of type `float32` for calculating accuracy // use `total` of type `float32` for calculating accuracy
Tensor total_fp32(VT::FP32); Tensor total_fp32(framework::TransToPtenDataType(VT::FP32));
total_fp32.Resize(total->dims()); total_fp32.Resize(total->dims());
total_fp32.mutable_data<float>(ctx.GetPlace()); total_fp32.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc total_fp32_desc(total_fp32); MLUCnnlTensorDesc total_fp32_desc(total_fp32);
......
...@@ -178,9 +178,8 @@ MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor, ...@@ -178,9 +178,8 @@ MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor,
} }
MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor) MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor)
: MLUCnnlTensorDesc( : MLUCnnlTensorDesc(tensor, CNNL_LAYOUT_ARRAY,
tensor, CNNL_LAYOUT_ARRAY, ToCnnlDataType(tensor.dtype())) {}
ToCnnlDataType(framework::TransToProtoVarType(tensor.dtype()))) {}
MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor, MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor,
cnnlTensorLayout_t layout, cnnlTensorLayout_t layout,
......
...@@ -49,44 +49,32 @@ inline const void* GetBasePtr(const Tensor* t) { return t->data(); } ...@@ -49,44 +49,32 @@ inline const void* GetBasePtr(const Tensor* t) { return t->data(); }
inline void* GetBasePtr(Tensor* t) { return t->data(); } inline void* GetBasePtr(Tensor* t) { return t->data(); }
template <typename T> inline cnnlDataType_t ToCnnlDataType(
inline cnnlDataType_t ToCnnlDataType(const T& t) { const paddle::experimental::DataType& dtype) {
auto type = framework::ToDataType(t);
return ToCnnlDataType(type);
}
template <typename T>
inline cnnlDataType_t ToCnnlDataType() {
auto type = framework::ToDataType(std::type_index(typeid(T)));
return ToCnnlDataType(type);
}
template <>
inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) {
cnnlDataType_t type = CNNL_DTYPE_FLOAT; cnnlDataType_t type = CNNL_DTYPE_FLOAT;
switch (t) { switch (dtype) {
case framework::proto::VarType::FP16: case DataType::FLOAT16:
type = CNNL_DTYPE_HALF; type = CNNL_DTYPE_HALF;
break; break;
case framework::proto::VarType::FP32: case DataType::FLOAT32:
type = CNNL_DTYPE_FLOAT; type = CNNL_DTYPE_FLOAT;
break; break;
case framework::proto::VarType::INT8: case DataType::INT8:
type = CNNL_DTYPE_INT8; type = CNNL_DTYPE_INT8;
break; break;
case framework::proto::VarType::INT16: case DataType::INT16:
type = CNNL_DTYPE_INT16; type = CNNL_DTYPE_INT16;
break; break;
case framework::proto::VarType::INT32: case DataType::INT32:
type = CNNL_DTYPE_INT32; type = CNNL_DTYPE_INT32;
break; break;
case framework::proto::VarType::INT64: case DataType::INT64:
type = CNNL_DTYPE_INT64; type = CNNL_DTYPE_INT64;
break; break;
case framework::proto::VarType::BOOL: case DataType::BOOL:
type = CNNL_DTYPE_BOOL; type = CNNL_DTYPE_BOOL;
break; break;
case framework::proto::VarType::UINT8: case DataType::UINT8:
type = CNNL_DTYPE_UINT8; type = CNNL_DTYPE_UINT8;
break; break;
default: default:
...@@ -95,6 +83,17 @@ inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) { ...@@ -95,6 +83,17 @@ inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) {
return type; return type;
} }
inline cnnlDataType_t ToCnnlDataType(
const paddle::framework::proto::VarType::Type& type) {
return ToCnnlDataType(framework::TransToPtenDataType(type));
}
template <typename T>
inline cnnlDataType_t ToCnnlDataType() {
auto type = framework::ToDataType(std::type_index(typeid(T)));
return ToCnnlDataType(type);
}
// Converts (via narrowing) a type T value to a type U, and checks that the // Converts (via narrowing) a type T value to a type U, and checks that the
// value has no value change due to the conversion. // value has no value change due to the conversion.
template <typename WideT, typename NarrowT> template <typename WideT, typename NarrowT>
......
...@@ -46,12 +46,10 @@ class ReduceMeanMLUKernel : public framework::OpKernel<T> { ...@@ -46,12 +46,10 @@ class ReduceMeanMLUKernel : public framework::OpKernel<T> {
} }
} }
MLUCnnlTensorDesc input_desc( MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
*input, CNNL_LAYOUT_ARRAY, ToCnnlDataType(input->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(input->dtype()))); MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
MLUCnnlTensorDesc output_desc( ToCnnlDataType(output->dtype()));
*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(framework::TransToProtoVarType(output->dtype())));
MLUCnnlReduceDesc reduction_desc( MLUCnnlReduceDesc reduction_desc(
reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType<T>(), reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType<T>(),
...@@ -91,8 +89,7 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel<T> { ...@@ -91,8 +89,7 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel<T> {
reduce_numel *= input_dims[d]; reduce_numel *= input_dims[d];
} }
Tensor tmp_output_grad( Tensor tmp_output_grad(output_grad->dtype());
framework::TransToProtoVarType(output_grad->dtype()));
auto tmp_output_dims = input_dims; auto tmp_output_dims = input_dims;
for (auto d : reduce_dims) { for (auto d : reduce_dims) {
tmp_output_dims[d] = 1; tmp_output_dims[d] = 1;
...@@ -100,13 +97,10 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel<T> { ...@@ -100,13 +97,10 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel<T> {
tmp_output_grad.ShareDataWith(*output_grad); tmp_output_grad.ShareDataWith(*output_grad);
tmp_output_grad.Resize(framework::make_ddim(tmp_output_dims)); tmp_output_grad.Resize(framework::make_ddim(tmp_output_dims));
MLUCnnlTensorDesc output_grad_desc( MLUCnnlTensorDesc output_grad_desc(tmp_output_grad, CNNL_LAYOUT_ARRAY,
tmp_output_grad, CNNL_LAYOUT_ARRAY, ToCnnlDataType(tmp_output_grad.dtype()));
ToCnnlDataType( MLUCnnlTensorDesc input_grad_desc(*input_grad, CNNL_LAYOUT_ARRAY,
framework::TransToProtoVarType(tmp_output_grad.dtype()))); ToCnnlDataType(input_grad->dtype()));
MLUCnnlTensorDesc input_grad_desc(
*input_grad, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(framework::TransToProtoVarType(input_grad->dtype())));
auto value = static_cast<T>(1.0 / static_cast<float>(reduce_numel)); auto value = static_cast<T>(1.0 / static_cast<float>(reduce_numel));
MLUCnnl::Fill(context, value, input_grad_desc.get(), MLUCnnl::Fill(context, value, input_grad_desc.get(),
......
...@@ -85,15 +85,13 @@ class ScaleMLUKernel : public framework::OpKernel<T> { ...@@ -85,15 +85,13 @@ class ScaleMLUKernel : public framework::OpKernel<T> {
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx); ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc new_bias_desc(new_bias_tensor); MLUCnnlTensorDesc new_bias_desc(new_bias_tensor);
MLUCnnlOpTensorDesc mul_op_desc( MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL,
CNNL_OP_TENSOR_MUL, ToCnnlDataType(in->dtype()),
ToCnnlDataType(framework::TransToProtoVarType(in->dtype())),
CNNL_NOT_PROPAGATE_NAN); CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor( MLUCnnl::OpTensor(
ctx, mul_op_desc.get(), scale_desc.get(), GetBasePtr(&scale_tensor), ctx, mul_op_desc.get(), scale_desc.get(), GetBasePtr(&scale_tensor),
bias_desc.get(), GetBasePtr(&bias_tensor), new_bias_desc.get(), bias_desc.get(), GetBasePtr(&bias_tensor), new_bias_desc.get(),
GetBasePtr(&new_bias_tensor), GetBasePtr(&new_bias_tensor), ToCnnlDataType(in->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(in->dtype())));
MLUCnnl::Scale(ctx, axis, input_desc.get(), GetBasePtr(in), MLUCnnl::Scale(ctx, axis, input_desc.get(), GetBasePtr(in),
scale_desc.get(), GetBasePtr(&scale_tensor), scale_desc.get(), GetBasePtr(&scale_tensor),
new_bias_desc.get(), GetBasePtr(&new_bias_tensor), new_bias_desc.get(), GetBasePtr(&new_bias_tensor),
......
...@@ -87,7 +87,7 @@ class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel<T> { ...@@ -87,7 +87,7 @@ class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"If soft_label=False, axis must be -1 or" "If soft_label=False, axis must be -1 or"
" can be regard as last dimention in mlu kernel.")); " can be regard as last dimention in mlu kernel."));
framework::Tensor labels_int32(VT::INT32); framework::Tensor labels_int32(framework::TransToPtenDataType(VT::INT32));
labels_int32.Resize(labels->dims()); labels_int32.Resize(labels->dims());
labels_int32.mutable_data<int32_t>(ctx.GetPlace()); labels_int32.mutable_data<int32_t>(ctx.GetPlace());
......
...@@ -61,15 +61,13 @@ class SplitMLUKernel : public framework::OpKernel<T> { ...@@ -61,15 +61,13 @@ class SplitMLUKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
outs[i]->mutable_data<T>(ctx.GetPlace()); outs[i]->mutable_data<T>(ctx.GetPlace());
output_descs.emplace_back(MLUCnnlTensorDesc( output_descs.emplace_back(MLUCnnlTensorDesc(
*outs[i], CNNL_LAYOUT_ARRAY, *outs[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(outs[i]->dtype())));
ToCnnlDataType(framework::TransToProtoVarType(outs[i]->dtype()))));
desc_vector.push_back(output_descs.back().get()); desc_vector.push_back(output_descs.back().get());
vct_tensor.push_back(GetBasePtr(outs[i])); vct_tensor.push_back(GetBasePtr(outs[i]));
} }
// init in tensors // init in tensors
MLUCnnlTensorDesc input_desc( MLUCnnlTensorDesc input_desc(*in, CNNL_LAYOUT_ARRAY,
*in, CNNL_LAYOUT_ARRAY, ToCnnlDataType(in->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(in->dtype())));
// MLU should do sth // MLU should do sth
MLUCnnl::Split(ctx, num_tensor, axis, input_desc.get(), GetBasePtr(in), MLUCnnl::Split(ctx, num_tensor, axis, input_desc.get(), GetBasePtr(in),
......
...@@ -43,15 +43,13 @@ class SumMLUKernel : public framework::OpKernel<T> { ...@@ -43,15 +43,13 @@ class SumMLUKernel : public framework::OpKernel<T> {
std::vector<cnnlTensorDescriptor_t> desc_vector; std::vector<cnnlTensorDescriptor_t> desc_vector;
for (int i = 0; i < ins_size; i++) { for (int i = 0; i < ins_size; i++) {
input_descs.emplace_back(MLUCnnlTensorDesc( input_descs.emplace_back(MLUCnnlTensorDesc(
*ins[i], CNNL_LAYOUT_ARRAY, *ins[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(ins[i]->dtype())));
ToCnnlDataType(framework::TransToProtoVarType(ins[i]->dtype()))));
desc_vector.push_back(input_descs.back().get()); desc_vector.push_back(input_descs.back().get());
inputs.push_back(GetBasePtr(ins[i])); inputs.push_back(GetBasePtr(ins[i]));
} }
// init out tensors // init out tensors
MLUCnnlTensorDesc output_desc( MLUCnnlTensorDesc output_desc(*out, CNNL_LAYOUT_ARRAY,
*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType(out->dtype()));
ToCnnlDataType(framework::TransToProtoVarType(out->dtype())));
uint32_t ins_size_t = static_cast<uint32_t>(ins_size); uint32_t ins_size_t = static_cast<uint32_t>(ins_size);
MLUCnnl::AddN(ctx, ins_size_t, desc_vector.data(), inputs.data(), MLUCnnl::AddN(ctx, ins_size_t, desc_vector.data(), inputs.data(),
output_desc.get(), GetBasePtr(out)); output_desc.get(), GetBasePtr(out));
......
...@@ -47,7 +47,7 @@ class TopkMLUKernel : public framework::OpKernel<T> { ...@@ -47,7 +47,7 @@ class TopkMLUKernel : public framework::OpKernel<T> {
const bool sorted = true; const bool sorted = true;
const int axis = -1; const int axis = -1;
// cnnl only support int32/int16 type of indices // cnnl only support int32/int16 type of indices
framework::Tensor indices_int32(VT::INT32); framework::Tensor indices_int32(framework::TransToPtenDataType(VT::INT32));
indices_int32.Resize(indices->dims()); indices_int32.Resize(indices->dims());
indices_int32.mutable_data<int32_t>(place); indices_int32.mutable_data<int32_t>(place);
......
...@@ -55,7 +55,7 @@ class TopkV2MLUKernel : public framework::OpKernel<T> { ...@@ -55,7 +55,7 @@ class TopkV2MLUKernel : public framework::OpKernel<T> {
indices->mutable_data<int64_t>(place); indices->mutable_data<int64_t>(place);
// cnnl only support int32/int16 type of indices // cnnl only support int32/int16 type of indices
framework::Tensor indices_int32(VT::INT32); framework::Tensor indices_int32(framework::TransToPtenDataType(VT::INT32));
indices_int32.Resize(indices->dims()); indices_int32.Resize(indices->dims());
indices_int32.mutable_data<int32_t>(place); indices_int32.mutable_data<int32_t>(place);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from test_matmul_v2_op import reference_matmul
paddle.enable_static()
SEED = 2022
class TestMatMulV2Op(OpTest):
"""
case 1
"""
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def config(self):
self.x_shape = (100, )
self.y_shape = (100, )
self.trans_x = False
self.trans_y = False
def init_kernel_type(self):
self.dtype = "float32"
def setUp(self):
self.set_mlu()
self.init_kernel_type()
self.config()
self.op_type = "matmul_v2"
x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1
x = -0.1 + 0.2 * x
y = -0.1 + 0.2 * y
result = reference_matmul(x, y, self.trans_x, self.trans_y)
result = result.astype(self.dtype)
self.inputs = {
'X': x,
'Y': y,
}
self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-7)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op):
"""
case 2
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 3, 2, 100)
self.trans_x = False
self.trans_y = True
class TestMatMuklOp3(TestMatMulV2Op):
"""
case 3
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp4(TestMatMulV2Op):
"""
case 4
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 2, 100, 2)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp5(TestMatMulV2Op):
"""
case 5
"""
def config(self):
self.x_shape = (1, 1, 100, 1)
self.y_shape = (100, )
self.trans_x = True
self.trans_y = False
class TestMatMuklOp6(TestMatMulV2Op):
"""
case 6
"""
def config(self):
self.x_shape = (1, 2, 102, 1)
self.y_shape = (102, )
self.trans_x = True
self.trans_y = False
class TestMatMuklOp7(TestMatMulV2Op):
"""
case 7
"""
def config(self):
self.x_shape = (1, 2, 1, 100)
self.y_shape = (100, )
self.trans_x = False
self.trans_y = False
class TestMatMuklOp8(TestMatMulV2Op):
"""
case 8
"""
def config(self):
self.x_shape = (1, 1, 2, 100)
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp9(TestMatMulV2Op):
"""
case 9
"""
def config(self):
self.x_shape = (1, 1, 1, 100)
self.y_shape = (2, 1, 2, 100)
self.trans_x = False
self.trans_y = True
class TestMatMuklOp10(TestMatMulV2Op):
"""
case 10
"""
def config(self):
self.x_shape = (1, 1, 25, 4)
self.y_shape = (1, 2, 4, 25)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp11(TestMatMulV2Op):
"""
case 11
"""
def config(self):
self.x_shape = (2, 1, 2, 100)
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp12(TestMatMulV2Op):
"""
case 12
"""
def config(self):
self.x_shape = (2, 1, 4, 25)
self.y_shape = (1, 1, 4, 25)
self.trans_x = True
self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op):
"""
case 13
"""
def config(self):
self.x_shape = (2, 2, 10, 10)
self.y_shape = (2, 2, 10, 10)
self.trans_x = True
self.trans_y = False
class TestMatMuklOp14(TestMatMulV2Op):
"""
case 14_1
"""
def config(self):
self.x_shape = (3, 1, 6, 6)
self.y_shape = (1, 2, 6, 9)
self.trans_x = True
self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op):
"""
case 14_2
"""
def config(self):
self.x_shape = (3, 1, 6, 6)
self.y_shape = (1, 2, 6, 9)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp16(TestMatMulV2Op):
"""
case 16 : to check the gradient for special case
"""
def config(self):
self.x_shape = (100)
self.y_shape = (1, 2, 2, 100, 2)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op):
"""
case 17 : to check the gradient for special case
"""
def config(self):
self.x_shape = (2, 1, 100)
self.y_shape = (100)
self.trans_x = False
self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = True
self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op):
"""
case 14_4
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = False
self.trans_y = True
#--------------------test matmul fp16--------------------
def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5):
class TestMatMulOpFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=atol)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X', 'Y'],
'Out',
max_relative_error=max_relative_error)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestMatMulOpFp16Case.__name__ = cls_name
globals()[cls_name] = TestMatMulOpFp16Case
create_test_fp16_class(TestMatMulV2Op)
create_test_fp16_class(TestMatMuklOp2)
create_test_fp16_class(TestMatMuklOp3)
create_test_fp16_class(TestMatMuklOp4)
create_test_fp16_class(TestMatMuklOp5)
create_test_fp16_class(TestMatMuklOp6)
create_test_fp16_class(TestMatMuklOp7)
create_test_fp16_class(TestMatMuklOp8)
create_test_fp16_class(TestMatMuklOp9)
create_test_fp16_class(TestMatMuklOp10)
create_test_fp16_class(TestMatMuklOp11)
create_test_fp16_class(TestMatMuklOp12)
create_test_fp16_class(TestMatMuklOp13)
create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class(TestMatMuklOp15)
create_test_fp16_class(TestMatMuklOp16)
create_test_fp16_class(TestMatMuklOp17)
class TestMatMulV2API(unittest.TestCase):
def setUp(self):
self.places = [paddle.CPUPlace()]
if paddle.is_compiled_with_mlu():
self.places.append(paddle.device.MLUPlace(0))
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32")
input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32")
result = paddle.matmul(input_x, input_y)
x_np = np.random.random([4, 3]).astype("float32")
y_np = np.random.random([3, 4]).astype("float32")
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input_x": x_np,
"input_y": y_np},
fetch_list=[result])
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
input_x = np.random.random([4, 3]).astype("float32")
input_y = np.random.random([3, 4]).astype("float32")
x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
def test_dygraph_fp16(self):
if paddle.is_compiled_with_mlu():
place = paddle.device.MLUPlace(0)
with fluid.dygraph.guard(place):
input_x = np.random.random([4, 3]).astype("float16")
input_y = np.random.random([3, 4]).astype("float16")
x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册