matmul_v2_op_xpu.cc 5.0 KB
Newer Older
Q
QingshuChen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_XPU

#include <string>
#include <vector>
19
#include "paddle/fluid/operators/matmul_v2_op.h"
20

21 22
#include "paddle/fluid/operators/xpu_api_wrapper.h"

Q
QingshuChen 已提交
23 24 25 26 27
namespace paddle {
namespace operators {

template <typename T>
class MatMulV2XPUKernel : public framework::OpKernel<T> {
28 29
  using XPUType = typename XPUTypeTrait<T>::Type;

Q
QingshuChen 已提交
30 31
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
T
taixiurong 已提交
32 33 34
    auto* x = ctx.Input<Tensor>("X");
    auto* y = ctx.Input<Tensor>("Y");
    auto* out = ctx.Output<Tensor>("Out");
Q
QingshuChen 已提交
35 36
    bool trans_x = ctx.Attr<bool>("trans_x");
    bool trans_y = ctx.Attr<bool>("trans_y");
T
taixiurong 已提交
37
    out->mutable_data<T>(ctx.GetPlace());
38 39 40 41 42 43 44 45 46 47 48 49
    const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x->data<T>());
    const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y->data<T>());
    XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
    auto x_dims = x->dims();
    auto y_dims = y->dims();

    XpuFcInfo fc_info;
    GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info);
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::XPUDeviceContext>();
    xpu::Context* xpu_ctx = dev_ctx.x_context();
    MatMulXPUFunction<XPUType>(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f);
Q
QingshuChen 已提交
50 51 52 53 54
  }
};

template <typename T>
class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
55
  using XPUType = typename XPUTypeTrait<T>::Type;
Q
QingshuChen 已提交
56

57
 public:
T
taixiurong 已提交
58 59 60 61 62 63 64 65 66 67
  void Compute(const framework::ExecutionContext& context) const override {
    bool transpose_x = context.Attr<bool>("trans_x");
    bool transpose_y = context.Attr<bool>("trans_y");
    auto x = *context.Input<framework::Tensor>("X");
    auto y = *context.Input<framework::Tensor>("Y");
    auto dout =
        *context.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
    auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
    if (dx) {
68
      dx->mutable_data<T>(context.GetPlace());
T
taixiurong 已提交
69 70
    }
    if (dy) {
71
      dy->mutable_data<T>(context.GetPlace());
Q
QingshuChen 已提交
72
    }
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    auto& dev_ctx =
        context.template device_context<paddle::platform::XPUDeviceContext>();

    const XPUType* dout_ptr = reinterpret_cast<const XPUType*>(dout.data<T>());
    const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
    const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());

    xpu::Context* xpu_ctx = dev_ctx.x_context();

    XpuFcInfo info_forward;
    GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward);
    xpu::ctx_guard RAII_GUARD(xpu_ctx);
    // begin calculate
    const XPUType* a_1 = reinterpret_cast<const XPUType*>(NULL);
    const XPUType* b_1 = reinterpret_cast<const XPUType*>(NULL);
    const XPUType* a_2 = reinterpret_cast<const XPUType*>(NULL);
    const XPUType* b_2 = reinterpret_cast<const XPUType*>(NULL);
    XPUType* c_1 = (dx == NULL) ? reinterpret_cast<XPUType*>(NULL)
                                : reinterpret_cast<XPUType*>(dx->data<T>());
    XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL)
                                : reinterpret_cast<XPUType*>(dy->data<T>());
    XpuFcInfo info_dx;
    XpuFcInfo info_dy;
    std::tuple<XpuFcInfo,
               XpuFcInfo,
               const XPUType*,
               const XPUType*,
               const XPUType*,
               const XPUType*>
        fc_info = MatmulGradFcInfo(xpu_ctx,
                                   &RAII_GUARD,
                                   info_forward,
                                   transpose_x,
                                   transpose_y,
                                   x_ptr,
                                   y_ptr,
                                   dout_ptr);
    std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
T
taixiurong 已提交
111
    if (dx) {
112
      MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
T
taixiurong 已提交
113 114
    }
    if (dy) {
115
      MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
Q
QingshuChen 已提交
116 117 118 119 120 121 122 123
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
T
taixiurong 已提交
124
namespace plat = paddle::platform;
125 126
REGISTER_OP_XPU_KERNEL(matmul_v2,
                       ops::MatMulV2XPUKernel<float>,
T
taixiurong 已提交
127
                       ops::MatMulV2XPUKernel<plat::float16>);
128 129
REGISTER_OP_XPU_KERNEL(matmul_v2_grad,
                       ops::MatMulV2XPUGradKernel<float>,
T
taixiurong 已提交
130
                       ops::MatMulV2XPUGradKernel<plat::float16>);
Q
QingshuChen 已提交
131 132

#endif