From 581b2c64b18fd5968cde2d158dca4327ca20b27a Mon Sep 17 00:00:00 2001 From: From00 Date: Sat, 26 Feb 2022 14:07:12 +0800 Subject: [PATCH] Move GumbelSoftmax OP to phi (#39873) * Move GumbelSoftmax OP to phi * platform::errors -> phi::errors; GumbelSoftmaxGradInferMeta -> backend.h/cc * Use axis util in kernel impl * Remove namespace platform::errors * Use GetCPUEngine in Device Context --- paddle/fluid/operators/gumbel_softmax_op.cc | 48 ++-- paddle/fluid/operators/gumbel_softmax_op.cu | 172 ------------ paddle/fluid/operators/gumbel_softmax_op.h | 249 ------------------ paddle/phi/infermeta/backward.cc | 12 + paddle/phi/infermeta/backward.h | 4 + paddle/phi/infermeta/unary.cc | 32 +++ paddle/phi/infermeta/unary.h | 11 + paddle/phi/kernels/CMakeLists.txt | 2 +- .../kernels/cpu/gumbel_softmax_grad_kernel.cc | 25 ++ .../phi/kernels/cpu/gumbel_softmax_kernel.cc | 121 +++++++++ .../kernels/gpu/gumbel_softmax_grad_kernel.cu | 25 ++ .../phi/kernels/gpu/gumbel_softmax_kernel.cu | 181 +++++++++++++ .../phi/kernels/gumbel_softmax_grad_kernel.h | 27 ++ paddle/phi/kernels/gumbel_softmax_kernel.h | 28 ++ .../impl/gumbel_softmax_grad_kernel_impl.h | 50 ++++ .../kernels/impl/gumbel_softmax_kernel_impl.h | 96 +++++++ paddle/phi/ops/compat/gumbel_softmax_sig.cc | 30 +++ 17 files changed, 658 insertions(+), 455 deletions(-) delete mode 100644 paddle/fluid/operators/gumbel_softmax_op.cu delete mode 100644 paddle/fluid/operators/gumbel_softmax_op.h create mode 100644 paddle/phi/kernels/cpu/gumbel_softmax_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc create mode 100644 paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu create mode 100644 paddle/phi/kernels/gumbel_softmax_grad_kernel.h create mode 100644 paddle/phi/kernels/gumbel_softmax_kernel.h create mode 100644 paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h create mode 100644 paddle/phi/ops/compat/gumbel_softmax_sig.cc diff --git a/paddle/fluid/operators/gumbel_softmax_op.cc b/paddle/fluid/operators/gumbel_softmax_op.cc index 95c6ed6690..f8f8f3fd78 100644 --- a/paddle/fluid/operators/gumbel_softmax_op.cc +++ b/paddle/fluid/operators/gumbel_softmax_op.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gumbel_softmax_op.h" -#include -#include -#include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -23,10 +24,6 @@ class GumbelSoftmaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - return UnaryOpUnchangedInferShapeCheckAxis(ctx); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -71,20 +68,6 @@ Samples from the Gumbel-Softmax distribution and optionally discretizes. class GumbelSoftmaxGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "gumbel_softmax_grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "gumbel_softmax_grad"); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Out"), - ctx->GetInputDim(framework::GradVarName("Out")), - platform::errors::InvalidArgument("Input(Out) and its gradients " - "should have the same shape.")); - - ctx->SetOutputDim(framework::GradVarName("X"), - ctx->GetInputDim(framework::GradVarName("Out"))); - } }; template @@ -107,17 +90,16 @@ class GumbelSoftmaxGradOpMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(gumbel_softmax, GumbelSoftmaxInferShapeFunctor, + PT_INFER_META(phi::GumbelSoftmaxInferMeta)); +DELCARE_INFER_SHAPE_FUNCTOR(gumbel_softmax_grad, + GumbelSoftmaxGradInferShapeFunctor, + PT_INFER_META(phi::GumbelSoftmaxGradInferMeta)); + REGISTER_OPERATOR(gumbel_softmax, ops::GumbelSoftmaxOp, ops::GumbelSoftmaxOpMaker, ops::GumbelSoftmaxGradOpMaker, - ops::GumbelSoftmaxGradOpMaker); -REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp); - -REGISTER_OP_CPU_KERNEL( - gumbel_softmax, - ops::GumbelSoftmaxKernel, - ops::GumbelSoftmaxKernel); -REGISTER_OP_CPU_KERNEL( - gumbel_softmax_grad, - ops::GumbelSoftmaxGradKernel, - ops::GumbelSoftmaxGradKernel); + ops::GumbelSoftmaxGradOpMaker, + GumbelSoftmaxInferShapeFunctor); +REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp, + GumbelSoftmaxGradInferShapeFunctor); diff --git a/paddle/fluid/operators/gumbel_softmax_op.cu b/paddle/fluid/operators/gumbel_softmax_op.cu deleted file mode 100644 index 880e3eb9f3..0000000000 --- a/paddle/fluid/operators/gumbel_softmax_op.cu +++ /dev/null @@ -1,172 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/gumbel_softmax_op.h" - -#if defined(__NVCC__) || defined(__HIPCC__) -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - -#include -#include -#include -#include -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/memory/memcpy.h" - -namespace paddle { -namespace operators { - -template -using KeyValuePair = cub::KeyValuePair; - -template -struct UniformCUDAGenerator { - T min_, max_; - unsigned int seed_; - unsigned int offset_ = 0; - HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed) - : min_(min), max_(max), seed_(seed) {} - HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed, - unsigned int offset) - : min_(min), max_(max), seed_(seed), offset_(offset) {} - - HOSTDEVICE T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n + offset_); - return dist(rng); - } -}; - -template -__global__ void OneHotCUDAKernel(const int64_t height, const int64_t width, - const int64_t size_out_axis, const T init, - const T* in, T* out) { - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - for (int64_t idx = blockIdx.x; idx < height; idx += gridDim.x) { - KeyValuePair kv_pair = {-1, init}; - int h = idx / size_out_axis; - int w = idx % size_out_axis; - cub::ArgMax reducer; - for (int k = threadIdx.x; k < width; k += blockDim.x) { - kv_pair = reducer( - {k, in[h * width * size_out_axis + k * size_out_axis + w]}, kv_pair); - } - kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); - if (threadIdx.x == 0) { - int index = static_cast(kv_pair.key); - out[h * width * size_out_axis + index * size_out_axis + w] = 1; - } - __syncthreads(); - } -} - -template -struct OneHotGenerator { - static void Transform(const platform::CUDADeviceContext& context, - const Tensor& X, Tensor* Out, int axis) { - const int size_to_axis = SizeToAxis(axis, X.dims()); - const int size_from_axis = SizeFromAxis(axis, X.dims()); - const int size_out_axis = SizeOutAxis(axis, X.dims()); - constexpr int thread_size = 512; - int64_t max_grid_dimx = context.GetCUDAMaxGridDimSize()[0]; - int64_t height = size_to_axis * size_out_axis; - int block_size = height < max_grid_dimx ? height : max_grid_dimx; - - Tensor input_tensor; - input_tensor.mutable_data(Out->dims(), platform::CUDAPlace()); - paddle::framework::TensorCopy(*Out, context.GetPlace(), &input_tensor); - phi::funcs::set_constant(context, Out, 0.0); - OneHotCUDAKernel< - T, thread_size><<>>( - height, size_from_axis / size_out_axis, size_out_axis, - std::numeric_limits::lowest(), input_tensor.data(), - Out->data()); - } -}; - -template -__global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data, - T* noise, const float temperature, - int64_t n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int step = blockDim.x * gridDim.x; - for (int64_t i = index; i < n; i += step) { - T gumbel_noise = -log(-log(noise[i])); - output_data[i] = (gumbel_noise + input_data[i]) / temperature; - } -} - -template -struct GumbleNoiseGenerator { - static void Transform(const platform::CUDADeviceContext& context, - const T* input_data, T* output_data, int size_to_axis, - int size_from_axis, const float temperature) { - Tensor random_tensor; - int64_t size = size_to_axis * size_from_axis; - T* random_data = - random_tensor.mutable_data({size}, platform::CUDAPlace()); - thrust::counting_iterator index_sequence_begin(0); - - // generate gumbel noise - int device_id = context.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy()) { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - thrust::transform( - index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(random_data), - UniformCUDAGenerator(0.00001, 1, seed_offset.first, gen_offset)); - } else { - const unsigned int seed = std::random_device()(); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(random_data), - UniformCUDAGenerator(0.00001, 1, seed)); - } - - // add gumbel noise to X - const int thread_size = 512; - int64_t block_size = (size + thread_size) / thread_size; - AddGumbelNoiseCUDAKernel< - T><<>>( - input_data, output_data, random_data, temperature, size); - } -}; - -#endif -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - gumbel_softmax, ops::GumbelSoftmaxKernel, - ops::GumbelSoftmaxKernel); -REGISTER_OP_CUDA_KERNEL( - gumbel_softmax_grad, - ops::GumbelSoftmaxGradKernel, - ops::GumbelSoftmaxGradKernel); diff --git a/paddle/fluid/operators/gumbel_softmax_op.h b/paddle/fluid/operators/gumbel_softmax_op.h deleted file mode 100644 index daddd13d7b..0000000000 --- a/paddle/fluid/operators/gumbel_softmax_op.h +++ /dev/null @@ -1,249 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/softmax.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -template -using EigenMatrix = framework::EigenMatrix; - -template -using EigenTensor = framework::EigenTensor; - -static inline int CanonicalAxis(const int axis, const int rank) { - if (axis < 0) { - return axis + rank; - } - return axis; -} - -static inline int SizeToAxis(const int axis, DDim dims) { - int size = 1; - for (int i = 0; i < axis; i++) { - size *= dims[i]; - } - return size; -} - -static inline int SizeFromAxis(const int axis, DDim dims) { - int size = 1; - for (int i = axis; i < dims.size(); i++) { - size *= dims[i]; - } - return size; -} - -static inline int SizeOutAxis(const int axis, DDim dims) { - int size = 1; - for (int i = axis + 1; i < dims.size(); i++) { - size *= dims[i]; - } - return size; -} - -template -struct ArgMaxFunctor { - void operator()(const DeviceContext& ctx, const Tensor& in, - Tensor* index_tensor, const int64_t& axis) { - auto in_eigen = EigenTensor::From(in, in.dims()); - auto index_eigen = EigenTensor::From(*index_tensor); - index_eigen = in_eigen.argmax(axis).template cast(); - } -}; -template -struct GumbleNoiseGenerator; - -template -struct OneHotGenerator; - -template -struct GumbleNoiseGenerator { - static void Transform(const platform::CPUDeviceContext& context, - const T* input_data, T* output_data, int size_to_axis, - int size_from_axis, const float temperature) { - // generate uniform random number - const int size = size_to_axis * size_from_axis; - std::uniform_real_distribution dist(0.00001, 1); - auto engine = paddle::framework::GetCPURandomEngine(0); - Tensor random_tensor; - auto* random_data = - random_tensor.mutable_data({size}, platform::CPUPlace()); - for (int64_t i = 0; i < size; ++i) { - random_data[i] = dist(*engine); - } - - // generate gumbel noise - framework::DDim dim_2d{size_to_axis, size_from_axis}; - auto gumbel_noise_eigen = EigenMatrix::From(random_tensor, dim_2d); - gumbel_noise_eigen = -(((-(gumbel_noise_eigen.log())).log())); - - // add noise - for (int64_t i = 0; i < size_to_axis * size_from_axis; i++) { - output_data[i] = (input_data[i] + random_data[i]) / temperature; - } - } -}; -template -struct OneHotGenerator { - static void Transform(const platform::CPUDeviceContext& context, - const Tensor& X, Tensor* Out, int axis) { - Tensor index; - std::vector index_dim; - const auto rank = X.dims().size(); - const int size_to_axis = SizeToAxis(axis, X.dims()); - const int size_from_axis = SizeFromAxis(axis, X.dims()); - const int size_out_axis = SizeOutAxis(axis, X.dims()); - - for (int i = 0; i < X.dims().size(); i++) { - if (i != axis) index_dim.push_back(X.dims().Get()[i]); - } - DDim index_ddim(index_dim.data(), rank - 1); - index.Resize(index_ddim); - auto* index_data = index.mutable_data(context.GetPlace()); - -#define CALL_ARG_MINMAX_FUNCTOR(rank) \ - ArgMaxFunctor functor##rank; \ - functor##rank(context, *Out, &index, axis); - switch (Out->dims().size()) { - case 1: - CALL_ARG_MINMAX_FUNCTOR(1); - break; - case 2: - CALL_ARG_MINMAX_FUNCTOR(2); - break; - case 3: - CALL_ARG_MINMAX_FUNCTOR(3); - break; - case 4: - CALL_ARG_MINMAX_FUNCTOR(4); - break; - case 5: - CALL_ARG_MINMAX_FUNCTOR(5); - break; - case 6: - CALL_ARG_MINMAX_FUNCTOR(6); - break; - default: - PADDLE_ENFORCE_LE(Out->dims().size(), 6, - platform::errors::InvalidArgument( - "gumbel_softmax operator doesn't supports " - "tensors whose ranks are greater " - "than 6 in CPU mode.")); - break; -#undef CALL_ARG_MINMAX_FUNCTOR - } - - phi::funcs::set_constant(context, Out, 0.0); - for (int i = 0; i < size_to_axis; i++) { - for (int j = 0; j < size_out_axis; j++) { - *(Out->data() + i * size_from_axis + j + - index_data[i * size_out_axis + j] * size_out_axis) = 1.0; - } - } - } -}; - -template -class GumbelSoftmaxKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Out = context.Output("Out"); - const int rank = X->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = X->dims()[axis]; - const bool is_hard = context.Attr("hard"); - const float temperature = context.Attr("temperature"); - PADDLE_ENFORCE_GT(temperature, 0, - platform::errors::InvalidArgument( - "The temperature must be greater than 0. But " - "received temperature = %f", - temperature)); - - // allocate memory on device. - Out->mutable_data(context.GetPlace()); - if (Out->numel() == 0) { - return; - } - - const int size_to_axis = SizeToAxis(axis, X->dims()); - const int size_from_axis = SizeFromAxis(axis, X->dims()); - Tensor X_noise_2d, Out_2d; - X_noise_2d.Resize({size_to_axis, size_from_axis}); - Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis}); - - // generate gumbel noise and add it to X - auto* x_noise_data = X_noise_2d.mutable_data(context.GetPlace()); - GumbleNoiseGenerator::Transform( - context.template device_context(), X->data(), - x_noise_data, size_to_axis, size_from_axis, temperature); - -#ifdef PADDLE_ON_INFERENCE - math::SoftmaxFunctor()( - context.template device_context(), axis_dim, &X_noise_2d, - &Out_2d); -#else - math::SoftmaxFunctor()( - context.template device_context(), axis_dim, &X_noise_2d, - &Out_2d); -#endif - - if (is_hard) { - OneHotGenerator::Transform( - context.template device_context(), *X, Out, axis); - } - } -}; - -template -class GumbelSoftmaxGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* Out = context.Input("Out"); - auto* dOut = context.Input(framework::GradVarName("Out")); - auto* dX = context.Output(framework::GradVarName("X")); - const int rank = dX->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = dX->dims()[axis]; - // allocate memory on device. - dX->mutable_data(context.GetPlace()); - if (dX->numel() == 0) { - return; - } - - const int size_to_axis = SizeToAxis(axis, dX->dims()); - const int size_from_axis = SizeFromAxis(axis, dX->dims()); - Tensor dX_2d, Out_2d, dOut_2d; - dX_2d.ShareDataWith(*dX).Resize({size_to_axis, size_from_axis}); - Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis}); - dOut_2d.ShareDataWith(*dOut).Resize({size_to_axis, size_from_axis}); - math::SoftmaxGradFunctor()( - context.template device_context(), axis_dim, &Out_2d, - &dOut_2d, &dX_2d); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e08eae0fc6..643a6dc9dd 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -76,4 +76,16 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x, } } +void GumbelSoftmaxGradInferMeta(const MetaTensor& out, + const MetaTensor& dout, + int axis, + MetaTensor* dx) { + PADDLE_ENFORCE_EQ( + out.dims(), + dout.dims(), + errors::InvalidArgument( + "Input(Out) and its gradients should have the same shape.")); + dx->share_meta(dout); +} + } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 35f988bbc0..5afa678dda 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -34,4 +34,8 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x, MetaTensor* dx, MetaTensor* dy); +void GumbelSoftmaxGradInferMeta(const MetaTensor& out, + const MetaTensor& dout, + int axis, + MetaTensor* dx); } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9b2f310e85..1a9dbf90dd 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -27,6 +27,30 @@ void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) { out->share_meta(x); } +// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1] +void UnchangedInferMetaCheckAxis(const MetaTensor& x, + int axis, + MetaTensor* out) { + auto rank = x.dims().size(); + PADDLE_ENFORCE_GE( + axis, + -rank, + errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X). But received axis: %d, R: %d.", + axis, + rank)); + PADDLE_ENFORCE_LT( + axis, + rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X). But received axis: %d, R: %d.", + axis, + rank)); + out->share_meta(x); +} + void FlattenInferMeta(const MetaTensor& x, int start_axis, int stop_axis, @@ -75,6 +99,14 @@ void FlattenInferMeta(const MetaTensor& x, } } +void GumbelSoftmaxInferMeta(const MetaTensor& x, + float temperature, + bool hard, + int axis, + MetaTensor* out) { + UnchangedInferMetaCheckAxis(x, axis, out); +} + void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(out_dtype); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 40bf4e3335..172ea2a565 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -34,11 +34,22 @@ class MetaConfig; void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out); +// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1] +void UnchangedInferMetaCheckAxis(const MetaTensor& x, + int axis, + MetaTensor* out); + void FlattenInferMeta(const MetaTensor& x, int start_axis, int stop_axis, MetaTensor* out); +void GumbelSoftmaxInferMeta(const MetaTensor& x, + float temperature, + bool hard, + int axis, + MetaTensor* out); + void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 1523401d19..ef51d6daf6 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -10,7 +10,7 @@ add_subdirectory(funcs) set_property(GLOBAL PROPERTY PHI_KERNELS "") set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col concat_and_split_functor) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col concat_and_split_functor softmax) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) diff --git a/paddle/phi/kernels/cpu/gumbel_softmax_grad_kernel.cc b/paddle/phi/kernels/cpu/gumbel_softmax_grad_kernel.cc new file mode 100644 index 0000000000..a4c131e72b --- /dev/null +++ b/paddle/phi/kernels/cpu/gumbel_softmax_grad_kernel.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gumbel_softmax_grad_kernel.h" +#include "paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(gumbel_softmax_grad, + CPU, + ALL_LAYOUT, + phi::GumbelSoftmaxGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc new file mode 100644 index 0000000000..eb406665c5 --- /dev/null +++ b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gumbel_softmax_kernel.h" +#include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +struct GumbleNoiseGenerator { + static void Transform(const CPUContext& ctx, + const T* input_data, + T* output_data, + int size_to_axis, + int size_from_axis, + const float temperature) { + // generate uniform random number + const int size = size_to_axis * size_from_axis; + std::uniform_real_distribution dist(0.00001, 1); + auto engine = ctx.GetGenerator()->GetCPUEngine(); + DenseTensor random_tensor; + random_tensor.Resize(make_ddim({size})); + auto* random_data = ctx.template Alloc(&random_tensor); + for (int64_t i = 0; i < size; ++i) { + random_data[i] = dist(*engine); + } + + // generate gumbel noise + DDim dim_2d{size_to_axis, size_from_axis}; + auto gumbel_noise_eigen = EigenMatrix::From(random_tensor, dim_2d); + gumbel_noise_eigen = -(((-(gumbel_noise_eigen.log())).log())); + + // add noise + for (int64_t i = 0; i < size_to_axis * size_from_axis; i++) { + output_data[i] = (input_data[i] + random_data[i]) / temperature; + } + } +}; + +template +struct OneHotGenerator { + static void Transform(const CPUContext& ctx, + const DenseTensor& x, + DenseTensor* out, + int axis) { + DenseTensor index; + std::vector index_dim; + const auto rank = x.dims().size(); + const int size_to_axis = funcs::SizeToAxis(axis, x.dims()); + const int size_from_axis = funcs::SizeFromAxis(axis, x.dims()); + const int size_out_axis = funcs::SizeOutAxis(axis, x.dims()); + + for (int i = 0; i < x.dims().size(); i++) { + if (i != axis) index_dim.push_back(x.dims().Get()[i]); + } + DDim index_ddim(index_dim.data(), rank - 1); + index.Resize(index_ddim); + auto* index_data = ctx.template Alloc(&index); + +#define CALL_ARG_MINMAX_FUNCTOR(rank) \ + ArgMaxFunctor functor##rank; \ + functor##rank(ctx, *out, &index, axis); + switch (out->dims().size()) { + case 1: + CALL_ARG_MINMAX_FUNCTOR(1); + break; + case 2: + CALL_ARG_MINMAX_FUNCTOR(2); + break; + case 3: + CALL_ARG_MINMAX_FUNCTOR(3); + break; + case 4: + CALL_ARG_MINMAX_FUNCTOR(4); + break; + case 5: + CALL_ARG_MINMAX_FUNCTOR(5); + break; + case 6: + CALL_ARG_MINMAX_FUNCTOR(6); + break; + default: + PADDLE_ENFORCE_LE( + out->dims().size(), + 6, + errors::InvalidArgument("gumbel_softmax operator doesn't supports " + "tensors whose ranks are greater " + "than 6 in CPU mode.")); + break; +#undef CALL_ARG_MINMAX_FUNCTOR + } + + funcs::set_constant(ctx, out, 0.0); + for (int i = 0; i < size_to_axis; i++) { + for (int j = 0; j < size_out_axis; j++) { + *(out->data() + i * size_from_axis + j + + index_data[i * size_out_axis + j] * size_out_axis) = 1.0; + } + } + } +}; + +} // namespace phi + +PD_REGISTER_KERNEL( + gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu new file mode 100644 index 0000000000..a28a7512f4 --- /dev/null +++ b/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gumbel_softmax_grad_kernel.h" +#include "paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(gumbel_softmax_grad, + GPU, + ALL_LAYOUT, + phi::GumbelSoftmaxGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu new file mode 100644 index 0000000000..6b1e58981b --- /dev/null +++ b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu @@ -0,0 +1,181 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gumbel_softmax_kernel.h" +#include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include +#include +#include +#include +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +template +using KeyValuePair = cub::KeyValuePair; + +template +struct UniformCUDAGenerator { + T min_, max_; + unsigned int seed_; + unsigned int offset_ = 0; + HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed) + : min_(min), max_(max), seed_(seed) {} + HOSTDEVICE UniformCUDAGenerator(T min, + T max, + unsigned int seed, + unsigned int offset) + : min_(min), max_(max), seed_(seed), offset_(offset) {} + + HOSTDEVICE T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n + offset_); + return dist(rng); + } +}; + +template +__global__ void OneHotCUDAKernel(const int64_t height, + const int64_t width, + const int64_t size_out_axis, + const T init, + const T* in, + T* out) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (int64_t idx = blockIdx.x; idx < height; idx += gridDim.x) { + KeyValuePair kv_pair = {-1, init}; + int h = idx / size_out_axis; + int w = idx % size_out_axis; + cub::ArgMax reducer; + for (int k = threadIdx.x; k < width; k += blockDim.x) { + kv_pair = reducer( + {k, in[h * width * size_out_axis + k * size_out_axis + w]}, kv_pair); + } + kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); + if (threadIdx.x == 0) { + int index = static_cast(kv_pair.key); + out[h * width * size_out_axis + index * size_out_axis + w] = 1; + } + __syncthreads(); + } +} + +template +struct OneHotGenerator { + static void Transform(const GPUContext& ctx, + const DenseTensor& X, + DenseTensor* out, + int axis) { + const int size_to_axis = funcs::SizeToAxis(axis, X.dims()); + const int size_from_axis = funcs::SizeFromAxis(axis, X.dims()); + const int size_out_axis = funcs::SizeOutAxis(axis, X.dims()); + constexpr int thread_size = 512; + int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; + int64_t height = size_to_axis * size_out_axis; + int block_size = height < max_grid_dimx ? height : max_grid_dimx; + + DenseTensor input_tensor; + input_tensor.Resize(out->dims()); + ctx.template Alloc(&input_tensor); + paddle::framework::TensorCopy(*out, ctx.GetPlace(), &input_tensor); + funcs::set_constant(ctx, out, 0.0); + OneHotCUDAKernel<<>>( + height, + size_from_axis / size_out_axis, + size_out_axis, + std::numeric_limits::lowest(), + input_tensor.data(), + out->data()); + } +}; + +template +__global__ void AddGumbelNoiseCUDAKernel(const T* input_data, + T* output_data, + T* noise, + const float temperature, + int64_t n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int step = blockDim.x * gridDim.x; + for (int64_t i = index; i < n; i += step) { + T gumbel_noise = -log(-log(noise[i])); + output_data[i] = (gumbel_noise + input_data[i]) / temperature; + } +} + +template +struct GumbleNoiseGenerator { + static void Transform(const GPUContext& ctx, + const T* input_data, + T* output_data, + int size_to_axis, + int size_from_axis, + const float temperature) { + DenseTensor random_tensor; + int64_t size = size_to_axis * size_from_axis; + random_tensor.Resize(make_ddim({size})); + auto* random_data = ctx.template Alloc(&random_tensor); + thrust::counting_iterator index_sequence_begin(0); + + // generate gumbel noise + int device_id = ctx.GetPlace().GetDeviceId(); + auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); + if (gen_cuda->GetIsInitPy()) { + auto seed_offset = gen_cuda->IncrementOffset(1); + int64_t gen_offset = size * seed_offset.second; + thrust::transform( + index_sequence_begin, + index_sequence_begin + size, + thrust::device_ptr(random_data), + UniformCUDAGenerator(0.00001, 1, seed_offset.first, gen_offset)); + } else { + const unsigned int seed = std::random_device()(); + thrust::transform(index_sequence_begin, + index_sequence_begin + size, + thrust::device_ptr(random_data), + UniformCUDAGenerator(0.00001, 1, seed)); + } + + // add gumbel noise to X + const int thread_size = 512; + int64_t block_size = (size + thread_size) / thread_size; + AddGumbelNoiseCUDAKernel<<>>( + input_data, output_data, random_data, temperature, size); + } +}; + +} // namespace phi +#endif + +PD_REGISTER_KERNEL( + gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} diff --git a/paddle/phi/kernels/gumbel_softmax_grad_kernel.h b/paddle/phi/kernels/gumbel_softmax_grad_kernel.h new file mode 100644 index 0000000000..e3f02d90fc --- /dev/null +++ b/paddle/phi/kernels/gumbel_softmax_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +namespace phi { + +template +void GumbelSoftmaxGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/gumbel_softmax_kernel.h b/paddle/phi/kernels/gumbel_softmax_kernel.h new file mode 100644 index 0000000000..46edb9750d --- /dev/null +++ b/paddle/phi/kernels/gumbel_softmax_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +namespace phi { + +template +void GumbelSoftmaxKernel(const Context& dev_ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h new file mode 100644 index 0000000000..3d57dd1002 --- /dev/null +++ b/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/math/softmax_impl.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + +namespace phi { + +template +void GumbelSoftmaxGradKernel(const Context& ctx, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx) { + const int rank = dx->dims().size(); + axis = funcs::CanonicalAxis(axis, rank); + int axis_dim = dx->dims()[axis]; + // allocate memory on device. + + ctx.template Alloc(dx); + if (dx->numel() == 0) { + return; + } + + const int size_to_axis = funcs::SizeToAxis(axis, dx->dims()); + const int size_from_axis = funcs::SizeFromAxis(axis, dx->dims()); + DenseTensor dx_2d(*dx), out_2d(out), dout_2d(dout); + dx_2d.Resize({size_to_axis, size_from_axis}); + out_2d.Resize({size_to_axis, size_from_axis}); + dout_2d.Resize({size_to_axis, size_from_axis}); + paddle::operators::math::SoftmaxGradFunctor()( + ctx, axis_dim, &out_2d, &dout_2d, &dx_2d); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h new file mode 100644 index 0000000000..2517d84898 --- /dev/null +++ b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h @@ -0,0 +1,96 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/math/softmax_impl.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +struct ArgMaxFunctor { + void operator()(const Context& ctx, + const DenseTensor& in, + DenseTensor* index_tensor, + const int64_t& axis) { + auto in_eigen = EigenTensor::From(in, in.dims()); + auto index_eigen = EigenTensor::From(*index_tensor); + index_eigen = in_eigen.argmax(axis).template cast(); + } +}; + +template +struct GumbleNoiseGenerator; + +template +struct OneHotGenerator; + +template +void GumbelSoftmaxKernel(const Context& ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out) { + const int rank = x.dims().size(); + axis = funcs::CanonicalAxis(axis, rank); + int axis_dim = x.dims()[axis]; + + PADDLE_ENFORCE_GT(temperature, + 0, + phi::errors::InvalidArgument( + "The temperature must be greater than 0. But " + "received temperature = %f", + temperature)); + + // allocate memory on device. + ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + + const int size_to_axis = funcs::SizeToAxis(axis, x.dims()); + const int size_from_axis = funcs::SizeFromAxis(axis, x.dims()); + DenseTensor x_noise_2d, out_2d(*out); + x_noise_2d.Resize({size_to_axis, size_from_axis}); + out_2d.Resize({size_to_axis, size_from_axis}); + + // generate gumbel noise and add it to X + auto* x_noise_data = ctx.template Alloc(&x_noise_2d); + GumbleNoiseGenerator::Transform(ctx, + x.data(), + x_noise_data, + size_to_axis, + size_from_axis, + temperature); + +#ifdef PADDLE_ON_INFERENCE + paddle::operators::math::SoftmaxFunctor()( + ctx, axis_dim, &x_noise_2d, &out_2d); +#else + paddle::operators::math::SoftmaxFunctor()( + ctx, axis_dim, &x_noise_2d, &out_2d); +#endif + + if (hard) { + OneHotGenerator::Transform(ctx, x, out, axis); + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/gumbel_softmax_sig.cc b/paddle/phi/ops/compat/gumbel_softmax_sig.cc new file mode 100644 index 0000000000..c7585a4e5f --- /dev/null +++ b/paddle/phi/ops/compat/gumbel_softmax_sig.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GumbelSoftmaxGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("gumbel_softmax_grad", + {"Out", GradVarName("Out")}, + {"axis"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad, + phi::GumbelSoftmaxGradOpArgumentMapping); -- GitLab