diff --git a/paddle/fluid/operators/bmm_op_xpu.cc b/paddle/fluid/operators/bmm_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc18558027982b7e496c442333c7e8399b4abbe3 --- /dev/null +++ b/paddle/fluid/operators/bmm_op_xpu.cc @@ -0,0 +1,211 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_XPU + +#include +#include +#include "paddle/fluid/operators/matmul_v2_op.h" + +#include "paddle/fluid/operators/xpu_api_wrapper.h" +#include "paddle/fluid/platform/device/device_wrapper.h" + +namespace paddle { +namespace operators { + +template +static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, + bool trans_x, bool trans_y, + const paddle::framework::ExecutionContext& ctx) { + using XPUType = typename XPUTypeTrait::Type; + const auto& x_dims = x->dims(); + const auto& y_dims = y->dims(); + auto& dev_ctx = + ctx.template device_context(); + + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dims), 0, trans_x); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( + ColumnMatrixFromVector(y_dims), 0, trans_y); + + T* data_c = out->data(); + int m = mat_dim_a.height_; + int n = mat_dim_b.width_; + int k = mat_dim_a.width_; + int batch_size = mat_dim_a.batch_size_; + // batch matmul + int r = xpu::fc_batched( + dev_ctx.x_context(), // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + 1.0, // float alpha, + reinterpret_cast(x->data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y->data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); +} + +template +class BmmXPUKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + if (x->numel() == 0 || y->numel() == 0) { + return; + } + bool trans_x = false; + bool trans_y = false; + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + PADDLE_ENFORCE_EQ(x_dims.size(), 3, + platform::errors::InvalidArgument( + "Input(X) of BmmOp must be 3-dimensional in BmmOp, " + "but received X's shape: [%s].", + x_dims)); + PADDLE_ENFORCE_EQ(y_dims.size(), 3, + platform::errors::InvalidArgument( + "Input(Y) of BmmOp must be 3-dimensional in BmmOp, " + "but received Y's shape: [%s].", + y_dims)); + PADDLE_ENFORCE_EQ( + x_dims[0], y_dims[0], + platform::errors::InvalidArgument( + "Input(X) and Input(Y) must have the same batch size in BmmOp, " + "but received X's batch size: [%s]," + "Y's batch size [%s]", + x_dims[0], y_dims[0])); + PADDLE_ENFORCE_EQ( + x_dims[2], y_dims[1], + platform::errors::InvalidArgument( + "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," + "but receive X's width: [%s]," + "Y's height: [%s].", + x_dims[2], y_dims[1])); + + if (std::is_same::value) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } + } + } +}; + +template +class BmmXPUGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext& ctx, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, + framework::Tensor* out) const { + out->mutable_data(ctx.GetPlace()); + if (std::is_same::value) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } + } + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, + framework::Tensor* out) const { + if (out == nullptr) return; + MatMul(context, a, trans_a, b, trans_b, out); + } + + void Compute(const framework::ExecutionContext& context) const override { + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = + *context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + auto* dy = context.Output(framework::GradVarName("Y")); + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false); + + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + + CalcInputGrad(context, dout, false, y, true, dx); + CalcInputGrad(context, x, true, dout, false, dy); + + // CalcInputGrad(context, dout, false, false, y, true, false, dx); + // CalcInputGrad(context, x, true, true, dout, false, true, dy); + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_XPU_KERNEL(bmm, ops::BmmXPUKernel, + ops::BmmXPUKernel); +REGISTER_OP_XPU_KERNEL(bmm_grad, ops::BmmXPUGradKernel, + ops::BmmXPUGradKernel); + +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 7b88f261d5a4fc3fd9a0ab135beb41dd31e4aa8b..357644b62d3ed3a37584ffdda611dceeb1d43404 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -43,6 +43,8 @@ XPUOpMap& get_kl2_ops() { {"batch_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"bmm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"bmm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"bce_loss_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"bce_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_bmm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_bmm_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..f6893150c9e615c5a054e7f398cf4ca0589a3be5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_bmm_op_xpu.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle.tensor as tensor +import unittest +import numpy as np +from op_test import OpTest +from op_test_xpu import XPUOpTest +from paddle.fluid.framework import Program, program_guard +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestBmmOp(XPUOpTestWrapper): + """ + func desc:: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/bmm_cn.html#bmm + """ + + def __init__(self): + self.op_name = 'bmm' + self.use_dynamic_create_class = False + + class TestBmmOp(XPUOpTest): + def setUp(self): + self.init_dtype() + self.set_xpu() + self.op_type = "bmm" + self.place = paddle.XPUPlace(0) + self.set_shape() + X = np.random.random(self.Xshape).astype(self.dtype) + Y = np.random.random(self.Yshape).astype(self.dtype) + self.inputs = {'X': X, 'Y': Y} + + Out = np.matmul(X, Y) + self.outputs = {'Out': Out} + + def init_dtype(self): + self.dtype = self.in_type + + def set_shape(self): + self.Xshape = (10, 3, 4) + self.Yshape = (10, 4, 5) + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.no_need_check_grad = False + self.__class__.op_type = self.in_type + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + class TestBmmOp1(TestBmmOp): + def set_shape(self): + self.Xshape = (3, 3, 3) + self.Yshape = (3, 3, 3) + + class TestBmmOp2(TestBmmOp): + def set_shape(self): + self.Xshape = (128, 3, 16) + self.Yshape = (128, 16, 3) + + class TestBmmOp3(TestBmmOp): + def set_shape(self): + self.Xshape = (2048, 16, 27) + self.Yshape = (2048, 27, 16) + + class TestBmmOp4(TestBmmOp): + def set_shape(self): + self.Xshape = (2, 27, 27) + self.Yshape = (2, 27, 27) + + class TestBmmOp5(TestBmmOp): + def set_shape(self): + self.Xshape = (2, 1, 1) + self.Yshape = (2, 1, 1) + + +support_types = get_xpu_op_support_types('bmm') +for stype in support_types: + create_test_class(globals(), XPUTestBmmOp, stype) + +if __name__ == '__main__': + unittest.main()