diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index 945546a502aebb10072542b8b7008486c077d38a..efad516cdbfe50f52a45195d59f315ad7eee67c5 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -20,276 +20,40 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/xpu_api_wrapper.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { using framework::Tensor; -static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) { - if (x_dim.size() > 1) { - return x_dim; - } - return phi::make_ddim({1, x_dim[0]}); -} - -static framework::Tensor FoldInitDims(const framework::Tensor &input) { - auto output = input; - auto in_dims = input.dims(); - if (in_dims.size() == 3) { - output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); - } - return output; -} -/** - * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the - * original y_dim is returned. - */ -static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) { - if (y_dim.size() > 1) { - return y_dim; - } - return phi::make_ddim({y_dim[0], 1}); -} - -static void ReshapeTensorIntoMatrixSequence( - framework::Tensor *x, const phi::funcs::MatDescriptor &descriptor) { - int64_t h, w; - h = descriptor.height_; - w = descriptor.width_; - if (descriptor.trans_) { - std::swap(w, h); - } - if (descriptor.batch_size_) { - x->Resize({descriptor.batch_size_, h, w}); - } else { - x->Resize({h, w}); - } -} -/** - * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor - * Out = matmul(x, y) - * - * This method will first calculate X,Y matrix sequence, and then calculate - * the out shape. - * - * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] - * The out = [BatchSize, H1, W2] - * - * If there is no batch size in `X` and `Y`, the out will be [H1, W2] - * If any of `X` and `Y` has batch size BatchSize, the out will have the - * BatchSize. - */ -static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, - framework::Tensor *y, - framework::Tensor *out, - bool trans_x, - bool trans_y) { - auto x_dim = RowMatrixFromVector(x->dims()); - auto y_dim = ColumnMatrixFromVector(y->dims()); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); - if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { - out->Resize({mat_dim_x.height_, mat_dim_y.width_}); - } else { - out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), - mat_dim_x.height_, - mat_dim_y.width_}); - } - - ReshapeTensorIntoMatrixSequence(x, mat_dim_x); - ReshapeTensorIntoMatrixSequence(y, mat_dim_y); -} - -template -static void MatMulXPUFunction(const Tensor *x, - const Tensor *y, - Tensor *out, - bool trans_x, - bool trans_y, - const paddle::framework::ExecutionContext &ctx) { - using XPUType = typename XPUTypeTrait::Type; - const auto &x_dims = x->dims(); - const auto &y_dims = y->dims(); - auto &dev_ctx = - ctx.template device_context(); - - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x_dims), 0, trans_x); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( - ColumnMatrixFromVector(y_dims), 0, trans_y); - - if (x_dims.size() == 3 && y_dims.size() <= 2) { - // if transpose_X is true, the transpose cost much time - if (!trans_x) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } else { - mat_dim_b.batch_size_ = mat_dim_a.batch_size_; - mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; - } - } - - if (mat_dim_a.width_ == mat_dim_b.height_) { - if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) { - mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; - } - if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) { - mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; - } - } - - PADDLE_ENFORCE_EQ(mat_dim_a.width_, - mat_dim_b.height_, - platform::errors::InvalidArgument( - "Shape mistake in matmul_op, the " - "first tensor width must be same as " - "second tensor height, but received " - "width:%d, height:%d x_dims = %s , y_dims = %s", - mat_dim_a.width_, - mat_dim_b.height_, - x_dims.to_str().c_str(), - y_dims.to_str().c_str())); - PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, - mat_dim_b.batch_size_, - platform::errors::InvalidArgument( - "Shape mistake in matmul_op, the two input" - "tensor batch_size must be same, but received first " - "tensor batch_size:%d, second " - "tensor batch_size:%d, x_dims = %s , y_dims = %s", - mat_dim_a.batch_size_, - mat_dim_b.batch_size_, - x_dims.to_str().c_str(), - y_dims.to_str().c_str())); - - float alpha = static_cast(ctx.Attr("alpha")); - T *data_c = out->data(); - int m = mat_dim_a.height_; - int n = mat_dim_b.width_; - int k = mat_dim_a.width_; - int batch_size = mat_dim_a.batch_size_; - int ldx = mat_dim_a.trans_ ? m : k; - int ldy = mat_dim_b.trans_ ? k : n; - int ldout = n; - if (batch_size <= 1) { - int r = 0; - r = xpu_fc_wrapper( - dev_ctx.x_context(), - reinterpret_cast(x->data()), - reinterpret_cast(y->data()), - reinterpret_cast(data_c), - m, - n, - k, - mat_dim_a.trans_, - mat_dim_b.trans_, - nullptr, - nullptr, - nullptr, - ldx, - ldy, - ldout, - alpha, - 0, - nullptr, - xpu::Activation_t::LINEAR); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External( - "XPU fc kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); - } else { - // batch matmul - int r = xpu::fc_batched( - dev_ctx.x_context(), // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - alpha, // float alpha, - reinterpret_cast(x->data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y->data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr - - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU fc_batched kernel return wrong value[%d %s] " - "x_dims = %s , y_dims = %s", - r, - XPUAPIErrorMsg[r], - x_dims.to_str().c_str(), - y_dims.to_str().c_str())); - } -} - template class MatMulXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: - void Compute(const framework::ExecutionContext &context) const override { - auto *x = context.Input("X"); - auto *y = context.Input("Y"); - auto *out = context.Output("Out"); + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); bool trans_x = context.Attr("transpose_X"); bool trans_y = context.Attr("transpose_Y"); - if (std::is_same::value) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, context); - } else { - if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, context); - } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, context); - } else { - MatMulXPUFunction(x, y, out, trans_x, trans_y, context); - } - } + float alpha = static_cast(context.Attr("alpha")); + const XPUType* x_ptr = reinterpret_cast(x->data()); + const XPUType* y_ptr = reinterpret_cast(y->data()); + XPUType* out_ptr = reinterpret_cast(out->data()); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + XpuFcInfo fc_info; + GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info); + auto& dev_ctx = + context.template device_context(); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha); } }; -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -static framework::Tensor XPUFoldHeadAndLastDims( - const DeviceContext &context, const framework::Tensor &input) { - using XPUType = typename XPUTypeTrait::Type; - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - - framework::Tensor output; - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector in_shape_host = {static_cast(in_dims[0]), - static_cast(in_dims[1]), - static_cast(in_dims[2])}; - std::vector axis_host = {1, 0, 2}; - int r = xpu::transpose(context.x_context(), - reinterpret_cast(input.data()), - reinterpret_cast(output.data()), - in_shape_host, - axis_host); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU transpose kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - - return output; -} - // Using dimensional constraints on matrix multiplication, it is // straight-forward to check the following table for when X and Y // are both matrices. @@ -317,107 +81,68 @@ static framework::Tensor XPUFoldHeadAndLastDims( // to X: (P * M) x K, dOut: (P * M) x N. template class MatMulGradXPUKernel : public framework::OpKernel { - public: - void MatMul(const framework::ExecutionContext &context, - const framework::Tensor &a, - bool trans_a, - const framework::Tensor &b, - bool trans_b, - framework::Tensor *out) const { - out->mutable_data(context.GetPlace()); - if (std::is_same::value) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); - } else { - if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); - } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); - } else { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); - } - } - } - - void CalcInputGrad(const framework::ExecutionContext &context, - const framework::Tensor &a, - bool trans_a, - bool is_fold_init_dims_a, - const framework::Tensor &b, - bool trans_b, - bool is_fold_init_dims_b, - framework::Tensor *out) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out); - } else { - auto &dev_ctx = context.template device_context(); - MatMul(context, - is_fold_init_dims_a - ? FoldInitDims(a) - : XPUFoldHeadAndLastDims(dev_ctx, a), - trans_a, - is_fold_init_dims_b - ? FoldInitDims(b) - : XPUFoldHeadAndLastDims(dev_ctx, b), - trans_b, - out); - } - } + using XPUType = typename XPUTypeTrait::Type; - void Compute(const framework::ExecutionContext &context) const override { + public: + void Compute(const framework::ExecutionContext& context) const override { auto x = *context.Input("X"); auto y = *context.Input("Y"); auto dout = *context.Input(framework::GradVarName("Out")); - auto *dx = context.Output(framework::GradVarName("X")); - auto *dy = context.Output(framework::GradVarName("Y")); + auto* dx = context.Output(framework::GradVarName("X")); + auto* dy = context.Output(framework::GradVarName("Y")); bool transpose_x = context.Attr("transpose_X"); bool transpose_y = context.Attr("transpose_Y"); - - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - - framework::DDim dx_dims; + float alpha = static_cast(context.Attr("alpha")); if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } + dx->mutable_data(context.GetPlace()); } - - framework::DDim dy_dims; if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - if (transpose_x && transpose_y) { - CalcInputGrad(context, y, true, true, dout, true, false, dx); - CalcInputGrad(context, dout, true, true, x, true, false, dy); - } else if (transpose_x) { - CalcInputGrad(context, y, false, false, dout, true, false, dx); - CalcInputGrad(context, x, false, false, dout, false, true, dy); - } else if (transpose_y) { - CalcInputGrad(context, dout, false, false, y, false, true, dx); - CalcInputGrad(context, dout, true, true, x, false, true, dy); - } else { - CalcInputGrad(context, dout, false, false, y, true, false, dx); - CalcInputGrad(context, x, true, true, dout, false, true, dy); + dy->mutable_data(context.GetPlace()); } - + auto& dev_ctx = + context.template device_context(); + + const XPUType* dout_ptr = reinterpret_cast(dout.data()); + const XPUType* x_ptr = reinterpret_cast(x.data()); + const XPUType* y_ptr = reinterpret_cast(y.data()); + + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + XpuFcInfo info_forward; + GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + // begin calculate + const XPUType* a_1 = reinterpret_cast(NULL); + const XPUType* b_1 = reinterpret_cast(NULL); + const XPUType* a_2 = reinterpret_cast(NULL); + const XPUType* b_2 = reinterpret_cast(NULL); + XPUType* c_1 = (dx == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dx->data()); + XPUType* c_2 = (dy == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dy->data()); + XpuFcInfo info_dx; + XpuFcInfo info_dy; + std::tuple + fc_info = MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + transpose_x, + transpose_y, + x_ptr, + y_ptr, + dout_ptr); + std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } + MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, alpha); } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } + MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, alpha); } } }; diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index e9cb665f4fc0e5ef6762a2a76a3508fe1dc2dc0d..7b4195c1c19fa2a7a489c84af7826dec1ff5cd46 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -16,146 +16,17 @@ #include #include - #include "paddle/fluid/operators/matmul_v2_op.h" + #include "paddle/fluid/operators/xpu_api_wrapper.h" namespace paddle { namespace operators { -template -static void MatMulXPUFunction(const Tensor* x, - const Tensor* y, - Tensor* out, - bool trans_x, - bool trans_y, - const paddle::framework::ExecutionContext& ctx) { - using XPUType = typename XPUTypeTrait::Type; - const auto& x_dims = x->dims(); - const auto& y_dims = y->dims(); - auto& dev_ctx = - ctx.template device_context(); - - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x_dims), 0, trans_x); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( - ColumnMatrixFromVector(y_dims), 0, trans_y); - - if (x_dims.size() >= 3 && y_dims.size() <= 2) { - // if transpose_X is true, the transpose cost much time - if (!trans_x) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } else { - mat_dim_b.batch_size_ = mat_dim_a.batch_size_; - mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; - } - } - - if (mat_dim_a.width_ == mat_dim_b.height_) { - if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) { - mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; - } - if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) { - mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; - } - } - - PADDLE_ENFORCE_EQ(mat_dim_a.width_, - mat_dim_b.height_, - platform::errors::InvalidArgument( - "Shape mistake in matmul_v2_op xdims = %s ydims = %s " - "x_trans = %d y_trans = %d", - x_dims.to_str(), - y_dims.to_str(), - mat_dim_a.trans_, - mat_dim_b.trans_)); - PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, - mat_dim_b.batch_size_, - platform::errors::InvalidArgument( - "Shape mistake in matmul_v2_op xdims = %s ydims = %s " - "x_trans = %d y_trans = %d", - x_dims.to_str(), - y_dims.to_str(), - mat_dim_a.trans_, - mat_dim_b.trans_)); - - T* data_c = out->data(); - int m = mat_dim_a.height_; - int n = mat_dim_b.width_; - int k = mat_dim_a.width_; - int batch_size = mat_dim_a.batch_size_; - int ldx = mat_dim_a.trans_ ? m : k; - int ldy = mat_dim_b.trans_ ? k : n; - int ldout = n; - if (batch_size <= 1) { - int r = 0; - r = xpu_fc_wrapper( - dev_ctx.x_context(), - reinterpret_cast(x->data()), - reinterpret_cast(y->data()), - reinterpret_cast(data_c), - m, - n, - k, - mat_dim_a.trans_, - mat_dim_b.trans_, - nullptr, - nullptr, - nullptr, - ldx, - ldy, - ldout, - 1.0, - 0, - nullptr, - xpu::Activation_t::LINEAR); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External( - "XPU fc kernel return wrong value[%d %s] , m = %d, n = " - "%d, " - "k = %d, a_tr = %d, b_tr = %d", - r, - XPUAPIErrorMsg[r], - m, - n, - k, - mat_dim_a.trans_, - mat_dim_b.trans_)); - } else { - // batch matmul - int r = xpu::fc_batched( - dev_ctx.x_context(), // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - 1.0, // float alpha, - reinterpret_cast(x->data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y->data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr - - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU fc_batched kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -} - template class MatMulV2XPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); @@ -164,160 +35,84 @@ class MatMulV2XPUKernel : public framework::OpKernel { bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); out->mutable_data(ctx.GetPlace()); - if (std::is_same::value) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else { - if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } - } + const XPUType* x_ptr = reinterpret_cast(x->data()); + const XPUType* y_ptr = reinterpret_cast(y->data()); + XPUType* out_ptr = reinterpret_cast(out->data()); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + XpuFcInfo fc_info; + GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info); + auto& dev_ctx = + ctx.template device_context(); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); } }; -template -static framework::Tensor XPUFoldHeadAndLastDims( - const DeviceContext& context, const framework::Tensor& input) { - using XPUType = typename XPUTypeTrait::Type; - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - - framework::Tensor output; - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector in_shape_host = {static_cast(in_dims[0]), - static_cast(in_dims[1]), - static_cast(in_dims[2])}; - std::vector axis_host = {1, 0, 2}; - - int r = xpu::transpose(context.x_context(), - reinterpret_cast(input.data()), - reinterpret_cast(output.data()), - in_shape_host, - axis_host); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU transpose kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - - return output; -} - template class MatMulV2XPUGradKernel : public framework::OpKernel { - public: - void MatMul(const framework::ExecutionContext& ctx, - const framework::Tensor& a, - bool trans_a, - const framework::Tensor& b, - bool trans_b, - framework::Tensor* out) const { - out->mutable_data(ctx.GetPlace()); - if (std::is_same::value) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else { - if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } - } - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, - bool trans_a, - bool is_fold_init_dims_a, - const framework::Tensor& b, - bool trans_b, - bool is_fold_init_dims_b, - framework::Tensor* out) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out); - } else { - auto& dev_ctx = - context.template device_context(); - MatMul( - context, - is_fold_init_dims_a - ? FoldInitDims(a) - : XPUFoldHeadAndLastDims( - dev_ctx, a), - trans_a, - is_fold_init_dims_b - ? FoldInitDims(b) - : XPUFoldHeadAndLastDims( - dev_ctx, b), - trans_b, - out); - } - } + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); - auto x = *context.Input("X"); auto y = *context.Input("Y"); auto dout = *context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto* dy = context.Output(framework::GradVarName("Y")); - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - - framework::DDim dx_dims; if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } + dx->mutable_data(context.GetPlace()); } - - framework::DDim dy_dims; if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - if (transpose_x && transpose_y) { - CalcInputGrad(context, y, true, true, dout, true, false, dx); - CalcInputGrad(context, dout, true, true, x, true, false, dy); - } else if (transpose_x) { - CalcInputGrad(context, y, false, false, dout, true, false, dx); - CalcInputGrad(context, x, false, false, dout, false, true, dy); - } else if (transpose_y) { - CalcInputGrad(context, dout, false, false, y, false, true, dx); - CalcInputGrad(context, dout, true, true, x, false, true, dy); - } else { - CalcInputGrad(context, dout, false, false, y, true, false, dx); - CalcInputGrad(context, x, true, true, dout, false, true, dy); + dy->mutable_data(context.GetPlace()); } - + auto& dev_ctx = + context.template device_context(); + + const XPUType* dout_ptr = reinterpret_cast(dout.data()); + const XPUType* x_ptr = reinterpret_cast(x.data()); + const XPUType* y_ptr = reinterpret_cast(y.data()); + + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + XpuFcInfo info_forward; + GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + // begin calculate + const XPUType* a_1 = reinterpret_cast(NULL); + const XPUType* b_1 = reinterpret_cast(NULL); + const XPUType* a_2 = reinterpret_cast(NULL); + const XPUType* b_2 = reinterpret_cast(NULL); + XPUType* c_1 = (dx == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dx->data()); + XPUType* c_2 = (dy == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dy->data()); + XpuFcInfo info_dx; + XpuFcInfo info_dy; + std::tuple + fc_info = MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + transpose_x, + transpose_y, + x_ptr, + y_ptr, + dout_ptr); + std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } + MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } + MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); } } }; diff --git a/paddle/fluid/operators/mul_op_xpu.cc b/paddle/fluid/operators/mul_op_xpu.cc index 706af96d1a6c4f890c9e2a0c32e87f66c08c9483..727a7c0f6e52ccaee85b2674e7511559561f1436 100644 --- a/paddle/fluid/operators/mul_op_xpu.cc +++ b/paddle/fluid/operators/mul_op_xpu.cc @@ -49,50 +49,23 @@ class MulXPUKernel : public framework::OpKernel { *y, context.template Attr("y_num_col_dims")) : *y; z->mutable_data(context.GetPlace()); - auto z_dim = z->dims(); - if (z_dim.size() != 2) { - z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } + + const XPUType* x_ptr = reinterpret_cast(x_matrix.data()); + const XPUType* y_ptr = reinterpret_cast(y_matrix.data()); + XPUType* out_ptr = reinterpret_cast(z->data()); + bool trans_a = false; bool trans_b = false; - int m = x_matrix.dims()[0]; - int k = x_matrix.dims()[1]; - int k1 = y_matrix.dims()[0]; - int n = y_matrix.dims()[1]; - PADDLE_ENFORCE_EQ( - k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op")); - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - const T* data_a = x_matrix.data(); - const T* data_b = y_matrix.data(); - T* data_c = z->data(); - auto& dev_ctx = context.template device_context(); - - int ret = xpu_fc_wrapper( - dev_ctx.x_context(), - reinterpret_cast(data_a), - reinterpret_cast(data_b), - reinterpret_cast(data_c), - m, - n, - k, - trans_a, - trans_b, - nullptr, - nullptr, - nullptr, - k, - n, - n, - alpha, - beta, - nullptr, - xpu::Activation_t::LINEAR); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper"); - - if (z_dim.size() != 2) { - z->Resize(z_dim); - } + auto x_dims = x_matrix.dims(); + auto y_dims = y_matrix.dims(); + + XpuFcInfo fc_info; + GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info); + auto& dev_ctx = + context.template device_context(); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); } }; @@ -125,98 +98,51 @@ class MulGradXPUKernel : public framework::OpKernel { dy->set_lod(y->lod()); } auto& dev_ctx = ctx.template device_context(); + + XpuFcInfo info_forward; + GetFCInfo(x_matrix.dims(), y_matrix.dims(), false, false, &info_forward); + + const XPUType* dout_ptr = reinterpret_cast(dout->data()); + const XPUType* x_ptr = reinterpret_cast(x->data()); + const XPUType* y_ptr = reinterpret_cast(y->data()); + + xpu::Context* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + // begin calculate + const XPUType* a_1 = reinterpret_cast(NULL); + const XPUType* b_1 = reinterpret_cast(NULL); + const XPUType* a_2 = reinterpret_cast(NULL); + const XPUType* b_2 = reinterpret_cast(NULL); + XPUType* c_1 = + (dx == NULL) + ? reinterpret_cast(NULL) + : reinterpret_cast(dx->mutable_data(ctx.GetPlace())); + XPUType* c_2 = + (dy == NULL) + ? reinterpret_cast(NULL) + : reinterpret_cast(dy->mutable_data(ctx.GetPlace())); + XpuFcInfo info_dx; + XpuFcInfo info_dy; + std::tuple + fc_info = MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + false, + false, + x_ptr, + y_ptr, + dout_ptr); + std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; if (dx) { - dx->mutable_data(ctx.GetPlace()); - Tensor dx_matrix = dx->dims().size() > 2 - ? framework::ReshapeToMatrix(*dx, x_num_col_dims) - : *dx; - // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - // blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); - bool trans_a = false; - bool trans_b = true; - int m = dout_mat.dims()[0]; - int k = dout_mat.dims()[1]; - int n = y_matrix.dims()[0]; - int k1 = y_matrix.dims()[1]; - PADDLE_ENFORCE_EQ( - k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op")); - int lda = (!trans_a) ? k : m; - int ldb = (!trans_b) ? n : k; - int ldc = n; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - const T* data_a = dout->data(); - const T* data_b = y_matrix.data(); - T* data_c = dx_matrix.data(); - - int ret = xpu_fc_wrapper( - dev_ctx.x_context(), - reinterpret_cast(data_a), - reinterpret_cast(data_b), - reinterpret_cast(data_c), - m, - n, - k, - trans_a, - trans_b, - nullptr, - nullptr, - nullptr, - lda, - ldb, - ldc, - alpha, - beta, - nullptr, - xpu::Activation_t::LINEAR); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper"); + MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - Tensor dy_matrix = dy->dims().size() > 2 - ? framework::ReshapeToMatrix(*dy, y_num_col_dims) - : *dy; - // dy = x' * dout. dy K x N, dout : M x N, x : M x K - // blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); - bool trans_a = true; - bool trans_b = false; - int k = x_matrix.dims()[0]; - int m = x_matrix.dims()[1]; - int k1 = dout_mat.dims()[0]; - int n = dout_mat.dims()[1]; - PADDLE_ENFORCE_EQ( - k, k1, platform::errors::InvalidArgument("Shape mistake in mul_op")); - int lda = (!trans_a) ? k : m; - int ldb = (!trans_b) ? n : k; - int ldc = n; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - const T* data_a = x_matrix.data(); - const T* data_b = dout->data(); - T* data_c = dy_matrix.data(); - - int ret = xpu_fc_wrapper( - dev_ctx.x_context(), - reinterpret_cast(data_a), - reinterpret_cast(data_b), - reinterpret_cast(data_c), - m, - n, - k, - trans_a, - trans_b, - nullptr, - nullptr, - nullptr, - lda, - ldb, - ldc, - alpha, - beta, - nullptr, - xpu::Activation_t::LINEAR); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper"); + MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); } } }; diff --git a/paddle/fluid/operators/xpu_api_wrapper.h b/paddle/fluid/operators/xpu_api_wrapper.h index 8d51e53e8b394980bf7726f41736b13e410bc75b..c85a765f3b6fd31e0f84e4da99b02fb03adf5914 100644 --- a/paddle/fluid/operators/xpu_api_wrapper.h +++ b/paddle/fluid/operators/xpu_api_wrapper.h @@ -12,42 +12,206 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include +#include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { +using float16 = typename XPUTypeTrait::Type; + +enum XPUFCCalcType { + FC_INT16 = 0, + FC_INT32, + FC_FLOAT, +}; + +template +XPUFCCalcType FCCalcType() { + if (std::is_same::value || + std::is_same::value) { + return XPUFCCalcType::FC_INT16; + } else if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { + return XPUFCCalcType::FC_INT32; + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + return XPUFCCalcType::FC_FLOAT; + } + return XPUFCCalcType::FC_INT16; +} + +struct XpuFcInfo { + int bs; + int m; + int n; + int k; + bool trans_x; + bool trans_y; + int stride_x; + int stride_y; + int stride_out; + float* max_x; + float* max_y; + float* max_out; + XpuFcInfo() + : bs(0), + m(0), + n(0), + k(0), + trans_x(false), + trans_y(false), + stride_x(0), + stride_y(0), + stride_out(0), + max_x(nullptr), + max_y(nullptr), + max_out(nullptr) {} + void InitFcInfo(int bs, + int m, + int n, + int k, + bool trans_x, + bool trans_y, + float* max_x, + float* max_y, + float* max_out) { + this->bs = bs; + this->m = m; + this->n = n; + this->k = k; + this->trans_x = trans_x; + this->trans_y = trans_y; + this->max_x = max_x; + this->max_y = max_y; + this->max_out = max_out; + + if (this->bs <= 1) { + this->stride_x = trans_x ? m : k; + this->stride_y = trans_y ? k : n; + this->stride_out = n; + } else { + this->stride_x = m * k; + this->stride_y = k * n; + this->stride_out = m * n; + } + } +}; + +static std::ostream& operator<<(std::ostream& os, const XpuFcInfo& fc_inf) { + os << "fc_inf[ bs, m, n, k, trans_x, trans_y, stride_x, stride_y, " + "stride_out] = " + << "[" << fc_inf.bs << ", " << fc_inf.m << ", " << fc_inf.n << ", " + << fc_inf.k << ", " << fc_inf.trans_x << ", " << fc_inf.trans_y << ", " + << fc_inf.stride_x << ", " << fc_inf.stride_y << ", " << fc_inf.stride_out; + return os; +} + +static void GetFCInfo(const phi::DDim& x_dims, + const phi::DDim& y_dims, + bool trans_x, + bool trans_y, + XpuFcInfo* info) { + framework::DDim new_x_dims = + (x_dims.size() > 1) ? x_dims : phi::make_ddim({1, x_dims[0]}); + framework::DDim new_y_dims = + (y_dims.size() > 1) ? y_dims : phi::make_ddim({y_dims[0], 1}); + + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(new_x_dims, 0, trans_x); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(new_y_dims, 0, trans_y); + + if (x_dims.size() >= 3 && y_dims.size() <= 2) { + if (!trans_x) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } else { + mat_dim_b.batch_size_ = mat_dim_a.batch_size_; + mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; + } + } + + if (y_dims.size() >= 3 && x_dims.size() <= 2) { + PADDLE_ENFORCE_EQ( + mat_dim_b.trans_, + false, + platform::errors::InvalidArgument( + "xpu not support this Shape in matmul_op xdims = %s ydims = %s " + "x_trans = %d y_trans = %d", + x_dims.to_str(), + y_dims.to_str(), + mat_dim_a.trans_, + mat_dim_b.trans_)); + mat_dim_b.height_ *= mat_dim_b.batch_size_; + mat_dim_b.batch_size_ = 0; + } + + if (mat_dim_a.width_ == mat_dim_b.height_) { + if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) { + mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; + } + if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) { + mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; + } + } + + PADDLE_ENFORCE_EQ(mat_dim_a.width_, + mat_dim_b.height_, + platform::errors::InvalidArgument( + "Shape mistake in matmul_op xdims = %s ydims = %s " + "x_trans = %d y_trans = %d", + x_dims.to_str(), + y_dims.to_str(), + mat_dim_a.trans_, + mat_dim_b.trans_)); + + info->m = mat_dim_a.height_; + info->n = mat_dim_b.width_; + info->k = mat_dim_a.width_; + info->bs = mat_dim_a.batch_size_; + info->trans_x = trans_x; + info->trans_y = trans_y; + + if (info->bs <= 1) { + info->stride_x = trans_x ? info->m : info->k; + info->stride_y = trans_y ? info->k : info->n; + info->stride_out = info->n; + } else { + info->stride_x = info->m * info->k; + info->stride_y = info->k * info->n; + info->stride_out = info->m * info->n; + } +} + template -int xpu_fc_wrapper(xpu::Context* ctx, - const XPUType* x, - const XPUType* w, - XPUType* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { +static void xpu_fc_wrapper(xpu::Context* ctx, + const XPUType* x, + const XPUType* w, + XPUType* y, + int m, + int n, + int k, + bool x_trans, + bool w_trans, + const float* x_maxptr, + const float* w_maxptr, + float* y_maxptr, + int ldx, + int ldw, + int ldy, + float alpha, + float beta, + const float* bias, + const xpu::Activation_t& act) { int r = 0; if (x_trans && std::getenv("XPU_PADDLE_FC_TRANS_A") != nullptr && std::is_same::value) { XPUType* l3_addr = nullptr; xpu::ctx_guard RAII_GUARD(ctx); l3_addr = RAII_GUARD.alloc_l3_or_gm(m * k); - if (l3_addr == nullptr) return XPUERR_NOMEM; + PADDLE_ENFORCE_XDNN_NOT_NULL(l3_addr); std::vector shape = {k, m}; std::vector axis = {1, 0}; r = xpu::transpose(ctx, x, l3_addr, shape, axis); - if (r != XPU_SUCCESS) return r; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); r = xpu::fc_fusion(ctx, l3_addr, @@ -68,7 +232,7 @@ int xpu_fc_wrapper(xpu::Context* ctx, beta, bias, act); - if (r != XPU_SUCCESS) return r; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); } else { r = xpu::fc_fusion(ctx, x, @@ -89,8 +253,356 @@ int xpu_fc_wrapper(xpu::Context* ctx, beta, bias, act); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); } - return r; +} + +template <> +void xpu_fc_wrapper(xpu::Context* ctx, + const float16* x, + const float16* w, + float16* y, + int m, + int n, + int k, + bool x_trans, + bool w_trans, + const float* x_maxptr, + const float* w_maxptr, + float* y_maxptr, + int ldx, + int ldw, + int ldy, + float alpha, + float beta, + const float* bias, + const xpu::Activation_t& act) { + int r = xpu::Error_t::INVALID_PARAM; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); +} + +template +static void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, + int bs, + bool trans_x, + bool trans_w, + int m, + int n, + int k, + float alpha, + const XPUType* x, + int stride_x, + const XPUType* w, + int stride_w, + float beta, + XPUType* y, + int stride_y, + const float* x_maxptr, + const float* w_maxptr) { + int r = xpu::fc_batched( + xpu_ctx, // Context* ctx, + bs, // int batch_size, + trans_x, // bool x_trans, + trans_w, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + reinterpret_cast(x), // const TX* x, + stride_x, // int stride_a, + reinterpret_cast(w), // const TW* w, + stride_w, // int stride_b, + 0.0, // float beta, + reinterpret_cast(y), // TY* y, + stride_y, // int stride_c, + x_maxptr, // const float* x_maxptr, + w_maxptr); // const float* w_maxptr + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); +} + +template <> +void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, + int bs, + bool trans_x, + bool trans_w, + int m, + int n, + int k, + float alpha, + const float16* x, + int stride_x, + const float16* w, + int stride_w, + float beta, + float16* y, + int stride_y, + const float* x_maxptr, + const float* w_maxptr) { + int r = xpu::Error_t::INVALID_PARAM; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); +} + +template <> +void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, + int bs, + bool trans_x, + bool trans_w, + int m, + int n, + int k, + float alpha, + const float16* x, + int stride_x, + const float16* w, + int stride_w, + float beta, + float16* y, + int stride_y, + const float* x_maxptr, + const float* w_maxptr) { + int r = xpu::Error_t::INVALID_PARAM; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); +} + +template +static void MatMulXPUFunction(xpu::Context* xpu_ctx, + const T* x, + const T* y, + T* out, + const XpuFcInfo& fcinfo, + float alpha) { + using XPUType = typename XPUTypeTrait::Type; + using float16 = typename XPUTypeTrait::Type; + int fccal_type = FCCalcType(); + + decltype(&paddle::operators::xpu_fc_wrapper) + fc_api_list[3] = { + &paddle::operators::xpu_fc_wrapper, + &paddle::operators::xpu_fc_wrapper, + &paddle::operators::xpu_fc_wrapper, + }; + decltype(&paddle::operators::xpu_fc_batch_wrapper) + fc_batch_api_list[3] = { + &paddle::operators::xpu_fc_batch_wrapper, + &paddle::operators::xpu_fc_batch_wrapper, + &paddle::operators::xpu_fc_batch_wrapper, + }; + + auto fc_api = fc_api_list[fccal_type]; + auto fc_batch_api = fc_batch_api_list[fccal_type]; + + int m = fcinfo.m; + int n = fcinfo.n; + int k = fcinfo.k; + int batch_size = fcinfo.bs; + int ldx = fcinfo.stride_x; + int ldy = fcinfo.stride_y; + int ldout = fcinfo.stride_out; + bool trans_x = fcinfo.trans_x; + bool trans_y = fcinfo.trans_y; + float* max_x = fcinfo.max_x; + float* max_y = fcinfo.max_y; + float* max_out = fcinfo.max_out; + + if (batch_size <= 1) { + fc_api(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + reinterpret_cast(out), + m, + n, + k, + trans_x, + trans_y, + max_x, + max_y, + max_out, + ldx, + ldy, + ldout, + alpha, + 0, + nullptr, + xpu::Activation_t::LINEAR); + } else { + // batch matmul + fc_batch_api(xpu_ctx, // Context* ctx, + batch_size, // int batch_size, + trans_x, // bool x_trans, + trans_y, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + reinterpret_cast(x), // const TX* x, + ldx, // int stride_a, + reinterpret_cast(y), // const TW* w, + ldy, // int stride_b, + 0.0, // float beta, + reinterpret_cast(out), // TY* y, + ldout, // int stride_c, + max_x, // const float* x_maxptr, + max_y); // const float* w_maxptr + } +} + +template +static std::tuple +MatmulGradFcInfo(xpu::Context* xpu_ctx, + xpu::ctx_guard* RAII_GUARD, + const XpuFcInfo& dout_shape, + bool trans_x, + bool trans_y, + const T* x, + const T* y, + const T* dout) { + XpuFcInfo dx_shape, dy_shape; + const T* dx_a = NULL; + const T* dx_b = NULL; + const T* dy_a = NULL; + const T* dy_b = NULL; + bool copy_to_l3 = false; + float* max_dout = NULL; + int maxptr_size = xpu_ctx->max_ptr_size(); + uint64_t l3_size = uint64_t(xpu_ctx->_l3_mgr.get_size()); + int bs = (dout_shape.bs <= 1) ? (1) : (dout_shape.bs); + int dx_size = bs * dout_shape.m * dout_shape.k; + int dy_size = bs * dout_shape.k * dout_shape.n; + int dout_size = bs * dout_shape.m * dout_shape.n; + if (trans_x && trans_y) { + copy_to_l3 = l3_size >= (dout_size * 2 + dy_size) * sizeof(T); + } else if (trans_x) { + copy_to_l3 = l3_size >= dout_size * sizeof(T); + } else if (trans_y) { + copy_to_l3 = l3_size >= dout_size * 2 * sizeof(T); + } else { + copy_to_l3 = l3_size >= (dout_size + dx_size) * sizeof(T); + } + + const T* dout_new = dout; + int r = 0; + if (copy_to_l3) { + T* dout_l3 = RAII_GUARD->alloc_l3(dout_size); + PADDLE_ENFORCE_XDNN_NOT_NULL(dout_l3); + if ((dout_shape.bs > 1) || ((dout_shape.bs <= 1) && + (FCCalcType() == XPUFCCalcType::FC_FLOAT))) { + r = xpu::copy(xpu_ctx, dout, dout_l3, dout_size); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + dout_new = dout_l3; + } else { + max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); + PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); + + r = xpu::findmax_copy_fusion(xpu_ctx, dout, max_dout, dout_l3, dout_size); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + dout_new = dout_l3; + } + } else if (((dout_shape.bs <= 1) && + (FCCalcType() != XPUFCCalcType::FC_FLOAT))) { + max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); + PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); + r = xpu::findmax_copy_fusion( + xpu_ctx, dout, max_dout, reinterpret_cast(NULL), dout_size); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + } + + if (trans_x && trans_y) { + // dx = T(y) * T(dout) + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.m, + dout_shape.n, + true, + true, + nullptr, + max_dout, + nullptr); + dx_a = y, dx_b = dout_new; + // dy = T(dout) * T(x) + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.n, + dout_shape.k, + dout_shape.m, + true, + true, + max_dout, + nullptr, + nullptr); + dy_a = dout_new, dy_b = x; + } else if (trans_x) { + // dx = y * T(dout) + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.m, + dout_shape.n, + false, + true, + nullptr, + max_dout, + nullptr); + dx_a = y, dx_b = dout_new; + // dy = x * dout + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.n, + dout_shape.m, + false, + false, + nullptr, + max_dout, + nullptr); + dy_a = x, dy_b = dout_new; + } else if (trans_y) { + // dx = dout * y + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.m, + dout_shape.k, + dout_shape.n, + false, + false, + max_dout, + nullptr, + nullptr); + dx_a = dout_new, dx_b = y; + // dy = T(dout) * x + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.n, + dout_shape.k, + dout_shape.m, + true, + false, + max_dout, + nullptr, + nullptr); + dy_a = dout_new, dy_b = x; + } else { + // dx = dout * T(y) + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.m, + dout_shape.k, + dout_shape.n, + false, + true, + max_dout, + nullptr, + nullptr); + dx_a = dout_new, dx_b = y; + // dy = T(x) * dout + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.n, + dout_shape.m, + true, + false, + nullptr, + max_dout, + nullptr); + dy_a = x, dy_b = dout_new; + } + std::tuple + result = std::make_tuple(dx_shape, dy_shape, dx_a, dx_b, dy_a, dy_b); + + return result; } } // namespace operators diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 27339a0f25a8abbaaeefe585a6a73e6995cee412..2b80396cc3138f47d15026cd5c9c167e3988470b 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -281,11 +281,18 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"matmul_v2_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"matmul_v2", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"matmul", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py index d92378f60f5789cec5776660f45d92990e8a16e0..b4032f2dcb67e0eb0b7143ee1cf8a71abc6c62c9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py +++ b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py @@ -84,7 +84,9 @@ type_dict_str_to_numpy = { xpu_test_op_white_list = [] xpu_test_device_type_white_list = ['xpu1_float64'] -xpu_test_op_type_white_list = ['dropout_float16', 'dropout_grad_float16'] +xpu_test_op_type_white_list = [ + 'dropout_float16', 'dropout_grad_float16', 'matmul_v2_float16' +] xpu_test_device_op_white_list = [] xpu_test_device_op_type_white_list = [] diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py index bc6fa19a35444c1283eade3bce5d270b61e933ac..1c68f8fb6bf1693a907ccc349a032373644ba77e 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py @@ -303,7 +303,8 @@ class TestMatmulBaseGenerator(XPUOpTest): X = np.random.random(shape_X).astype(self.dtype) Y = np.random.random(shape_Y).astype(self.dtype) - Out = reference_matmul(X, Y, transpose_X, transpose_Y) + Out = reference_matmul(X, Y, transpose_X, + transpose_Y).astype(self.dtype) self.inputs = {'X': X, 'Y': Y} self.attrs = {'transpose_X': transpose_X, 'transpose_Y': transpose_Y} self.outputs = {'Out': Out}