/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU #include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/xpu_api_wrapper.h" namespace paddle { namespace operators { using framework::Tensor; template class MatMulXPUKernel : public framework::OpKernel { using XPUType = typename XPUTypeTrait::Type; public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); bool trans_x = context.Attr("transpose_X"); bool trans_y = context.Attr("transpose_Y"); float alpha = static_cast(context.Attr("alpha")); const XPUType* x_ptr = reinterpret_cast(x->data()); const XPUType* y_ptr = reinterpret_cast(y->data()); XPUType* out_ptr = reinterpret_cast(out->data()); auto x_dims = x->dims(); auto y_dims = y->dims(); XpuFcInfo fc_info; GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info); auto& dev_ctx = context.template device_context(); xpu::Context* xpu_ctx = dev_ctx.x_context(); MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha); } }; // Using dimensional constraints on matrix multiplication, it is // straight-forward to check the following table for when X and Y // are both matrices. // // transpose_X | False | True | False | True // transpose_Y | False | False | True | True // -----------+----------+----------+----------+----------- // dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T // dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T // // When X is a vector of size K, we treat it instead as a matrix of shape // (1, K). Similarly, when Y is a vector of size K, we treat it instead as // a matrix of shape (K, 1). // // When X and Y are both 3-dimensional tensors, then the first dimension // the batch dimension can be ignored and the exact same formulas apply // as for two matrices. // // Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end // up with formulas like // // dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj} // // To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N // to X: (P * M) x K, dOut: (P * M) x N. template class MatMulGradXPUKernel : public framework::OpKernel { using XPUType = typename XPUTypeTrait::Type; public: void Compute(const framework::ExecutionContext& context) const override { auto x = *context.Input("X"); auto y = *context.Input("Y"); auto dout = *context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto* dy = context.Output(framework::GradVarName("Y")); bool transpose_x = context.Attr("transpose_X"); bool transpose_y = context.Attr("transpose_Y"); float alpha = static_cast(context.Attr("alpha")); if (dx) { dx->mutable_data(context.GetPlace()); } if (dy) { dy->mutable_data(context.GetPlace()); } auto& dev_ctx = context.template device_context(); const XPUType* dout_ptr = reinterpret_cast(dout.data()); const XPUType* x_ptr = reinterpret_cast(x.data()); const XPUType* y_ptr = reinterpret_cast(y.data()); xpu::Context* xpu_ctx = dev_ctx.x_context(); XpuFcInfo info_forward; GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward); xpu::ctx_guard RAII_GUARD(xpu_ctx); // begin calculate const XPUType* a_1 = reinterpret_cast(NULL); const XPUType* b_1 = reinterpret_cast(NULL); const XPUType* a_2 = reinterpret_cast(NULL); const XPUType* b_2 = reinterpret_cast(NULL); XPUType* c_1 = (dx == NULL) ? reinterpret_cast(NULL) : reinterpret_cast(dx->data()); XPUType* c_2 = (dy == NULL) ? reinterpret_cast(NULL) : reinterpret_cast(dy->data()); XpuFcInfo info_dx; XpuFcInfo info_dy; std::tuple fc_info = MatmulGradFcInfo(xpu_ctx, &RAII_GUARD, info_forward, transpose_x, transpose_y, x_ptr, y_ptr, dout_ptr); std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; if (dx) { MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, alpha); } if (dy) { MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, alpha); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL( matmul, ops::MatMulXPUKernel, ops::MatMulXPUKernel); REGISTER_OP_XPU_KERNEL( matmul_grad, ops::MatMulGradXPUKernel, ops::MatMulGradXPUKernel); #endif