// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #ifdef _WIN32 #ifndef NOMINMAX #define NOMINMAX // msvc max/min macro conflict with std::min/max #endif #endif namespace paddle { namespace operators { static framework::DDim ExtendDims2Rank(const framework::DDim& in_dims, int rank) { if (in_dims.size() == rank) { return in_dims; } std::vector shapes(rank, 1); for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) { shapes[j] = in_dims[i]; } return framework::make_ddim(shapes); } template static void GetBroadcastDims(const framework::DDim& in_dims, const framework::DDim& out_dims, Eigen::DSizes* bcast_dims) { for (size_t i = 0; i < D; ++i) { if (in_dims[i] == out_dims[i]) { (*bcast_dims)[i] = 1; } else { (*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]); } } } template static void LerpFunction(const framework::ExecutionContext& ctx) { auto x = ctx.Input("X"); auto y = ctx.Input("Y"); auto w = ctx.Input("Weight"); auto out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto out_dims = out->dims(); auto x_dims = ExtendDims2Rank(x->dims(), D); auto y_dims = ExtendDims2Rank(y->dims(), D); auto w_dims = ExtendDims2Rank(w->dims(), D); Eigen::DSizes x_bcast_dims; Eigen::DSizes y_bcast_dims; Eigen::DSizes w_bcast_dims; GetBroadcastDims(x_dims, out_dims, &x_bcast_dims); GetBroadcastDims(y_dims, out_dims, &y_bcast_dims); GetBroadcastDims(w_dims, out_dims, &w_bcast_dims); auto eigen_x = framework::EigenTensor::From(*x, x_dims); auto eigen_y = framework::EigenTensor::From(*y, y_dims); auto eigen_w = framework::EigenTensor::From(*w, w_dims); auto eigen_out = framework::EigenTensor::From(*out); auto& place = *ctx.template device_context().eigen_device(); eigen_out.device(place) = eigen_x.broadcast(x_bcast_dims) + eigen_w.broadcast(w_bcast_dims) * (eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims)); } template static void LerpGradFunction(const framework::ExecutionContext& ctx) { auto w = ctx.Input("Weight"); auto dout = ctx.Input(framework::GradVarName("Out")); auto dx = ctx.Output(framework::GradVarName("X")); auto dy = ctx.Output(framework::GradVarName("Y")); auto dout_dims = dout->dims(); auto dx_dims = ExtendDims2Rank(dx->dims(), D); auto dy_dims = ExtendDims2Rank(dy->dims(), D); auto w_dims = ExtendDims2Rank(w->dims(), D); Eigen::DSizes dx_bcast_dims; Eigen::DSizes dy_bcast_dims; Eigen::DSizes w_bcast_dims; GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); auto eigen_w = framework::EigenTensor::From(*w, w_dims); auto eigen_dout = framework::EigenTensor::From(*dout); Eigen::DSizes dx_reshape_dims; Eigen::DSizes dy_reshape_dims; Eigen::DSizes reduce_dims; for (int i = 0; i < dout_dims.size(); ++i) { dx_reshape_dims[2 * i] = dx_bcast_dims[i]; dx_reshape_dims[2 * i + 1] = dx_dims[i]; dy_reshape_dims[2 * i] = dy_bcast_dims[i]; dy_reshape_dims[2 * i + 1] = dy_dims[i]; reduce_dims[i] = 2 * i; } auto& place = *ctx.template device_context().eigen_device(); if (dx) { dx->mutable_data(ctx.GetPlace()); auto eigen_dx = framework::EigenTensor::From(*dx, dx_dims); auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout; eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims) .sum(reduce_dims) .reshape(eigen_dx.dimensions()); } if (dy) { dy->mutable_data(ctx.GetPlace()); auto eigen_dy = framework::EigenTensor::From(*dy, dy_dims); auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout; eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims) .sum(reduce_dims) .reshape(eigen_dy.dimensions()); } } template class LerpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int rank = ctx.Output("Out")->dims().size(); PADDLE_ENFORCE_GE( rank, 1, platform::errors::InvalidArgument( "The number of dimensions for LerpOp must be " "greater than or equal to 1, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, 6, platform::errors::InvalidArgument( "The number of dimensions for LerpOp must be " "less than or equal to 6, but the value received is %d.", rank)); switch (rank) { case 1: LerpFunction(ctx); break; case 2: LerpFunction(ctx); break; case 3: LerpFunction(ctx); break; case 4: LerpFunction(ctx); break; case 5: LerpFunction(ctx); break; case 6: LerpFunction(ctx); break; } } }; template class LerpGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int rank = ctx.Input(framework::GradVarName("Out")) ->dims() .size(); PADDLE_ENFORCE_GE( rank, 1, platform::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " "greater than or equal to 1, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, 6, platform::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " "less than or equal to 6, but the value received is %d.", rank)); switch (rank) { case 1: LerpGradFunction(ctx); break; case 2: LerpGradFunction(ctx); break; case 3: LerpGradFunction(ctx); break; case 4: LerpGradFunction(ctx); break; case 5: LerpGradFunction(ctx); break; case 6: LerpGradFunction(ctx); break; } } }; } // namespace operators } // namespace paddle