// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template using EigenMatrix = framework::EigenMatrix; template void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y, const Tensor* tensor_dout, Tensor* tensor_dx, Tensor* tensor_dy, const paddle::framework::ExecutionContext& ctx) { #ifdef __NVCC__ if (1 == tensor_dout->dims().size()) { auto dout = framework::EigenVector::Flatten(*tensor_dout); if (tensor_dx) { auto y = framework::EigenVector::Flatten(*tensor_y); auto dx = framework::EigenVector::Flatten(*tensor_dx); auto& dev = *ctx.template device_context().eigen_device(); Eigen::DSizes size(tensor_dx->numel()); dx.device(dev) = y * dout.broadcast(size); } if (tensor_dy) { auto x = framework::EigenVector::Flatten(*tensor_x); auto dy = framework::EigenVector::Flatten(*tensor_dy); auto& dev = *ctx.template device_context().eigen_device(); Eigen::DSizes size(tensor_dy->numel()); dy.device(dev) = x * dout.broadcast(size); } } else { auto dout = EigenMatrix::From(*tensor_dout); if (tensor_dx) { tensor_dx->mutable_data(ctx.GetPlace()); auto y = EigenMatrix::From(*tensor_y); auto dx = EigenMatrix::From(*tensor_dx); auto& dev = *ctx.template device_context().eigen_device(); Eigen::DSizes size(1, tensor_dx->dims()[1]); dx.device(dev) = y * dout.broadcast(size); } if (tensor_dy) { tensor_dy->mutable_data(ctx.GetPlace()); auto x = EigenMatrix::From(*tensor_x); auto dy = EigenMatrix::From(*tensor_dy); auto& dev = *ctx.template device_context().eigen_device(); Eigen::DSizes size(1, tensor_dy->dims()[1]); dy.device(dev) = x * dout.broadcast(size); } } #else const auto* data_dout = tensor_dout->data(); if (tensor_dx) { auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); const auto* data_y = tensor_y->data(); const framework::DDim& dim = tensor_x->dims(); size_t N = static_cast(framework::product(dim)); auto step = dim[dim.size() - 1]; int s = -1; for (size_t i = 0; i < N; ++i) { if (0 == i % step) ++s; data_dx[i] = data_y[i] * data_dout[s]; } } if (tensor_dy) { auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); const auto* data_x = tensor_x->data(); const framework::DDim& dim = tensor_y->dims(); size_t N = static_cast(framework::product(dim)); auto step = dim[dim.size() - 1]; int s = -1; for (size_t i = 0; i < N; ++i) { if (0 == i % step) ++s; data_dy[i] = data_x[i] * data_dout[s]; } } #endif } template class DotKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* tensor_x = ctx.Input("X"); auto* tensor_y = ctx.Input("Y"); auto* tensor_out = ctx.Output("Out"); tensor_out->mutable_data(ctx.GetPlace()); #ifdef __NVCC__ if (1 == tensor_out->dims().size()) { auto out = framework::EigenScalar::From(*tensor_out); auto x = framework::EigenVector::Flatten(*tensor_x); auto y = framework::EigenVector::Flatten(*tensor_y); auto& dev = *ctx.template device_context().eigen_device(); out.device(dev) = (x * y).sum(); } else { auto out = EigenMatrix::From(*tensor_out); auto x = EigenMatrix::From(*tensor_x); auto y = EigenMatrix::From(*tensor_y); auto& dev = *ctx.template device_context().eigen_device(); out.device(dev) = (x * y).sum(Eigen::DSizes(1)); } #else const auto* data_x = tensor_x->data(); const auto* data_y = tensor_y->data(); auto* data_out = tensor_out->data(); auto x_dims = tensor_x->dims(); auto step = x_dims[x_dims.size() - 1]; int size = static_cast(framework::product(x_dims)); for (int ind = -1, j = 0; j < size; ++j) { if (j % step == 0) { ++ind; data_out[ind] = data_x[j] * data_y[j]; } else { data_out[ind] += data_x[j] * data_y[j]; } } #endif } }; template class DotGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* tensor_x = ctx.Input("X"); auto* tensor_y = ctx.Input("Y"); auto* tensor_dout = ctx.Input(framework::GradVarName("Out")); auto* tensor_dx = ctx.Output(framework::GradVarName("X")); auto* tensor_dy = ctx.Output(framework::GradVarName("Y")); if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); DotGradFunction(tensor_x, tensor_y, tensor_dout, tensor_dx, tensor_dy, ctx); } }; } // namespace operators } // namespace paddle