/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/softmax.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; template using EigenMatrix = framework::EigenMatrix; template using EigenTensor = framework::EigenTensor; static inline int CanonicalAxis(const int axis, const int rank) { if (axis < 0) { return axis + rank; } return axis; } static inline int SizeToAxis(const int axis, DDim dims) { int size = 1; for (int i = 0; i < axis; i++) { size *= dims[i]; } return size; } static inline int SizeFromAxis(const int axis, DDim dims) { int size = 1; for (int i = axis; i < dims.size(); i++) { size *= dims[i]; } return size; } static inline int SizeOutAxis(const int axis, DDim dims) { int size = 1; for (int i = axis + 1; i < dims.size(); i++) { size *= dims[i]; } return size; } template struct ArgMaxFunctor { void operator()(const DeviceContext& ctx, const Tensor& in, Tensor* index_tensor, const int64_t& axis) { auto in_eigen = EigenTensor::From(in, in.dims()); auto index_eigen = EigenTensor::From(*index_tensor); index_eigen = in_eigen.argmax(axis).template cast(); } }; template struct GumbleNoiseGenerator; template struct OneHotGenerator; template struct GumbleNoiseGenerator { static void Transform(const platform::CPUDeviceContext& context, const T* input_data, T* output_data, int size_to_axis, int size_from_axis, const float temperature) { // generate uniform random number const int size = size_to_axis * size_from_axis; std::uniform_real_distribution dist(0.00001, 1); auto engine = paddle::framework::GetCPURandomEngine(0); Tensor random_tensor; auto* random_data = random_tensor.mutable_data({size}, platform::CPUPlace()); for (int64_t i = 0; i < size; ++i) { random_data[i] = dist(*engine); } // generate gumbel noise framework::DDim dim_2d{size_to_axis, size_from_axis}; auto gumbel_noise_eigen = EigenMatrix::From(random_tensor, dim_2d); gumbel_noise_eigen = -(((-(gumbel_noise_eigen.log())).log())); // add noise for (int64_t i = 0; i < size_to_axis * size_from_axis; i++) { output_data[i] = (input_data[i] + random_data[i]) / temperature; } } }; template struct OneHotGenerator { static void Transform(const platform::CPUDeviceContext& context, const Tensor& X, Tensor* Out, int axis) { Tensor index; std::vector index_dim; const auto rank = X.dims().size(); const int size_to_axis = SizeToAxis(axis, X.dims()); const int size_from_axis = SizeFromAxis(axis, X.dims()); const int size_out_axis = SizeOutAxis(axis, X.dims()); for (int i = 0; i < X.dims().size(); i++) { if (i != axis) index_dim.push_back(X.dims().Get()[i]); } DDim index_ddim(index_dim.data(), rank - 1); index.Resize(index_ddim); auto* index_data = index.mutable_data(context.GetPlace()); #define CALL_ARG_MINMAX_FUNCTOR(rank) \ ArgMaxFunctor functor##rank; \ functor##rank(context, *Out, &index, axis); switch (Out->dims().size()) { case 1: CALL_ARG_MINMAX_FUNCTOR(1); break; case 2: CALL_ARG_MINMAX_FUNCTOR(2); break; case 3: CALL_ARG_MINMAX_FUNCTOR(3); break; case 4: CALL_ARG_MINMAX_FUNCTOR(4); break; case 5: CALL_ARG_MINMAX_FUNCTOR(5); break; case 6: CALL_ARG_MINMAX_FUNCTOR(6); break; default: PADDLE_ENFORCE_LE(Out->dims().size(), 6, platform::errors::InvalidArgument( "gumbel_softmax operator doesn't supports " "tensors whose ranks are greater " "than 6 in CPU mode.")); break; #undef CALL_ARG_MINMAX_FUNCTOR } math::set_constant(context, Out, 0.0); for (int i = 0; i < size_to_axis; i++) { for (int j = 0; j < size_out_axis; j++) { *(Out->data() + i * size_from_axis + j + index_data[i * size_out_axis + j] * size_out_axis) = 1.0; } } } }; template class GumbelSoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Out = context.Output("Out"); const int rank = X->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = X->dims()[axis]; const bool is_hard = context.Attr("hard"); const float temperature = context.Attr("temperature"); PADDLE_ENFORCE_GT(temperature, 0, platform::errors::InvalidArgument( "The temperature must be greater than 0. But " "received temperature = %f", temperature)); // allocate memory on device. Out->mutable_data(context.GetPlace()); if (Out->numel() == 0) { return; } const int size_to_axis = SizeToAxis(axis, X->dims()); const int size_from_axis = SizeFromAxis(axis, X->dims()); Tensor X_noise_2d, Out_2d; X_noise_2d.Resize({size_to_axis, size_from_axis}); Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis}); // generate gumbel noise and add it to X auto* x_noise_data = X_noise_2d.mutable_data(context.GetPlace()); GumbleNoiseGenerator::Transform( context.template device_context(), X->data(), x_noise_data, size_to_axis, size_from_axis, temperature); #ifdef PADDLE_ON_INFERENCE math::SoftmaxFunctor()( context.template device_context(), axis_dim, &X_noise_2d, &Out_2d); #else math::SoftmaxFunctor()( context.template device_context(), axis_dim, &X_noise_2d, &Out_2d); #endif if (is_hard) { OneHotGenerator::Transform( context.template device_context(), *X, Out, axis); } } }; template class GumbelSoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); const int rank = dX->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = dX->dims()[axis]; // allocate memory on device. dX->mutable_data(context.GetPlace()); if (dX->numel() == 0) { return; } const int size_to_axis = SizeToAxis(axis, dX->dims()); const int size_from_axis = SizeFromAxis(axis, dX->dims()); Tensor dX_2d, Out_2d, dOut_2d; dX_2d.ShareDataWith(*dX).Resize({size_to_axis, size_from_axis}); Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis}); dOut_2d.ShareDataWith(*dOut).Resize({size_to_axis, size_from_axis}); math::SoftmaxGradFunctor()( context.template device_context(), axis_dim, &Out_2d, &dOut_2d, &dX_2d); } }; } // namespace operators } // namespace paddle