未验证 提交 d752a7f2 编写于 作者: T taixiurong 提交者: GitHub

xpu-paddlepaddle-31 优化matmul test=kunlun (#43975)

上级 33540e10
...@@ -20,276 +20,40 @@ limitations under the License. */ ...@@ -20,276 +20,40 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h" #include "paddle/fluid/operators/xpu_api_wrapper.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return phi::make_ddim({1, x_dim[0]});
}
static framework::Tensor FoldInitDims(const framework::Tensor &input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
}
return output;
}
/**
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
*/
static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
return phi::make_ddim({y_dim[0], 1});
}
static void ReshapeTensorIntoMatrixSequence(
framework::Tensor *x, const phi::funcs::MatDescriptor &descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
/**
* Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
* Out = matmul(x, y)
*
* This method will first calculate X,Y matrix sequence, and then calculate
* the out shape.
*
* Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
* The out = [BatchSize, H1, W2]
*
* If there is no batch size in `X` and `Y`, the out will be [H1, W2]
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
framework::Tensor *y,
framework::Tensor *out,
bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixFromVector(x->dims());
auto y_dim = ColumnMatrixFromVector(y->dims());
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_,
mat_dim_y.width_});
}
ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename T, typename FCT>
static void MatMulXPUFunction(const Tensor *x,
const Tensor *y,
Tensor *out,
bool trans_x,
bool trans_y,
const paddle::framework::ExecutionContext &ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto &x_dims = x->dims();
const auto &y_dims = y->dims();
auto &dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x_dims), 0, trans_x);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
ColumnMatrixFromVector(y_dims), 0, trans_y);
if (x_dims.size() == 3 && y_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time
if (!trans_x) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
if (mat_dim_a.width_ == mat_dim_b.height_) {
if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
}
PADDLE_ENFORCE_EQ(mat_dim_a.width_,
mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op, the "
"first tensor width must be same as "
"second tensor height, but received "
"width:%d, height:%d x_dims = %s , y_dims = %s",
mat_dim_a.width_,
mat_dim_b.height_,
x_dims.to_str().c_str(),
y_dims.to_str().c_str()));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_,
mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d, x_dims = %s , y_dims = %s",
mat_dim_a.batch_size_,
mat_dim_b.batch_size_,
x_dims.to_str().c_str(),
y_dims.to_str().c_str()));
float alpha = static_cast<T>(ctx.Attr<float>("alpha"));
T *data_c = out->data<T>();
int m = mat_dim_a.height_;
int n = mat_dim_b.width_;
int k = mat_dim_a.width_;
int batch_size = mat_dim_a.batch_size_;
int ldx = mat_dim_a.trans_ ? m : k;
int ldy = mat_dim_b.trans_ ? k : n;
int ldout = n;
if (batch_size <= 1) {
int r = 0;
r = xpu_fc_wrapper<XPUType, FCT>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(x->data<T>()),
reinterpret_cast<const XPUType *>(y->data<T>()),
reinterpret_cast<XPUType *>(data_c),
m,
n,
k,
mat_dim_a.trans_,
mat_dim_b.trans_,
nullptr,
nullptr,
nullptr,
ldx,
ldy,
ldout,
alpha,
0,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
} else {
// batch matmul
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
alpha, // float alpha,
reinterpret_cast<const XPUType *>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const XPUType *>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType *>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc_batched kernel return wrong value[%d %s] "
"x_dims = %s , y_dims = %s",
r,
XPUAPIErrorMsg[r],
x_dims.to_str().c_str(),
y_dims.to_str().c_str()));
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulXPUKernel : public framework::OpKernel<T> { class MatMulXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto *x = context.Input<framework::Tensor>("X"); auto* x = context.Input<framework::Tensor>("X");
auto *y = context.Input<framework::Tensor>("Y"); auto* y = context.Input<framework::Tensor>("Y");
auto *out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
bool trans_x = context.Attr<bool>("transpose_X"); bool trans_x = context.Attr<bool>("transpose_X");
bool trans_y = context.Attr<bool>("transpose_Y"); bool trans_y = context.Attr<bool>("transpose_Y");
if (std::is_same<paddle::platform::float16, T>::value) { float alpha = static_cast<T>(context.Attr<float>("alpha"));
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, context); const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
} else { const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, context); auto x_dims = x->dims();
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { auto y_dims = y->dims();
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, context);
} else { XpuFcInfo fc_info;
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, context); GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info);
} auto& dev_ctx =
} context.template device_context<paddle::platform::XPUDeviceContext>();
xpu::Context* xpu_ctx = dev_ctx.x_context();
MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha);
} }
}; };
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
const DeviceContext &context, const framework::Tensor &input) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> in_shape_host = {static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]),
static_cast<int>(in_dims[2])};
std::vector<int> axis_host = {1, 0, 2};
int r = xpu::transpose(context.x_context(),
reinterpret_cast<const XPUType *>(input.data<T>()),
reinterpret_cast<XPUType *>(output.data<T>()),
in_shape_host,
axis_host);
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
// Using dimensional constraints on matrix multiplication, it is // Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y // straight-forward to check the following table for when X and Y
// are both matrices. // are both matrices.
...@@ -317,107 +81,68 @@ static framework::Tensor XPUFoldHeadAndLastDims( ...@@ -317,107 +81,68 @@ static framework::Tensor XPUFoldHeadAndLastDims(
// to X: (P * M) x K, dOut: (P * M) x N. // to X: (P * M) x K, dOut: (P * M) x N.
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulGradXPUKernel : public framework::OpKernel<T> { class MatMulGradXPUKernel : public framework::OpKernel<T> {
public: using XPUType = typename XPUTypeTrait<T>::Type;
void MatMul(const framework::ExecutionContext &context,
const framework::Tensor &a,
bool trans_a,
const framework::Tensor &b,
bool trans_b,
framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace());
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, context);
} else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, context);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(&a, &b, out, trans_a, trans_b, context);
} else {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, context);
}
}
}
void CalcInputGrad(const framework::ExecutionContext &context,
const framework::Tensor &a,
bool trans_a,
bool is_fold_init_dims_a,
const framework::Tensor &b,
bool trans_b,
bool is_fold_init_dims_b,
framework::Tensor *out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto &dev_ctx = context.template device_context<DeviceContext>();
MatMul(context,
is_fold_init_dims_a
? FoldInitDims(a)
: XPUFoldHeadAndLastDims<DeviceContext, T>(dev_ctx, a),
trans_a,
is_fold_init_dims_b
? FoldInitDims(b)
: XPUFoldHeadAndLastDims<DeviceContext, T>(dev_ctx, b),
trans_b,
out);
}
}
void Compute(const framework::ExecutionContext &context) const override { public:
void Compute(const framework::ExecutionContext& context) const override {
auto x = *context.Input<framework::Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y"); auto y = *context.Input<framework::Tensor>("Y");
auto dout = auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
bool transpose_x = context.Attr<bool>("transpose_X"); bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y"); bool transpose_y = context.Attr<bool>("transpose_Y");
float alpha = static_cast<T>(context.Attr<float>("alpha"));
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx->mutable_data<T>(context.GetPlace());
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
} }
framework::DDim dy_dims;
if (dy) { if (dy) {
dy_dims = dy->dims(); dy->mutable_data<T>(context.GetPlace());
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
if (transpose_x && transpose_y) {
CalcInputGrad(context, y, true, true, dout, true, false, dx);
CalcInputGrad(context, dout, true, true, x, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(context, y, false, false, dout, true, false, dx);
CalcInputGrad(context, x, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(context, dout, false, false, y, false, true, dx);
CalcInputGrad(context, dout, true, true, x, false, true, dy);
} else {
CalcInputGrad(context, dout, false, false, y, true, false, dx);
CalcInputGrad(context, x, true, true, dout, false, true, dy);
} }
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout.data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
xpu::Context* xpu_ctx = dev_ctx.x_context();
XpuFcInfo info_forward;
GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward);
xpu::ctx_guard RAII_GUARD(xpu_ctx);
// begin calculate
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 = (dx == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->data<T>());
XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->data<T>());
XpuFcInfo info_dx;
XpuFcInfo info_dy;
std::tuple<XpuFcInfo,
XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
transpose_x,
transpose_y,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) { if (dx) {
if (dx_dims != x.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, alpha);
dx->Resize(dx_dims);
}
} }
if (dy) { if (dy) {
if (dy_dims != y.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, alpha);
dy->Resize(dy_dims);
}
} }
} }
}; };
......
...@@ -16,146 +16,17 @@ ...@@ -16,146 +16,17 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/matmul_v2_op.h" #include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h" #include "paddle/fluid/operators/xpu_api_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename FCT>
static void MatMulXPUFunction(const Tensor* x,
const Tensor* y,
Tensor* out,
bool trans_x,
bool trans_y,
const paddle::framework::ExecutionContext& ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x->dims();
const auto& y_dims = y->dims();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x_dims), 0, trans_x);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
ColumnMatrixFromVector(y_dims), 0, trans_y);
if (x_dims.size() >= 3 && y_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time
if (!trans_x) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
if (mat_dim_a.width_ == mat_dim_b.height_) {
if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
}
PADDLE_ENFORCE_EQ(mat_dim_a.width_,
mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_v2_op xdims = %s ydims = %s "
"x_trans = %d y_trans = %d",
x_dims.to_str(),
y_dims.to_str(),
mat_dim_a.trans_,
mat_dim_b.trans_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_,
mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_v2_op xdims = %s ydims = %s "
"x_trans = %d y_trans = %d",
x_dims.to_str(),
y_dims.to_str(),
mat_dim_a.trans_,
mat_dim_b.trans_));
T* data_c = out->data<T>();
int m = mat_dim_a.height_;
int n = mat_dim_b.width_;
int k = mat_dim_a.width_;
int batch_size = mat_dim_a.batch_size_;
int ldx = mat_dim_a.trans_ ? m : k;
int ldy = mat_dim_b.trans_ ? k : n;
int ldout = n;
if (batch_size <= 1) {
int r = 0;
r = xpu_fc_wrapper<XPUType, FCT>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x->data<T>()),
reinterpret_cast<const XPUType*>(y->data<T>()),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
mat_dim_a.trans_,
mat_dim_b.trans_,
nullptr,
nullptr,
nullptr,
ldx,
ldy,
ldout,
1.0,
0,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc kernel return wrong value[%d %s] , m = %d, n = "
"%d, "
"k = %d, a_tr = %d, b_tr = %d",
r,
XPUAPIErrorMsg[r],
m,
n,
k,
mat_dim_a.trans_,
mat_dim_b.trans_));
} else {
// batch matmul
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
1.0, // float alpha,
reinterpret_cast<const XPUType*>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const XPUType*>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType*>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc_batched kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
}
}
template <typename T> template <typename T>
class MatMulV2XPUKernel : public framework::OpKernel<T> { class MatMulV2XPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
...@@ -164,160 +35,84 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> { ...@@ -164,160 +35,84 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x"); bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
if (std::is_same<paddle::platform::float16, T>::value) { const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx); const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
} else { XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { auto x_dims = x->dims();
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, ctx); auto y_dims = y->dims();
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, ctx); XpuFcInfo fc_info;
} else { GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info);
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx); auto& dev_ctx =
} ctx.template device_context<paddle::platform::XPUDeviceContext>();
} xpu::Context* xpu_ctx = dev_ctx.x_context();
MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
} }
}; };
template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
const DeviceContext& context, const framework::Tensor& input) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> in_shape_host = {static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]),
static_cast<int>(in_dims[2])};
std::vector<int> axis_host = {1, 0, 2};
int r = xpu::transpose(context.x_context(),
reinterpret_cast<const XPUType*>(input.data<T>()),
reinterpret_cast<XPUType*>(output.data<T>()),
in_shape_host,
axis_host);
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
template <typename T> template <typename T>
class MatMulV2XPUGradKernel : public framework::OpKernel<T> { class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
public: using XPUType = typename XPUTypeTrait<T>::Type;
void MatMul(const framework::ExecutionContext& ctx,
const framework::Tensor& a,
bool trans_a,
const framework::Tensor& b,
bool trans_b,
framework::Tensor* out) const {
out->mutable_data<T>(ctx.GetPlace());
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
} else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, ctx);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(&a, &b, out, trans_a, trans_b, ctx);
} else {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
}
}
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a,
bool trans_a,
bool is_fold_init_dims_a,
const framework::Tensor& b,
bool trans_b,
bool is_fold_init_dims_b,
framework::Tensor* out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
MatMul(
context,
is_fold_init_dims_a
? FoldInitDims(a)
: XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
dev_ctx, a),
trans_a,
is_fold_init_dims_b
? FoldInitDims(b)
: XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
dev_ctx, b),
trans_b,
out);
}
}
public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool transpose_x = context.Attr<bool>("trans_x"); bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y"); bool transpose_y = context.Attr<bool>("trans_y");
auto x = *context.Input<framework::Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y"); auto y = *context.Input<framework::Tensor>("Y");
auto dout = auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx->mutable_data<T>(context.GetPlace());
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
} }
framework::DDim dy_dims;
if (dy) { if (dy) {
dy_dims = dy->dims(); dy->mutable_data<T>(context.GetPlace());
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
if (transpose_x && transpose_y) {
CalcInputGrad(context, y, true, true, dout, true, false, dx);
CalcInputGrad(context, dout, true, true, x, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(context, y, false, false, dout, true, false, dx);
CalcInputGrad(context, x, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(context, dout, false, false, y, false, true, dx);
CalcInputGrad(context, dout, true, true, x, false, true, dy);
} else {
CalcInputGrad(context, dout, false, false, y, true, false, dx);
CalcInputGrad(context, x, true, true, dout, false, true, dy);
} }
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout.data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
xpu::Context* xpu_ctx = dev_ctx.x_context();
XpuFcInfo info_forward;
GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward);
xpu::ctx_guard RAII_GUARD(xpu_ctx);
// begin calculate
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 = (dx == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->data<T>());
XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->data<T>());
XpuFcInfo info_dx;
XpuFcInfo info_dy;
std::tuple<XpuFcInfo,
XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
transpose_x,
transpose_y,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) { if (dx) {
if (dx_dims != x.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
dx->Resize(dx_dims);
}
} }
if (dy) { if (dy) {
if (dy_dims != y.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
dy->Resize(dy_dims);
}
} }
} }
}; };
......
...@@ -49,50 +49,23 @@ class MulXPUKernel : public framework::OpKernel<T> { ...@@ -49,50 +49,23 @@ class MulXPUKernel : public framework::OpKernel<T> {
*y, context.template Attr<int>("y_num_col_dims")) *y, context.template Attr<int>("y_num_col_dims"))
: *y; : *y;
z->mutable_data<T>(context.GetPlace()); z->mutable_data<T>(context.GetPlace());
auto z_dim = z->dims();
if (z_dim.size() != 2) { const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x_matrix.data<T>());
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y_matrix.data<T>());
} XPUType* out_ptr = reinterpret_cast<XPUType*>(z->data<T>());
bool trans_a = false; bool trans_a = false;
bool trans_b = false; bool trans_b = false;
int m = x_matrix.dims()[0]; auto x_dims = x_matrix.dims();
int k = x_matrix.dims()[1]; auto y_dims = y_matrix.dims();
int k1 = y_matrix.dims()[0];
int n = y_matrix.dims()[1]; XpuFcInfo fc_info;
PADDLE_ENFORCE_EQ( GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info);
k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op")); auto& dev_ctx =
T alpha = static_cast<T>(1.0); context.template device_context<paddle::platform::XPUDeviceContext>();
T beta = static_cast<T>(0.0); xpu::Context* xpu_ctx = dev_ctx.x_context();
const T* data_a = x_matrix.data<T>();
const T* data_b = y_matrix.data<T>(); MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
T* data_c = z->data<T>();
auto& dev_ctx = context.template device_context<DeviceContext>();
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
trans_a,
trans_b,
nullptr,
nullptr,
nullptr,
k,
n,
n,
alpha,
beta,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
} }
}; };
...@@ -125,98 +98,51 @@ class MulGradXPUKernel : public framework::OpKernel<T> { ...@@ -125,98 +98,51 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
dy->set_lod(y->lod()); dy->set_lod(y->lod());
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
XpuFcInfo info_forward;
GetFCInfo(x_matrix.dims(), y_matrix.dims(), false, false, &info_forward);
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout->data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
xpu::Context* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
// begin calculate
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 =
(dx == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->mutable_data<T>(ctx.GetPlace()));
XPUType* c_2 =
(dy == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->mutable_data<T>(ctx.GetPlace()));
XpuFcInfo info_dx;
XpuFcInfo info_dy;
std::tuple<XpuFcInfo,
XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
false,
false,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) { if (dx) {
dx->mutable_data<T>(ctx.GetPlace()); MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
Tensor dx_matrix = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
// blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
bool trans_a = false;
bool trans_b = true;
int m = dout_mat.dims()[0];
int k = dout_mat.dims()[1];
int n = y_matrix.dims()[0];
int k1 = y_matrix.dims()[1];
PADDLE_ENFORCE_EQ(
k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op"));
int lda = (!trans_a) ? k : m;
int ldb = (!trans_b) ? n : k;
int ldc = n;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
const T* data_a = dout->data<T>();
const T* data_b = y_matrix.data<T>();
T* data_c = dx_matrix.data<T>();
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
trans_a,
trans_b,
nullptr,
nullptr,
nullptr,
lda,
ldb,
ldc,
alpha,
beta,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
Tensor dy_matrix = dy->dims().size() > 2
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
// blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
bool trans_a = true;
bool trans_b = false;
int k = x_matrix.dims()[0];
int m = x_matrix.dims()[1];
int k1 = dout_mat.dims()[0];
int n = dout_mat.dims()[1];
PADDLE_ENFORCE_EQ(
k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op"));
int lda = (!trans_a) ? k : m;
int ldb = (!trans_b) ? n : k;
int ldc = n;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
const T* data_a = x_matrix.data<T>();
const T* data_b = dout->data<T>();
T* data_c = dy_matrix.data<T>();
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
trans_a,
trans_b,
nullptr,
nullptr,
nullptr,
lda,
ldb,
ldc,
alpha,
beta,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
} }
} }
}; };
......
...@@ -281,11 +281,18 @@ XPUOpMap& get_kl2_ops() { ...@@ -281,11 +281,18 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"matmul_v2_grad", {"matmul_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"matmul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mean_grad", {"mean_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
...@@ -84,7 +84,9 @@ type_dict_str_to_numpy = { ...@@ -84,7 +84,9 @@ type_dict_str_to_numpy = {
xpu_test_op_white_list = [] xpu_test_op_white_list = []
xpu_test_device_type_white_list = ['xpu1_float64'] xpu_test_device_type_white_list = ['xpu1_float64']
xpu_test_op_type_white_list = ['dropout_float16', 'dropout_grad_float16'] xpu_test_op_type_white_list = [
'dropout_float16', 'dropout_grad_float16', 'matmul_v2_float16'
]
xpu_test_device_op_white_list = [] xpu_test_device_op_white_list = []
xpu_test_device_op_type_white_list = [] xpu_test_device_op_type_white_list = []
......
...@@ -303,7 +303,8 @@ class TestMatmulBaseGenerator(XPUOpTest): ...@@ -303,7 +303,8 @@ class TestMatmulBaseGenerator(XPUOpTest):
X = np.random.random(shape_X).astype(self.dtype) X = np.random.random(shape_X).astype(self.dtype)
Y = np.random.random(shape_Y).astype(self.dtype) Y = np.random.random(shape_Y).astype(self.dtype)
Out = reference_matmul(X, Y, transpose_X, transpose_Y) Out = reference_matmul(X, Y, transpose_X,
transpose_Y).astype(self.dtype)
self.inputs = {'X': X, 'Y': Y} self.inputs = {'X': X, 'Y': Y}
self.attrs = {'transpose_X': transpose_X, 'transpose_Y': transpose_Y} self.attrs = {'transpose_X': transpose_X, 'transpose_Y': transpose_Y}
self.outputs = {'Out': Out} self.outputs = {'Out': Out}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册