/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef PADDLE_FLUID_OPERATORS_BMM_OP_H_ #define PADDLE_FLUID_OPERATORS_BMM_OP_H_ #include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; static void ReshapeTensorIntoMatrixSequence( framework::Tensor *x, const phi::funcs::MatDescriptor &descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; if (descriptor.trans_) { std::swap(w, h); } x->Resize({descriptor.batch_size_, h, w}); } static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, framework::Tensor *y, framework::Tensor *out, bool trans_x, bool trans_y) { auto x_dim = x->dims(); auto y_dim = y->dims(); auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, false); auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, false); out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), mat_dim_x.height_, mat_dim_y.width_}); ReshapeTensorIntoMatrixSequence(x, mat_dim_x); ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } template class BmmKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { const Tensor &x = *context.Input("X"); const Tensor &y = *context.Input("Y"); Tensor *out = context.Output("Out"); out->mutable_data(context.GetPlace()); if (x.numel() == 0 || y.numel() == 0) { return; } auto blas = phi::funcs::GetBlas(context); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(x.dims(), 0, false); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(y.dims(), 0, false); // auto scale = static_cast(context.Attr("alpha")); blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); } }; template class BmmGradKernel : public framework::OpKernel { public: void MatMul(const framework::ExecutionContext &context, const framework::Tensor &a, bool trans_a, const framework::Tensor &b, bool trans_b, framework::Tensor *out) const { out->mutable_data(context.GetPlace()); auto blas = phi::funcs::GetBlas(context); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); } void CalcInputGrad(const framework::ExecutionContext &context, const framework::Tensor &a, bool trans_a, const framework::Tensor &b, bool trans_b, framework::Tensor *out) const { if (out == nullptr) return; MatMul(context, a, trans_a, b, trans_b, out); } void Compute(const framework::ExecutionContext &context) const override { auto x = *context.Input("X"); auto y = *context.Input("Y"); auto dout = *context.Input(framework::GradVarName("Out")); auto *dx = context.Output(framework::GradVarName("X")); auto *dy = context.Output(framework::GradVarName("Y")); ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false); framework::DDim dx_dims; if (dx) { dx_dims = dx->dims(); if (dx_dims != x.dims()) { dx->Resize(x.dims()); } } framework::DDim dy_dims; if (dy) { dy_dims = dy->dims(); if (dy_dims != y.dims()) { dy->Resize(y.dims()); } } CalcInputGrad(context, dout, false, y, true, dx); CalcInputGrad(context, x, true, dout, false, dy); if (dx) { if (dx_dims != x.dims()) { dx->Resize(dx_dims); } } if (dy) { if (dy_dims != y.dims()) { dy->Resize(dy_dims); } } } }; } // namespace operators } // namespace paddle #endif // PADDLE_FLUID_OPERATORS_BMM_OP_H_