未验证 提交 d752a7f2 编写于 作者: T taixiurong 提交者: GitHub

xpu-paddlepaddle-31 优化matmul test=kunlun (#43975)

上级 33540e10
...@@ -20,275 +20,39 @@ limitations under the License. */ ...@@ -20,275 +20,39 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h" #include "paddle/fluid/operators/xpu_api_wrapper.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return phi::make_ddim({1, x_dim[0]});
}
static framework::Tensor FoldInitDims(const framework::Tensor &input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
}
return output;
}
/**
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
*/
static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
return phi::make_ddim({y_dim[0], 1});
}
static void ReshapeTensorIntoMatrixSequence(
framework::Tensor *x, const phi::funcs::MatDescriptor &descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
/**
* Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
* Out = matmul(x, y)
*
* This method will first calculate X,Y matrix sequence, and then calculate
* the out shape.
*
* Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
* The out = [BatchSize, H1, W2]
*
* If there is no batch size in `X` and `Y`, the out will be [H1, W2]
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
framework::Tensor *y,
framework::Tensor *out,
bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixFromVector(x->dims());
auto y_dim = ColumnMatrixFromVector(y->dims());
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_,
mat_dim_y.width_});
}
ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename T, typename FCT>
static void MatMulXPUFunction(const Tensor *x,
const Tensor *y,
Tensor *out,
bool trans_x,
bool trans_y,
const paddle::framework::ExecutionContext &ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto &x_dims = x->dims();
const auto &y_dims = y->dims();
auto &dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x_dims), 0, trans_x);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
ColumnMatrixFromVector(y_dims), 0, trans_y);
if (x_dims.size() == 3 && y_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time
if (!trans_x) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
if (mat_dim_a.width_ == mat_dim_b.height_) {
if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
}
PADDLE_ENFORCE_EQ(mat_dim_a.width_,
mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op, the "
"first tensor width must be same as "
"second tensor height, but received "
"width:%d, height:%d x_dims = %s , y_dims = %s",
mat_dim_a.width_,
mat_dim_b.height_,
x_dims.to_str().c_str(),
y_dims.to_str().c_str()));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_,
mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d, x_dims = %s , y_dims = %s",
mat_dim_a.batch_size_,
mat_dim_b.batch_size_,
x_dims.to_str().c_str(),
y_dims.to_str().c_str()));
float alpha = static_cast<T>(ctx.Attr<float>("alpha"));
T *data_c = out->data<T>();
int m = mat_dim_a.height_;
int n = mat_dim_b.width_;
int k = mat_dim_a.width_;
int batch_size = mat_dim_a.batch_size_;
int ldx = mat_dim_a.trans_ ? m : k;
int ldy = mat_dim_b.trans_ ? k : n;
int ldout = n;
if (batch_size <= 1) {
int r = 0;
r = xpu_fc_wrapper<XPUType, FCT>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(x->data<T>()),
reinterpret_cast<const XPUType *>(y->data<T>()),
reinterpret_cast<XPUType *>(data_c),
m,
n,
k,
mat_dim_a.trans_,
mat_dim_b.trans_,
nullptr,
nullptr,
nullptr,
ldx,
ldy,
ldout,
alpha,
0,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
} else {
// batch matmul
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
alpha, // float alpha,
reinterpret_cast<const XPUType *>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const XPUType *>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType *>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc_batched kernel return wrong value[%d %s] "
"x_dims = %s , y_dims = %s",
r,
XPUAPIErrorMsg[r],
x_dims.to_str().c_str(),
y_dims.to_str().c_str()));
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulXPUKernel : public framework::OpKernel<T> { class MatMulXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto *x = context.Input<framework::Tensor>("X"); auto* x = context.Input<framework::Tensor>("X");
auto *y = context.Input<framework::Tensor>("Y"); auto* y = context.Input<framework::Tensor>("Y");
auto *out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
bool trans_x = context.Attr<bool>("transpose_X"); bool trans_x = context.Attr<bool>("transpose_X");
bool trans_y = context.Attr<bool>("transpose_Y"); bool trans_y = context.Attr<bool>("transpose_Y");
if (std::is_same<paddle::platform::float16, T>::value) { float alpha = static_cast<T>(context.Attr<float>("alpha"));
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, context); const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
} else { const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, context); auto x_dims = x->dims();
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { auto y_dims = y->dims();
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, context);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, context);
}
}
}
};
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
const DeviceContext &context, const framework::Tensor &input) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output; XpuFcInfo fc_info;
output.Resize({in_dims[1], in_dims[0], in_dims[2]}); GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info);
output.mutable_data<T>(context.GetPlace()); auto& dev_ctx =
std::vector<int> in_shape_host = {static_cast<int>(in_dims[0]), context.template device_context<paddle::platform::XPUDeviceContext>();
static_cast<int>(in_dims[1]), xpu::Context* xpu_ctx = dev_ctx.x_context();
static_cast<int>(in_dims[2])};
std::vector<int> axis_host = {1, 0, 2};
int r = xpu::transpose(context.x_context(),
reinterpret_cast<const XPUType *>(input.data<T>()),
reinterpret_cast<XPUType *>(output.data<T>()),
in_shape_host,
axis_host);
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output; MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha);
} }
};
// Using dimensional constraints on matrix multiplication, it is // Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y // straight-forward to check the following table for when X and Y
...@@ -317,107 +81,68 @@ static framework::Tensor XPUFoldHeadAndLastDims( ...@@ -317,107 +81,68 @@ static framework::Tensor XPUFoldHeadAndLastDims(
// to X: (P * M) x K, dOut: (P * M) x N. // to X: (P * M) x K, dOut: (P * M) x N.
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulGradXPUKernel : public framework::OpKernel<T> { class MatMulGradXPUKernel : public framework::OpKernel<T> {
public: using XPUType = typename XPUTypeTrait<T>::Type;
void MatMul(const framework::ExecutionContext &context,
const framework::Tensor &a,
bool trans_a,
const framework::Tensor &b,
bool trans_b,
framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace());
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, context);
} else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, context);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(&a, &b, out, trans_a, trans_b, context);
} else {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, context);
}
}
}
void CalcInputGrad(const framework::ExecutionContext &context,
const framework::Tensor &a,
bool trans_a,
bool is_fold_init_dims_a,
const framework::Tensor &b,
bool trans_b,
bool is_fold_init_dims_b,
framework::Tensor *out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto &dev_ctx = context.template device_context<DeviceContext>();
MatMul(context,
is_fold_init_dims_a
? FoldInitDims(a)
: XPUFoldHeadAndLastDims<DeviceContext, T>(dev_ctx, a),
trans_a,
is_fold_init_dims_b
? FoldInitDims(b)
: XPUFoldHeadAndLastDims<DeviceContext, T>(dev_ctx, b),
trans_b,
out);
}
}
void Compute(const framework::ExecutionContext &context) const override { public:
void Compute(const framework::ExecutionContext& context) const override {
auto x = *context.Input<framework::Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y"); auto y = *context.Input<framework::Tensor>("Y");
auto dout = auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
bool transpose_x = context.Attr<bool>("transpose_X"); bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y"); bool transpose_y = context.Attr<bool>("transpose_Y");
float alpha = static_cast<T>(context.Attr<float>("alpha"));
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx->mutable_data<T>(context.GetPlace());
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
} }
}
framework::DDim dy_dims;
if (dy) { if (dy) {
dy_dims = dy->dims(); dy->mutable_data<T>(context.GetPlace());
if (dy_dims != y.dims()) { }
dy->Resize(y.dims()); auto& dev_ctx =
} context.template device_context<paddle::platform::XPUDeviceContext>();
}
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout.data<T>());
if (transpose_x && transpose_y) { const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
CalcInputGrad(context, y, true, true, dout, true, false, dx); const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
CalcInputGrad(context, dout, true, true, x, true, false, dy);
} else if (transpose_x) { xpu::Context* xpu_ctx = dev_ctx.x_context();
CalcInputGrad(context, y, false, false, dout, true, false, dx);
CalcInputGrad(context, x, false, false, dout, false, true, dy); XpuFcInfo info_forward;
} else if (transpose_y) { GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward);
CalcInputGrad(context, dout, false, false, y, false, true, dx); xpu::ctx_guard RAII_GUARD(xpu_ctx);
CalcInputGrad(context, dout, true, true, x, false, true, dy); // begin calculate
} else { const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
CalcInputGrad(context, dout, false, false, y, true, false, dx); const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
CalcInputGrad(context, x, true, true, dout, false, true, dy); const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
} const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 = (dx == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->data<T>());
XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->data<T>());
XpuFcInfo info_dx;
XpuFcInfo info_dy;
std::tuple<XpuFcInfo,
XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
transpose_x,
transpose_y,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) { if (dx) {
if (dx_dims != x.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, alpha);
dx->Resize(dx_dims);
} }
}
if (dy) { if (dy) {
if (dy_dims != y.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, alpha);
dy->Resize(dy_dims);
}
} }
} }
}; };
......
...@@ -16,146 +16,17 @@ ...@@ -16,146 +16,17 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/matmul_v2_op.h" #include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h" #include "paddle/fluid/operators/xpu_api_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename FCT>
static void MatMulXPUFunction(const Tensor* x,
const Tensor* y,
Tensor* out,
bool trans_x,
bool trans_y,
const paddle::framework::ExecutionContext& ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x->dims();
const auto& y_dims = y->dims();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x_dims), 0, trans_x);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
ColumnMatrixFromVector(y_dims), 0, trans_y);
if (x_dims.size() >= 3 && y_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time
if (!trans_x) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
if (mat_dim_a.width_ == mat_dim_b.height_) {
if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
}
PADDLE_ENFORCE_EQ(mat_dim_a.width_,
mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_v2_op xdims = %s ydims = %s "
"x_trans = %d y_trans = %d",
x_dims.to_str(),
y_dims.to_str(),
mat_dim_a.trans_,
mat_dim_b.trans_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_,
mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_v2_op xdims = %s ydims = %s "
"x_trans = %d y_trans = %d",
x_dims.to_str(),
y_dims.to_str(),
mat_dim_a.trans_,
mat_dim_b.trans_));
T* data_c = out->data<T>();
int m = mat_dim_a.height_;
int n = mat_dim_b.width_;
int k = mat_dim_a.width_;
int batch_size = mat_dim_a.batch_size_;
int ldx = mat_dim_a.trans_ ? m : k;
int ldy = mat_dim_b.trans_ ? k : n;
int ldout = n;
if (batch_size <= 1) {
int r = 0;
r = xpu_fc_wrapper<XPUType, FCT>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x->data<T>()),
reinterpret_cast<const XPUType*>(y->data<T>()),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
mat_dim_a.trans_,
mat_dim_b.trans_,
nullptr,
nullptr,
nullptr,
ldx,
ldy,
ldout,
1.0,
0,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc kernel return wrong value[%d %s] , m = %d, n = "
"%d, "
"k = %d, a_tr = %d, b_tr = %d",
r,
XPUAPIErrorMsg[r],
m,
n,
k,
mat_dim_a.trans_,
mat_dim_b.trans_));
} else {
// batch matmul
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
1.0, // float alpha,
reinterpret_cast<const XPUType*>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const XPUType*>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType*>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU fc_batched kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
}
}
template <typename T> template <typename T>
class MatMulV2XPUKernel : public framework::OpKernel<T> { class MatMulV2XPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
...@@ -164,160 +35,84 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> { ...@@ -164,160 +35,84 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x"); bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
if (std::is_same<paddle::platform::float16, T>::value) { const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx); const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
} else { XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { auto x_dims = x->dims();
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, ctx); auto y_dims = y->dims();
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, ctx); XpuFcInfo fc_info;
} else { GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info);
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx); auto& dev_ctx =
} ctx.template device_context<paddle::platform::XPUDeviceContext>();
} xpu::Context* xpu_ctx = dev_ctx.x_context();
MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
} }
}; };
template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
const DeviceContext& context, const framework::Tensor& input) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> in_shape_host = {static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]),
static_cast<int>(in_dims[2])};
std::vector<int> axis_host = {1, 0, 2};
int r = xpu::transpose(context.x_context(),
reinterpret_cast<const XPUType*>(input.data<T>()),
reinterpret_cast<XPUType*>(output.data<T>()),
in_shape_host,
axis_host);
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
template <typename T> template <typename T>
class MatMulV2XPUGradKernel : public framework::OpKernel<T> { class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
public: using XPUType = typename XPUTypeTrait<T>::Type;
void MatMul(const framework::ExecutionContext& ctx,
const framework::Tensor& a,
bool trans_a,
const framework::Tensor& b,
bool trans_b,
framework::Tensor* out) const {
out->mutable_data<T>(ctx.GetPlace());
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
} else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, ctx);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(&a, &b, out, trans_a, trans_b, ctx);
} else {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
}
}
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a,
bool trans_a,
bool is_fold_init_dims_a,
const framework::Tensor& b,
bool trans_b,
bool is_fold_init_dims_b,
framework::Tensor* out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
MatMul(
context,
is_fold_init_dims_a
? FoldInitDims(a)
: XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
dev_ctx, a),
trans_a,
is_fold_init_dims_b
? FoldInitDims(b)
: XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
dev_ctx, b),
trans_b,
out);
}
}
public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool transpose_x = context.Attr<bool>("trans_x"); bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y"); bool transpose_y = context.Attr<bool>("trans_y");
auto x = *context.Input<framework::Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y"); auto y = *context.Input<framework::Tensor>("Y");
auto dout = auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx->mutable_data<T>(context.GetPlace());
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
} }
framework::DDim dy_dims;
if (dy) { if (dy) {
dy_dims = dy->dims(); dy->mutable_data<T>(context.GetPlace());
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
if (transpose_x && transpose_y) {
CalcInputGrad(context, y, true, true, dout, true, false, dx);
CalcInputGrad(context, dout, true, true, x, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(context, y, false, false, dout, true, false, dx);
CalcInputGrad(context, x, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(context, dout, false, false, y, false, true, dx);
CalcInputGrad(context, dout, true, true, x, false, true, dy);
} else {
CalcInputGrad(context, dout, false, false, y, true, false, dx);
CalcInputGrad(context, x, true, true, dout, false, true, dy);
} }
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout.data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
xpu::Context* xpu_ctx = dev_ctx.x_context();
XpuFcInfo info_forward;
GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward);
xpu::ctx_guard RAII_GUARD(xpu_ctx);
// begin calculate
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 = (dx == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->data<T>());
XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->data<T>());
XpuFcInfo info_dx;
XpuFcInfo info_dy;
std::tuple<XpuFcInfo,
XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
transpose_x,
transpose_y,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) { if (dx) {
if (dx_dims != x.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
dx->Resize(dx_dims);
} }
}
if (dy) { if (dy) {
if (dy_dims != y.dims()) { MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
dy->Resize(dy_dims);
}
} }
} }
}; };
......
...@@ -49,50 +49,23 @@ class MulXPUKernel : public framework::OpKernel<T> { ...@@ -49,50 +49,23 @@ class MulXPUKernel : public framework::OpKernel<T> {
*y, context.template Attr<int>("y_num_col_dims")) *y, context.template Attr<int>("y_num_col_dims"))
: *y; : *y;
z->mutable_data<T>(context.GetPlace()); z->mutable_data<T>(context.GetPlace());
auto z_dim = z->dims();
if (z_dim.size() != 2) { const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x_matrix.data<T>());
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y_matrix.data<T>());
} XPUType* out_ptr = reinterpret_cast<XPUType*>(z->data<T>());
bool trans_a = false; bool trans_a = false;
bool trans_b = false; bool trans_b = false;
int m = x_matrix.dims()[0]; auto x_dims = x_matrix.dims();
int k = x_matrix.dims()[1]; auto y_dims = y_matrix.dims();
int k1 = y_matrix.dims()[0];
int n = y_matrix.dims()[1]; XpuFcInfo fc_info;
PADDLE_ENFORCE_EQ( GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info);
k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op")); auto& dev_ctx =
T alpha = static_cast<T>(1.0); context.template device_context<paddle::platform::XPUDeviceContext>();
T beta = static_cast<T>(0.0); xpu::Context* xpu_ctx = dev_ctx.x_context();
const T* data_a = x_matrix.data<T>();
const T* data_b = y_matrix.data<T>(); MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
T* data_c = z->data<T>();
auto& dev_ctx = context.template device_context<DeviceContext>();
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
trans_a,
trans_b,
nullptr,
nullptr,
nullptr,
k,
n,
n,
alpha,
beta,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
} }
}; };
...@@ -125,98 +98,51 @@ class MulGradXPUKernel : public framework::OpKernel<T> { ...@@ -125,98 +98,51 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
dy->set_lod(y->lod()); dy->set_lod(y->lod());
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
XpuFcInfo info_forward;
GetFCInfo(x_matrix.dims(), y_matrix.dims(), false, false, &info_forward);
const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout->data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
xpu::Context* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
// begin calculate
const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
XPUType* c_1 =
(dx == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dx->mutable_data<T>(ctx.GetPlace()));
XPUType* c_2 =
(dy == NULL)
? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->mutable_data<T>(ctx.GetPlace()));
XpuFcInfo info_dx;
XpuFcInfo info_dy;
std::tuple<XpuFcInfo,
XpuFcInfo,
const XPUType*,
const XPUType*,
const XPUType*,
const XPUType*>
fc_info = MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
info_forward,
false,
false,
x_ptr,
y_ptr,
dout_ptr);
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) { if (dx) {
dx->mutable_data<T>(ctx.GetPlace()); MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
Tensor dx_matrix = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
// blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
bool trans_a = false;
bool trans_b = true;
int m = dout_mat.dims()[0];
int k = dout_mat.dims()[1];
int n = y_matrix.dims()[0];
int k1 = y_matrix.dims()[1];
PADDLE_ENFORCE_EQ(
k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op"));
int lda = (!trans_a) ? k : m;
int ldb = (!trans_b) ? n : k;
int ldc = n;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
const T* data_a = dout->data<T>();
const T* data_b = y_matrix.data<T>();
T* data_c = dx_matrix.data<T>();
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
trans_a,
trans_b,
nullptr,
nullptr,
nullptr,
lda,
ldb,
ldc,
alpha,
beta,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
Tensor dy_matrix = dy->dims().size() > 2
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
// blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
bool trans_a = true;
bool trans_b = false;
int k = x_matrix.dims()[0];
int m = x_matrix.dims()[1];
int k1 = dout_mat.dims()[0];
int n = dout_mat.dims()[1];
PADDLE_ENFORCE_EQ(
k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op"));
int lda = (!trans_a) ? k : m;
int ldb = (!trans_b) ? n : k;
int ldc = n;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
const T* data_a = x_matrix.data<T>();
const T* data_b = dout->data<T>();
T* data_c = dy_matrix.data<T>();
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c),
m,
n,
k,
trans_a,
trans_b,
nullptr,
nullptr,
nullptr,
lda,
ldb,
ldc,
alpha,
beta,
nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
} }
} }
}; };
......
...@@ -12,12 +12,176 @@ limitations under the License. */ ...@@ -12,12 +12,176 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include <vector> #include <vector>
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using float16 = typename XPUTypeTrait<paddle::platform::float16>::Type;
enum XPUFCCalcType {
FC_INT16 = 0,
FC_INT32,
FC_FLOAT,
};
template <typename T>
XPUFCCalcType FCCalcType() {
if (std::is_same<paddle::platform::float16, T>::value ||
std::is_same<float16, T>::value) {
return XPUFCCalcType::FC_INT16;
} else if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
return XPUFCCalcType::FC_INT32;
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
return XPUFCCalcType::FC_FLOAT;
}
return XPUFCCalcType::FC_INT16;
}
struct XpuFcInfo {
int bs;
int m;
int n;
int k;
bool trans_x;
bool trans_y;
int stride_x;
int stride_y;
int stride_out;
float* max_x;
float* max_y;
float* max_out;
XpuFcInfo()
: bs(0),
m(0),
n(0),
k(0),
trans_x(false),
trans_y(false),
stride_x(0),
stride_y(0),
stride_out(0),
max_x(nullptr),
max_y(nullptr),
max_out(nullptr) {}
void InitFcInfo(int bs,
int m,
int n,
int k,
bool trans_x,
bool trans_y,
float* max_x,
float* max_y,
float* max_out) {
this->bs = bs;
this->m = m;
this->n = n;
this->k = k;
this->trans_x = trans_x;
this->trans_y = trans_y;
this->max_x = max_x;
this->max_y = max_y;
this->max_out = max_out;
if (this->bs <= 1) {
this->stride_x = trans_x ? m : k;
this->stride_y = trans_y ? k : n;
this->stride_out = n;
} else {
this->stride_x = m * k;
this->stride_y = k * n;
this->stride_out = m * n;
}
}
};
static std::ostream& operator<<(std::ostream& os, const XpuFcInfo& fc_inf) {
os << "fc_inf[ bs, m, n, k, trans_x, trans_y, stride_x, stride_y, "
"stride_out] = "
<< "[" << fc_inf.bs << ", " << fc_inf.m << ", " << fc_inf.n << ", "
<< fc_inf.k << ", " << fc_inf.trans_x << ", " << fc_inf.trans_y << ", "
<< fc_inf.stride_x << ", " << fc_inf.stride_y << ", " << fc_inf.stride_out;
return os;
}
static void GetFCInfo(const phi::DDim& x_dims,
const phi::DDim& y_dims,
bool trans_x,
bool trans_y,
XpuFcInfo* info) {
framework::DDim new_x_dims =
(x_dims.size() > 1) ? x_dims : phi::make_ddim({1, x_dims[0]});
framework::DDim new_y_dims =
(y_dims.size() > 1) ? y_dims : phi::make_ddim({y_dims[0], 1});
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(new_x_dims, 0, trans_x);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(new_y_dims, 0, trans_y);
if (x_dims.size() >= 3 && y_dims.size() <= 2) {
if (!trans_x) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
if (y_dims.size() >= 3 && x_dims.size() <= 2) {
PADDLE_ENFORCE_EQ(
mat_dim_b.trans_,
false,
platform::errors::InvalidArgument(
"xpu not support this Shape in matmul_op xdims = %s ydims = %s "
"x_trans = %d y_trans = %d",
x_dims.to_str(),
y_dims.to_str(),
mat_dim_a.trans_,
mat_dim_b.trans_));
mat_dim_b.height_ *= mat_dim_b.batch_size_;
mat_dim_b.batch_size_ = 0;
}
if (mat_dim_a.width_ == mat_dim_b.height_) {
if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
}
}
PADDLE_ENFORCE_EQ(mat_dim_a.width_,
mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op xdims = %s ydims = %s "
"x_trans = %d y_trans = %d",
x_dims.to_str(),
y_dims.to_str(),
mat_dim_a.trans_,
mat_dim_b.trans_));
info->m = mat_dim_a.height_;
info->n = mat_dim_b.width_;
info->k = mat_dim_a.width_;
info->bs = mat_dim_a.batch_size_;
info->trans_x = trans_x;
info->trans_y = trans_y;
if (info->bs <= 1) {
info->stride_x = trans_x ? info->m : info->k;
info->stride_y = trans_y ? info->k : info->n;
info->stride_out = info->n;
} else {
info->stride_x = info->m * info->k;
info->stride_y = info->k * info->n;
info->stride_out = info->m * info->n;
}
}
template <typename XPUType, typename FCT> template <typename XPUType, typename FCT>
int xpu_fc_wrapper(xpu::Context* ctx, static void xpu_fc_wrapper(xpu::Context* ctx,
const XPUType* x, const XPUType* x,
const XPUType* w, const XPUType* w,
XPUType* y, XPUType* y,
...@@ -42,12 +206,12 @@ int xpu_fc_wrapper(xpu::Context* ctx, ...@@ -42,12 +206,12 @@ int xpu_fc_wrapper(xpu::Context* ctx,
XPUType* l3_addr = nullptr; XPUType* l3_addr = nullptr;
xpu::ctx_guard RAII_GUARD(ctx); xpu::ctx_guard RAII_GUARD(ctx);
l3_addr = RAII_GUARD.alloc_l3_or_gm<XPUType>(m * k); l3_addr = RAII_GUARD.alloc_l3_or_gm<XPUType>(m * k);
if (l3_addr == nullptr) return XPUERR_NOMEM; PADDLE_ENFORCE_XDNN_NOT_NULL(l3_addr);
std::vector<int> shape = {k, m}; std::vector<int> shape = {k, m};
std::vector<int> axis = {1, 0}; std::vector<int> axis = {1, 0};
r = xpu::transpose<XPUType>(ctx, x, l3_addr, shape, axis); r = xpu::transpose<XPUType>(ctx, x, l3_addr, shape, axis);
if (r != XPU_SUCCESS) return r; PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
r = xpu::fc_fusion<XPUType, XPUType, XPUType, FCT>(ctx, r = xpu::fc_fusion<XPUType, XPUType, XPUType, FCT>(ctx,
l3_addr, l3_addr,
...@@ -68,7 +232,7 @@ int xpu_fc_wrapper(xpu::Context* ctx, ...@@ -68,7 +232,7 @@ int xpu_fc_wrapper(xpu::Context* ctx,
beta, beta,
bias, bias,
act); act);
if (r != XPU_SUCCESS) return r; PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion");
} else { } else {
r = xpu::fc_fusion<XPUType, XPUType, XPUType, FCT>(ctx, r = xpu::fc_fusion<XPUType, XPUType, XPUType, FCT>(ctx,
x, x,
...@@ -89,8 +253,356 @@ int xpu_fc_wrapper(xpu::Context* ctx, ...@@ -89,8 +253,356 @@ int xpu_fc_wrapper(xpu::Context* ctx,
beta, beta,
bias, bias,
act); act);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion");
} }
return r; }
template <>
void xpu_fc_wrapper<float16, int32_t>(xpu::Context* ctx,
const float16* x,
const float16* w,
float16* y,
int m,
int n,
int k,
bool x_trans,
bool w_trans,
const float* x_maxptr,
const float* w_maxptr,
float* y_maxptr,
int ldx,
int ldw,
int ldy,
float alpha,
float beta,
const float* bias,
const xpu::Activation_t& act) {
int r = xpu::Error_t::INVALID_PARAM;
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper");
}
template <typename XPUType, typename FCT>
static void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx,
int bs,
bool trans_x,
bool trans_w,
int m,
int n,
int k,
float alpha,
const XPUType* x,
int stride_x,
const XPUType* w,
int stride_w,
float beta,
XPUType* y,
int stride_y,
const float* x_maxptr,
const float* w_maxptr) {
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
xpu_ctx, // Context* ctx,
bs, // int batch_size,
trans_x, // bool x_trans,
trans_w, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
alpha, // float alpha,
reinterpret_cast<const XPUType*>(x), // const TX* x,
stride_x, // int stride_a,
reinterpret_cast<const XPUType*>(w), // const TW* w,
stride_w, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType*>(y), // TY* y,
stride_y, // int stride_c,
x_maxptr, // const float* x_maxptr,
w_maxptr); // const float* w_maxptr
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched");
}
template <>
void xpu_fc_batch_wrapper<float16, int32_t>(xpu::Context* xpu_ctx,
int bs,
bool trans_x,
bool trans_w,
int m,
int n,
int k,
float alpha,
const float16* x,
int stride_x,
const float16* w,
int stride_w,
float beta,
float16* y,
int stride_y,
const float* x_maxptr,
const float* w_maxptr) {
int r = xpu::Error_t::INVALID_PARAM;
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper");
}
template <>
void xpu_fc_batch_wrapper<float16, float>(xpu::Context* xpu_ctx,
int bs,
bool trans_x,
bool trans_w,
int m,
int n,
int k,
float alpha,
const float16* x,
int stride_x,
const float16* w,
int stride_w,
float beta,
float16* y,
int stride_y,
const float* x_maxptr,
const float* w_maxptr) {
int r = xpu::Error_t::INVALID_PARAM;
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper");
}
template <typename T>
static void MatMulXPUFunction(xpu::Context* xpu_ctx,
const T* x,
const T* y,
T* out,
const XpuFcInfo& fcinfo,
float alpha) {
using XPUType = typename XPUTypeTrait<T>::Type;
using float16 = typename XPUTypeTrait<paddle::platform::float16>::Type;
int fccal_type = FCCalcType<XPUType>();
decltype(&paddle::operators::xpu_fc_wrapper<XPUType, int16_t>)
fc_api_list[3] = {
&paddle::operators::xpu_fc_wrapper<XPUType, int16_t>,
&paddle::operators::xpu_fc_wrapper<XPUType, int32_t>,
&paddle::operators::xpu_fc_wrapper<XPUType, float>,
};
decltype(&paddle::operators::xpu_fc_batch_wrapper<XPUType, int16_t>)
fc_batch_api_list[3] = {
&paddle::operators::xpu_fc_batch_wrapper<XPUType, int16_t>,
&paddle::operators::xpu_fc_batch_wrapper<XPUType, int32_t>,
&paddle::operators::xpu_fc_batch_wrapper<XPUType, float>,
};
auto fc_api = fc_api_list[fccal_type];
auto fc_batch_api = fc_batch_api_list[fccal_type];
int m = fcinfo.m;
int n = fcinfo.n;
int k = fcinfo.k;
int batch_size = fcinfo.bs;
int ldx = fcinfo.stride_x;
int ldy = fcinfo.stride_y;
int ldout = fcinfo.stride_out;
bool trans_x = fcinfo.trans_x;
bool trans_y = fcinfo.trans_y;
float* max_x = fcinfo.max_x;
float* max_y = fcinfo.max_y;
float* max_out = fcinfo.max_out;
if (batch_size <= 1) {
fc_api(xpu_ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<const XPUType*>(y),
reinterpret_cast<XPUType*>(out),
m,
n,
k,
trans_x,
trans_y,
max_x,
max_y,
max_out,
ldx,
ldy,
ldout,
alpha,
0,
nullptr,
xpu::Activation_t::LINEAR);
} else {
// batch matmul
fc_batch_api(xpu_ctx, // Context* ctx,
batch_size, // int batch_size,
trans_x, // bool x_trans,
trans_y, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
alpha, // float alpha,
reinterpret_cast<const XPUType*>(x), // const TX* x,
ldx, // int stride_a,
reinterpret_cast<const XPUType*>(y), // const TW* w,
ldy, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType*>(out), // TY* y,
ldout, // int stride_c,
max_x, // const float* x_maxptr,
max_y); // const float* w_maxptr
}
}
template <typename T>
static std::tuple<XpuFcInfo, XpuFcInfo, const T*, const T*, const T*, const T*>
MatmulGradFcInfo(xpu::Context* xpu_ctx,
xpu::ctx_guard* RAII_GUARD,
const XpuFcInfo& dout_shape,
bool trans_x,
bool trans_y,
const T* x,
const T* y,
const T* dout) {
XpuFcInfo dx_shape, dy_shape;
const T* dx_a = NULL;
const T* dx_b = NULL;
const T* dy_a = NULL;
const T* dy_b = NULL;
bool copy_to_l3 = false;
float* max_dout = NULL;
int maxptr_size = xpu_ctx->max_ptr_size();
uint64_t l3_size = uint64_t(xpu_ctx->_l3_mgr.get_size());
int bs = (dout_shape.bs <= 1) ? (1) : (dout_shape.bs);
int dx_size = bs * dout_shape.m * dout_shape.k;
int dy_size = bs * dout_shape.k * dout_shape.n;
int dout_size = bs * dout_shape.m * dout_shape.n;
if (trans_x && trans_y) {
copy_to_l3 = l3_size >= (dout_size * 2 + dy_size) * sizeof(T);
} else if (trans_x) {
copy_to_l3 = l3_size >= dout_size * sizeof(T);
} else if (trans_y) {
copy_to_l3 = l3_size >= dout_size * 2 * sizeof(T);
} else {
copy_to_l3 = l3_size >= (dout_size + dx_size) * sizeof(T);
}
const T* dout_new = dout;
int r = 0;
if (copy_to_l3) {
T* dout_l3 = RAII_GUARD->alloc_l3<T>(dout_size);
PADDLE_ENFORCE_XDNN_NOT_NULL(dout_l3);
if ((dout_shape.bs > 1) || ((dout_shape.bs <= 1) &&
(FCCalcType<T>() == XPUFCCalcType::FC_FLOAT))) {
r = xpu::copy(xpu_ctx, dout, dout_l3, dout_size);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
dout_new = dout_l3;
} else {
max_dout = RAII_GUARD->alloc_l3_or_gm<float>(maxptr_size);
PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout);
r = xpu::findmax_copy_fusion(xpu_ctx, dout, max_dout, dout_l3, dout_size);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion");
dout_new = dout_l3;
}
} else if (((dout_shape.bs <= 1) &&
(FCCalcType<T>() != XPUFCCalcType::FC_FLOAT))) {
max_dout = RAII_GUARD->alloc_l3_or_gm<float>(maxptr_size);
PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout);
r = xpu::findmax_copy_fusion(
xpu_ctx, dout, max_dout, reinterpret_cast<T*>(NULL), dout_size);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion");
}
if (trans_x && trans_y) {
// dx = T(y) * T(dout)
dx_shape.InitFcInfo(dout_shape.bs,
dout_shape.k,
dout_shape.m,
dout_shape.n,
true,
true,
nullptr,
max_dout,
nullptr);
dx_a = y, dx_b = dout_new;
// dy = T(dout) * T(x)
dy_shape.InitFcInfo(dout_shape.bs,
dout_shape.n,
dout_shape.k,
dout_shape.m,
true,
true,
max_dout,
nullptr,
nullptr);
dy_a = dout_new, dy_b = x;
} else if (trans_x) {
// dx = y * T(dout)
dx_shape.InitFcInfo(dout_shape.bs,
dout_shape.k,
dout_shape.m,
dout_shape.n,
false,
true,
nullptr,
max_dout,
nullptr);
dx_a = y, dx_b = dout_new;
// dy = x * dout
dy_shape.InitFcInfo(dout_shape.bs,
dout_shape.k,
dout_shape.n,
dout_shape.m,
false,
false,
nullptr,
max_dout,
nullptr);
dy_a = x, dy_b = dout_new;
} else if (trans_y) {
// dx = dout * y
dx_shape.InitFcInfo(dout_shape.bs,
dout_shape.m,
dout_shape.k,
dout_shape.n,
false,
false,
max_dout,
nullptr,
nullptr);
dx_a = dout_new, dx_b = y;
// dy = T(dout) * x
dy_shape.InitFcInfo(dout_shape.bs,
dout_shape.n,
dout_shape.k,
dout_shape.m,
true,
false,
max_dout,
nullptr,
nullptr);
dy_a = dout_new, dy_b = x;
} else {
// dx = dout * T(y)
dx_shape.InitFcInfo(dout_shape.bs,
dout_shape.m,
dout_shape.k,
dout_shape.n,
false,
true,
max_dout,
nullptr,
nullptr);
dx_a = dout_new, dx_b = y;
// dy = T(x) * dout
dy_shape.InitFcInfo(dout_shape.bs,
dout_shape.k,
dout_shape.n,
dout_shape.m,
true,
false,
nullptr,
max_dout,
nullptr);
dy_a = x, dy_b = dout_new;
}
std::tuple<XpuFcInfo, XpuFcInfo, const T*, const T*, const T*, const T*>
result = std::make_tuple(dx_shape, dy_shape, dx_a, dx_b, dy_a, dy_b);
return result;
} }
} // namespace operators } // namespace operators
......
...@@ -281,11 +281,18 @@ XPUOpMap& get_kl2_ops() { ...@@ -281,11 +281,18 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"matmul_v2_grad", {"matmul_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"matmul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mean_grad", {"mean_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
...@@ -84,7 +84,9 @@ type_dict_str_to_numpy = { ...@@ -84,7 +84,9 @@ type_dict_str_to_numpy = {
xpu_test_op_white_list = [] xpu_test_op_white_list = []
xpu_test_device_type_white_list = ['xpu1_float64'] xpu_test_device_type_white_list = ['xpu1_float64']
xpu_test_op_type_white_list = ['dropout_float16', 'dropout_grad_float16'] xpu_test_op_type_white_list = [
'dropout_float16', 'dropout_grad_float16', 'matmul_v2_float16'
]
xpu_test_device_op_white_list = [] xpu_test_device_op_white_list = []
xpu_test_device_op_type_white_list = [] xpu_test_device_op_type_white_list = []
......
...@@ -303,7 +303,8 @@ class TestMatmulBaseGenerator(XPUOpTest): ...@@ -303,7 +303,8 @@ class TestMatmulBaseGenerator(XPUOpTest):
X = np.random.random(shape_X).astype(self.dtype) X = np.random.random(shape_X).astype(self.dtype)
Y = np.random.random(shape_Y).astype(self.dtype) Y = np.random.random(shape_Y).astype(self.dtype)
Out = reference_matmul(X, Y, transpose_X, transpose_Y) Out = reference_matmul(X, Y, transpose_X,
transpose_Y).astype(self.dtype)
self.inputs = {'X': X, 'Y': Y} self.inputs = {'X': X, 'Y': Y}
self.attrs = {'transpose_X': transpose_X, 'transpose_Y': transpose_Y} self.attrs = {'transpose_X': transpose_X, 'transpose_Y': transpose_Y}
self.outputs = {'Out': Out} self.outputs = {'Out': Out}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册