// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace lite { namespace kernels { namespace x86 { using Tensor = framework::Tensor; template class MulCompute : public KernelLite { public: using param_t = operators::MulParam; void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); CHECK(context.x86_device_context()); param.output->template mutable_data(); auto* x = ¶m.x->raw_tensor(); auto* y = ¶m.y->raw_tensor(); const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix( *x, param.x_num_col_dims) : *x; const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix( *y, param.y_num_col_dims) : *y; auto* z = ¶m.output->raw_tensor(); auto z_dim = z->dims(); if (z_dim.size() != 2) { z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } auto blas = paddle::operators::math::GetBlas( *context.x86_device_context()); blas.MatMul(x_matrix, y_matrix, z); if (z_dim.size() != 2) { z->Resize(z_dim); } } virtual ~MulCompute() = default; }; template class MulGradCompute : public KernelLite { public: void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); CHECK(context.x86_device_context()); auto* x = ¶m.x->raw_tensor(); auto* y = ¶m.y->raw_tensor(); auto x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, param.x_num_col_dims) : static_cast(*x); auto y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, param.y_num_col_dims) : static_cast(*y); auto* dout = ¶m.output_grad->raw_tensor(); Tensor dout_mat; dout_mat.ShareDataWith(*dout); dout_mat.Resize( {framework::flatten_to_2d(x->dims(), param.x_num_col_dims)[0], framework::flatten_to_2d(y->dims(), param.y_num_col_dims)[1]}); auto* dx = ¶m.x_grad->raw_tensor(); auto* dy = ¶m.y_grad->raw_tensor(); if (dx != nullptr) { dx->set_lod(x->lod()); } if (dy != nullptr) { dy->set_lod(y->lod()); } auto blas = paddle::operators::math::GetBlas( *context.x86_device_context()); if (dx) { // dx->mutable_data(context.x86_device_context->GetPlace()); param.x_grad->template mutable_data(); Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix( *dx, param.x_num_col_dims) : *dx; // dx = dout * y'. dx: M x K, dout : M x N, y : K x N blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); } if (dy) { // dy->yutable_data(context.x86_device_context->GetPlace()); param.y_grad->template mutable_data(); Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix( *dy, param.y_num_col_dims) : *dy; // dy = x' * dout. dy K x N, dout : M x N, x : M x K blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); } } virtual ~MulGradCompute() = default; }; } // namespace x86 } // namespace kernels } // namespace lite } // namespace paddle