matmul_v2_op_xpu.cc 12.0 KB
Newer Older
Q
QingshuChen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_XPU

#include <string>
#include <vector>

20
#include "paddle/fluid/operators/matmul_v2_op.h"
21 22
#include "paddle/fluid/operators/xpu_api_wrapper.h"

Q
QingshuChen 已提交
23 24 25
namespace paddle {
namespace operators {

T
taixiurong 已提交
26
template <typename T, typename FCT>
27 28 29 30 31
static void MatMulXPUFunction(const Tensor* x,
                              const Tensor* y,
                              Tensor* out,
                              bool trans_x,
                              bool trans_y,
T
taixiurong 已提交
32
                              const paddle::framework::ExecutionContext& ctx) {
T
taixiurong 已提交
33
  using XPUType = typename XPUTypeTrait<T>::Type;
T
taixiurong 已提交
34 35
  const auto& x_dims = x->dims();
  const auto& y_dims = y->dims();
Q
QingshuChen 已提交
36 37 38
  auto& dev_ctx =
      ctx.template device_context<paddle::platform::XPUDeviceContext>();

39
  auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(
40
      RowMatrixFromVector(x_dims), 0, trans_x);
41
  auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(
42
      ColumnMatrixFromVector(y_dims), 0, trans_y);
Q
QingshuChen 已提交
43

44
  if (x_dims.size() >= 3 && y_dims.size() <= 2) {
T
taixiurong 已提交
45 46 47 48
    // if transpose_X is true, the transpose cost much time
    if (!trans_x) {
      mat_dim_a.height_ *= mat_dim_a.batch_size_;
      mat_dim_a.batch_size_ = 0;
Q
QingshuChen 已提交
49
    } else {
T
taixiurong 已提交
50 51
      mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
      mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
Q
QingshuChen 已提交
52 53 54
    }
  }

T
taixiurong 已提交
55 56 57
  if (mat_dim_a.width_ == mat_dim_b.height_) {
    if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
      mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
Q
QingshuChen 已提交
58
    }
T
taixiurong 已提交
59 60
    if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
      mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
Q
QingshuChen 已提交
61 62 63
    }
  }

64 65
  PADDLE_ENFORCE_EQ(mat_dim_a.width_,
                    mat_dim_b.height_,
T
taixiurong 已提交
66
                    platform::errors::InvalidArgument(
67 68
                        "Shape mistake in matmul_v2_op xdims = %s ydims = %s "
                        "x_trans = %d y_trans = %d",
69 70 71
                        x_dims.to_str(),
                        y_dims.to_str(),
                        mat_dim_a.trans_,
72
                        mat_dim_b.trans_));
73 74
  PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_,
                    mat_dim_b.batch_size_,
T
taixiurong 已提交
75
                    platform::errors::InvalidArgument(
76 77
                        "Shape mistake in matmul_v2_op xdims = %s ydims = %s "
                        "x_trans = %d y_trans = %d",
78 79 80
                        x_dims.to_str(),
                        y_dims.to_str(),
                        mat_dim_a.trans_,
81
                        mat_dim_b.trans_));
T
taixiurong 已提交
82

83
  T* data_c = out->data<T>();
T
taixiurong 已提交
84 85 86 87
  int m = mat_dim_a.height_;
  int n = mat_dim_b.width_;
  int k = mat_dim_a.width_;
  int batch_size = mat_dim_a.batch_size_;
88 89 90
  int ldx = mat_dim_a.trans_ ? m : k;
  int ldy = mat_dim_b.trans_ ? k : n;
  int ldout = n;
91 92
  if (batch_size <= 1) {
    int r = 0;
93
    r = xpu_fc_wrapper<XPUType, FCT>(
94 95
        dev_ctx.x_context(),
        reinterpret_cast<const XPUType*>(x->data<T>()),
T
taixiurong 已提交
96
        reinterpret_cast<const XPUType*>(y->data<T>()),
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
        reinterpret_cast<XPUType*>(data_c),
        m,
        n,
        k,
        mat_dim_a.trans_,
        mat_dim_b.trans_,
        nullptr,
        nullptr,
        nullptr,
        ldx,
        ldy,
        ldout,
        1.0,
        0,
        nullptr,
        xpu::Activation_t::LINEAR);
113
    PADDLE_ENFORCE_EQ(
114 115
        r,
        XPU_SUCCESS,
116
        platform::errors::External(
117
            "XPU fc kernel return wrong value[%d %s] , m = %d, n = "
118 119
            "%d, "
            "k = %d, a_tr = %d, b_tr = %d",
120 121 122 123 124 125 126
            r,
            XPUAPIErrorMsg[r],
            m,
            n,
            k,
            mat_dim_a.trans_,
            mat_dim_b.trans_));
Q
QingshuChen 已提交
127
  } else {
128
    // batch matmul
T
taixiurong 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
        dev_ctx.x_context(),                             // Context* ctx,
        batch_size,                                      // int batch_size,
        mat_dim_a.trans_,                                // bool x_trans,
        mat_dim_b.trans_,                                // bool w_trans,
        m,                                               // int m,
        n,                                               // int n,
        k,                                               // int k,
        1.0,                                             // float alpha,
        reinterpret_cast<const XPUType*>(x->data<T>()),  // const TX* x,
        mat_dim_a.stride_,                               // int stride_a,
        reinterpret_cast<const XPUType*>(y->data<T>()),  // const TW* w,
        mat_dim_b.stride_,                               // int stride_b,
        0.0,                                             // float beta,
        reinterpret_cast<XPUType*>(data_c),              // TY* y,
        m * n,                                           // int stride_c,
        nullptr,   // const float* x_maxptr,
        nullptr);  // const float* w_maxptr
147

148 149
    PADDLE_ENFORCE_EQ(r,
                      XPU_SUCCESS,
150
                      platform::errors::External(
151 152
                          "XPU fc_batched kernel return wrong value[%d %s]",
                          r,
153
                          XPUAPIErrorMsg[r]));
Q
QingshuChen 已提交
154 155 156 157 158 159 160
  }
}

template <typename T>
class MatMulV2XPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
T
taixiurong 已提交
161 162 163
    auto* x = ctx.Input<Tensor>("X");
    auto* y = ctx.Input<Tensor>("Y");
    auto* out = ctx.Output<Tensor>("Out");
Q
QingshuChen 已提交
164 165
    bool trans_x = ctx.Attr<bool>("trans_x");
    bool trans_y = ctx.Attr<bool>("trans_y");
T
taixiurong 已提交
166
    out->mutable_data<T>(ctx.GetPlace());
T
taixiurong 已提交
167
    if (std::is_same<paddle::platform::float16, T>::value) {
T
taixiurong 已提交
168
      MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx);
T
taixiurong 已提交
169
    } else {
170
      if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
T
taixiurong 已提交
171
        MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, ctx);
172 173
      } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
        MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, ctx);
T
taixiurong 已提交
174 175 176
      } else {
        MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx);
      }
T
taixiurong 已提交
177
    }
Q
QingshuChen 已提交
178 179 180
  }
};

T
taixiurong 已提交
181 182 183
template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
    const DeviceContext& context, const framework::Tensor& input) {
T
taixiurong 已提交
184
  using XPUType = typename XPUTypeTrait<T>::Type;
T
taixiurong 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197
  auto in_dims = input.dims();
  if (in_dims.size() != 3) {
    return input;
  }

  framework::Tensor output;
  output.Resize({in_dims[1], in_dims[0], in_dims[2]});
  output.mutable_data<T>(context.GetPlace());
  std::vector<int> in_shape_host = {static_cast<int>(in_dims[0]),
                                    static_cast<int>(in_dims[1]),
                                    static_cast<int>(in_dims[2])};
  std::vector<int> axis_host = {1, 0, 2};

198 199 200 201 202 203 204
  int r = xpu::transpose(context.x_context(),
                         reinterpret_cast<const XPUType*>(input.data<T>()),
                         reinterpret_cast<XPUType*>(output.data<T>()),
                         in_shape_host,
                         axis_host);
  PADDLE_ENFORCE_EQ(r,
                    XPU_SUCCESS,
T
taixiurong 已提交
205
                    platform::errors::External(
206 207
                        "XPU transpose kernel return wrong value[%d %s]",
                        r,
T
taixiurong 已提交
208 209 210 211 212 213
                        XPUAPIErrorMsg[r]));
  output.Resize({in_dims[1], in_dims[0] * in_dims[2]});

  return output;
}

Q
QingshuChen 已提交
214 215 216
template <typename T>
class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
 public:
T
taixiurong 已提交
217
  void MatMul(const framework::ExecutionContext& ctx,
218 219 220 221
              const framework::Tensor& a,
              bool trans_a,
              const framework::Tensor& b,
              bool trans_b,
Q
QingshuChen 已提交
222
              framework::Tensor* out) const {
T
taixiurong 已提交
223
    out->mutable_data<T>(ctx.GetPlace());
T
taixiurong 已提交
224
    if (std::is_same<paddle::platform::float16, T>::value) {
T
taixiurong 已提交
225
      MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
T
taixiurong 已提交
226
    } else {
227
      if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
T
taixiurong 已提交
228
        MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, ctx);
229 230
      } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
        MatMulXPUFunction<T, float>(&a, &b, out, trans_a, trans_b, ctx);
T
taixiurong 已提交
231 232 233
      } else {
        MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
      }
T
taixiurong 已提交
234
    }
Q
QingshuChen 已提交
235 236 237
  }

  void CalcInputGrad(const framework::ExecutionContext& context,
238 239 240 241 242 243
                     const framework::Tensor& a,
                     bool trans_a,
                     bool is_fold_init_dims_a,
                     const framework::Tensor& b,
                     bool trans_b,
                     bool is_fold_init_dims_b,
Q
QingshuChen 已提交
244 245 246 247 248 249 250
                     framework::Tensor* out) const {
    if (out == nullptr) return;
    bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
                        out->dims().size() == 2;
    if (!need_combine) {
      MatMul(context, a, trans_a, b, trans_b, out);
    } else {
T
taixiurong 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263
      auto& dev_ctx =
          context.template device_context<paddle::platform::XPUDeviceContext>();
      MatMul(
          context,
          is_fold_init_dims_a
              ? FoldInitDims(a)
              : XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
                    dev_ctx, a),
          trans_a,
          is_fold_init_dims_b
              ? FoldInitDims(b)
              : XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
                    dev_ctx, b),
264 265
          trans_b,
          out);
Q
QingshuChen 已提交
266 267 268
    }
  }

T
taixiurong 已提交
269 270 271 272 273 274 275 276 277 278 279
  void Compute(const framework::ExecutionContext& context) const override {
    bool transpose_x = context.Attr<bool>("trans_x");
    bool transpose_y = context.Attr<bool>("trans_y");

    auto x = *context.Input<framework::Tensor>("X");
    auto y = *context.Input<framework::Tensor>("Y");
    auto dout =
        *context.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
    auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
    ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
280

T
taixiurong 已提交
281 282 283 284 285
    framework::DDim dx_dims;
    if (dx) {
      dx_dims = dx->dims();
      if (dx_dims != x.dims()) {
        dx->Resize(x.dims());
Q
QingshuChen 已提交
286
      }
T
taixiurong 已提交
287 288 289 290 291 292 293
    }

    framework::DDim dy_dims;
    if (dy) {
      dy_dims = dy->dims();
      if (dy_dims != y.dims()) {
        dy->Resize(y.dims());
Q
QingshuChen 已提交
294 295 296
      }
    }

T
taixiurong 已提交
297 298 299 300 301 302 303 304 305
    if (transpose_x && transpose_y) {
      CalcInputGrad(context, y, true, true, dout, true, false, dx);
      CalcInputGrad(context, dout, true, true, x, true, false, dy);
    } else if (transpose_x) {
      CalcInputGrad(context, y, false, false, dout, true, false, dx);
      CalcInputGrad(context, x, false, false, dout, false, true, dy);
    } else if (transpose_y) {
      CalcInputGrad(context, dout, false, false, y, false, true, dx);
      CalcInputGrad(context, dout, true, true, x, false, true, dy);
Q
QingshuChen 已提交
306
    } else {
T
taixiurong 已提交
307 308
      CalcInputGrad(context, dout, false, false, y, true, false, dx);
      CalcInputGrad(context, x, true, true, dout, false, true, dy);
Q
QingshuChen 已提交
309 310
    }

T
taixiurong 已提交
311 312 313
    if (dx) {
      if (dx_dims != x.dims()) {
        dx->Resize(dx_dims);
Q
QingshuChen 已提交
314
      }
T
taixiurong 已提交
315
    }
Q
QingshuChen 已提交
316

T
taixiurong 已提交
317 318 319
    if (dy) {
      if (dy_dims != y.dims()) {
        dy->Resize(dy_dims);
Q
QingshuChen 已提交
320 321 322 323 324 325 326 327 328
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
T
taixiurong 已提交
329
namespace plat = paddle::platform;
330 331
REGISTER_OP_XPU_KERNEL(matmul_v2,
                       ops::MatMulV2XPUKernel<float>,
T
taixiurong 已提交
332
                       ops::MatMulV2XPUKernel<plat::float16>);
333 334
REGISTER_OP_XPU_KERNEL(matmul_v2_grad,
                       ops::MatMulV2XPUGradKernel<float>,
T
taixiurong 已提交
335
                       ops::MatMulV2XPUGradKernel<plat::float16>);
Q
QingshuChen 已提交
336 337

#endif