From 18eda6c371675f6d85b31eef77608fac34904588 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 15 Sep 2021 13:09:02 +0800 Subject: [PATCH] Add New OP: gumbel_softmax (#35506) * Add New Op: gumbel_softmax * Add New Op: gumbel_softmax * Add New Op: gumbel_softmax (amend) * add __main__ function in unit test * fix bugs when test in windows ci * update en docs * delete reletive error in unit test * delete relative error in unit test * set hard=True in unit test --- paddle/fluid/operators/gumbel_softmax_op.cc | 123 +++++++++ paddle/fluid/operators/gumbel_softmax_op.cu | 173 ++++++++++++ paddle/fluid/operators/gumbel_softmax_op.h | 250 ++++++++++++++++++ .../tests/unittests/test_gumbel_softmax_op.py | 229 ++++++++++++++++ python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/activation.py | 75 ++++++ 6 files changed, 852 insertions(+) create mode 100644 paddle/fluid/operators/gumbel_softmax_op.cc create mode 100644 paddle/fluid/operators/gumbel_softmax_op.cu create mode 100644 paddle/fluid/operators/gumbel_softmax_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py diff --git a/paddle/fluid/operators/gumbel_softmax_op.cc b/paddle/fluid/operators/gumbel_softmax_op.cc new file mode 100644 index 0000000000..95c6ed6690 --- /dev/null +++ b/paddle/fluid/operators/gumbel_softmax_op.cc @@ -0,0 +1,123 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gumbel_softmax_op.h" +#include +#include +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +namespace paddle { +namespace operators { +class GumbelSoftmaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + return UnaryOpUnchangedInferShapeCheckAxis(ctx); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class GumbelSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor) An N-D Tensor, N >= 1," + "The first N - 1 dimensions index into a batch of independent " + "distributions " + "and the last dimension represents a vector of probabilities for " + "each class."); + AddOutput("Out", "The sampled tensor with the same shape as X."); + AddAttr("temperature", + "(float, default 1.0) non-negative scalar temperature.") + .SetDefault(1.0); + AddAttr( + "hard", + "(bool, default false) " + "if True, the returned samples will be discretized as one-hot vectors, " + "but will be differentiated as if it is the soft sample in autograd.") + .SetDefault(false); + AddAttr("axis", + "(int, default -1)" + "The dimension index of Input(x) to perform gumbel_softmax.") + .SetDefault(-1); + AddComment(R"DOC( +GumbelSoftmax Operator. + +Samples from the Gumbel-Softmax distribution and optionally discretizes. + +)DOC"); + } +}; + +class GumbelSoftmaxGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "gumbel_softmax_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "gumbel_softmax_grad"); + PADDLE_ENFORCE_EQ( + ctx->GetInputDim("Out"), + ctx->GetInputDim(framework::GradVarName("Out")), + platform::errors::InvalidArgument("Input(Out) and its gradients " + "should have the same shape.")); + + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } +}; + +template +class GumbelSoftmaxGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("gumbel_softmax_grad"); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gumbel_softmax, ops::GumbelSoftmaxOp, + ops::GumbelSoftmaxOpMaker, + ops::GumbelSoftmaxGradOpMaker, + ops::GumbelSoftmaxGradOpMaker); +REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp); + +REGISTER_OP_CPU_KERNEL( + gumbel_softmax, + ops::GumbelSoftmaxKernel, + ops::GumbelSoftmaxKernel); +REGISTER_OP_CPU_KERNEL( + gumbel_softmax_grad, + ops::GumbelSoftmaxGradKernel, + ops::GumbelSoftmaxGradKernel); diff --git a/paddle/fluid/operators/gumbel_softmax_op.cu b/paddle/fluid/operators/gumbel_softmax_op.cu new file mode 100644 index 0000000000..bf0ac66741 --- /dev/null +++ b/paddle/fluid/operators/gumbel_softmax_op.cu @@ -0,0 +1,173 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/gumbel_softmax_op.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include +#include +#include +#include +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/memory/memcpy.h" + +namespace paddle { +namespace operators { + +template +using KeyValuePair = cub::KeyValuePair; + +template +struct UniformCUDAGenerator { + T min_, max_; + unsigned int seed_; + unsigned int offset_ = 0; + HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed) + : min_(min), max_(max), seed_(seed) {} + HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed, + unsigned int offset) + : min_(min), max_(max), seed_(seed), offset_(offset) {} + + HOSTDEVICE T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n + offset_); + return dist(rng); + } +}; + +template +__global__ void OneHotCUDAKernel(const int64_t height, const int64_t width, + const int64_t size_out_axis, const T init, + const T* in, T* out) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (int64_t idx = blockIdx.x; idx < height; idx += gridDim.x) { + KeyValuePair kv_pair = {-1, init}; + int h = idx / size_out_axis; + int w = idx % size_out_axis; + cub::ArgMax reducer; + for (int k = threadIdx.x; k < width; k += blockDim.x) { + kv_pair = reducer( + {k, in[h * width * size_out_axis + k * size_out_axis + w]}, kv_pair); + } + kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); + if (threadIdx.x == 0) { + int index = static_cast(kv_pair.key); + out[h * width * size_out_axis + index * size_out_axis + w] = 1; + } + __syncthreads(); + } +} + +template +struct OneHotGenerator { + static void Transform(const platform::CUDADeviceContext& context, + const Tensor& X, Tensor* Out, int axis) { + const int size_to_axis = SizeToAxis(axis, X.dims()); + const int size_from_axis = SizeFromAxis(axis, X.dims()); + const int size_out_axis = SizeOutAxis(axis, X.dims()); + constexpr int thread_size = 512; + int64_t max_grid_dimx = context.GetCUDAMaxGridDimSize().x; + int64_t height = size_to_axis * size_out_axis; + int block_size = height < max_grid_dimx ? height : max_grid_dimx; + + Tensor input_tensor; + input_tensor.mutable_data(Out->dims(), platform::CUDAPlace()); + TensorCopy(*Out, context.GetPlace(), &input_tensor); + math::set_constant(context, Out, 0.0); + OneHotCUDAKernel< + T, thread_size><<>>( + height, size_from_axis / size_out_axis, size_out_axis, + std::numeric_limits::lowest(), input_tensor.data(), + Out->data()); + } +}; + +template +__global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data, + T* noise, const float temperature, + int64_t n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int step = blockDim.x * gridDim.x; + for (int64_t i = index; i < n; i += step) { + T gumbel_noise = -log(-log(noise[i])); + output_data[i] = (gumbel_noise + input_data[i]) / temperature; + } +} + +template +struct GumbleNoiseGenerator { + static void Transform(const platform::CUDADeviceContext& context, + const T* input_data, T* output_data, int size_to_axis, + int size_from_axis, const float temperature) { + Tensor random_tensor; + int64_t size = size_to_axis * size_from_axis; + T* random_data = + random_tensor.mutable_data({size}, platform::CUDAPlace()); + thrust::counting_iterator index_sequence_begin(0); + const unsigned int seed = std::random_device()(); + + // generate gumbel noise + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (gen_cuda->GetIsInitPy()) { + auto seed_offset = gen_cuda->IncrementOffset(1); + int gen_offset = size * seed_offset.second; + thrust::transform( + index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(random_data), + UniformCUDAGenerator(0.00001, 1, seed_offset.first, gen_offset)); + } else { + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(random_data), + UniformCUDAGenerator(0.00001, 1, seed)); + } + + // add gumbel noise to X + const int thread_size = 512; + int64_t block_size = (size + thread_size) / thread_size; + AddGumbelNoiseCUDAKernel< + T><<>>( + input_data, output_data, random_data, temperature, size); + } +}; + +#endif +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + gumbel_softmax, ops::GumbelSoftmaxKernel, + ops::GumbelSoftmaxKernel); +REGISTER_OP_CUDA_KERNEL( + gumbel_softmax_grad, + ops::GumbelSoftmaxGradKernel, + ops::GumbelSoftmaxGradKernel); diff --git a/paddle/fluid/operators/gumbel_softmax_op.h b/paddle/fluid/operators/gumbel_softmax_op.h new file mode 100644 index 0000000000..c224cc7ca1 --- /dev/null +++ b/paddle/fluid/operators/gumbel_softmax_op.h @@ -0,0 +1,250 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenTensor = framework::EigenTensor; + +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +static inline int SizeToAxis(const int axis, DDim dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeFromAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeOutAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +template +struct ArgMaxFunctor { + void operator()(const DeviceContext& ctx, const Tensor& in, + Tensor* index_tensor, const int64_t& axis) { + auto in_eigen = EigenTensor::From(in, in.dims()); + auto index_eigen = EigenTensor::From(*index_tensor); + index_eigen = in_eigen.argmax(axis).template cast(); + } +}; +template +struct GumbleNoiseGenerator; + +template +struct OneHotGenerator; + +template +struct GumbleNoiseGenerator { + static void Transform(const platform::CPUDeviceContext& context, + const T* input_data, T* output_data, int size_to_axis, + int size_from_axis, const float temperature) { + // generate uniform random number + const int size = size_to_axis * size_from_axis; + std::uniform_real_distribution dist(0.00001, 1); + const int seed = std::random_device()(); + auto engine = paddle::framework::GetCPURandomEngine(seed); + Tensor random_tensor; + auto* random_data = + random_tensor.mutable_data({size}, platform::CPUPlace()); + for (int64_t i = 0; i < size; ++i) { + random_data[i] = dist(*engine); + } + + // generate gumbel noise + framework::DDim dim_2d{size_to_axis, size_from_axis}; + auto gumbel_noise_eigen = EigenMatrix::From(random_tensor, dim_2d); + gumbel_noise_eigen = -(((-(gumbel_noise_eigen.log())).log())); + + // add noise + for (int64_t i = 0; i < size_to_axis * size_from_axis; i++) { + output_data[i] = (input_data[i] + random_data[i]) / temperature; + } + } +}; +template +struct OneHotGenerator { + static void Transform(const platform::CPUDeviceContext& context, + const Tensor& X, Tensor* Out, int axis) { + Tensor index; + std::vector index_dim; + const auto rank = X.dims().size(); + const int size_to_axis = SizeToAxis(axis, X.dims()); + const int size_from_axis = SizeFromAxis(axis, X.dims()); + const int size_out_axis = SizeOutAxis(axis, X.dims()); + + for (int i = 0; i < X.dims().size(); i++) { + if (i != axis) index_dim.push_back(X.dims().Get()[i]); + } + DDim index_ddim(index_dim.data(), rank - 1); + index.Resize(index_ddim); + auto* index_data = index.mutable_data(context.GetPlace()); + +#define CALL_ARG_MINMAX_FUNCTOR(rank) \ + ArgMaxFunctor functor##rank; \ + functor##rank(context, *Out, &index, axis); + switch (Out->dims().size()) { + case 1: + CALL_ARG_MINMAX_FUNCTOR(1); + break; + case 2: + CALL_ARG_MINMAX_FUNCTOR(2); + break; + case 3: + CALL_ARG_MINMAX_FUNCTOR(3); + break; + case 4: + CALL_ARG_MINMAX_FUNCTOR(4); + break; + case 5: + CALL_ARG_MINMAX_FUNCTOR(5); + break; + case 6: + CALL_ARG_MINMAX_FUNCTOR(6); + break; + default: + PADDLE_ENFORCE_LE(Out->dims().size(), 6, + platform::errors::InvalidArgument( + "gumbel_softmax operator doesn't supports " + "tensors whose ranks are greater " + "than 6 in CPU mode.")); + break; +#undef CALL_ARG_MINMAX_FUNCTOR + } + + math::set_constant(context, Out, 0.0); + for (int i = 0; i < size_to_axis; i++) { + for (int j = 0; j < size_out_axis; j++) { + *(Out->data() + i * size_from_axis + j + + index_data[i * size_out_axis + j] * size_out_axis) = 1.0; + } + } + } +}; + +template +class GumbelSoftmaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + const int rank = X->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = X->dims()[axis]; + const bool is_hard = context.Attr("hard"); + const float temperature = context.Attr("temperature"); + PADDLE_ENFORCE_GT(temperature, 0, + platform::errors::InvalidArgument( + "The temperature must be greater than 0. But " + "received temperature = %f", + temperature)); + + // allocate memory on device. + Out->mutable_data(context.GetPlace()); + if (Out->numel() == 0) { + return; + } + + const int size_to_axis = SizeToAxis(axis, X->dims()); + const int size_from_axis = SizeFromAxis(axis, X->dims()); + Tensor X_noise_2d, Out_2d; + X_noise_2d.Resize({size_to_axis, size_from_axis}); + Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis}); + + // generate gumbel noise and add it to X + auto* x_noise_data = X_noise_2d.mutable_data(context.GetPlace()); + GumbleNoiseGenerator::Transform( + context.template device_context(), X->data(), + x_noise_data, size_to_axis, size_from_axis, temperature); + +#ifdef PADDLE_ON_INFERENCE + math::SoftmaxFunctor()( + context.template device_context(), axis_dim, &X_noise_2d, + &Out_2d); +#else + math::SoftmaxFunctor()( + context.template device_context(), axis_dim, &X_noise_2d, + &Out_2d); +#endif + + if (is_hard) { + OneHotGenerator::Transform( + context.template device_context(), *X, Out, axis); + } + } +}; + +template +class GumbelSoftmaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* Out = context.Input("Out"); + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + const int rank = dX->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = dX->dims()[axis]; + // allocate memory on device. + dX->mutable_data(context.GetPlace()); + if (dX->numel() == 0) { + return; + } + + const int size_to_axis = SizeToAxis(axis, dX->dims()); + const int size_from_axis = SizeFromAxis(axis, dX->dims()); + Tensor dX_2d, Out_2d, dOut_2d; + dX_2d.ShareDataWith(*dX).Resize({size_to_axis, size_from_axis}); + Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis}); + dOut_2d.ShareDataWith(*dOut).Resize({size_to_axis, size_from_axis}); + math::SoftmaxGradFunctor()( + context.template device_context(), axis_dim, &Out_2d, + &dOut_2d, &dX_2d); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py b/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py new file mode 100644 index 0000000000..e423404d07 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +paddle.enable_static() + + +class TestGumbelSoftmaxOp(OpTest): + def init_attrs(self): + self.shape = [20, 10] + self.attrs = {"hard": True, "axis": -1} + self.count_expected = 20 + self.dtype = "float64" + + def verify_output(self, outs): + out_np = np.array(outs[0]) + out_np.shape = self.shape + self.assertTrue(list(out_np.shape) == self.shape) + self.assertEqual(out_np.sum(), self.count_expected) + + def setUp(self): + self.op_type = "gumbel_softmax" + self.init_attrs() + np.random.seed(0) + x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + out = np.zeros(self.shape).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestGumbelSoftmaxOp2(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [20, 10] + self.attrs = {"hard": True, "axis": 0} + self.count_expected = 10 + self.dtype = "float64" + + +class TestGumbelSoftmaxOp3(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [100] + self.attrs = {"hard": True, "axis": -1} + self.count_expected = 1 + self.dtype = "float64" + + +class TestGumbelSoftmaxOp4(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [20, 10, 5] + self.attrs = {"hard": True, "axis": -1} + self.count_expected = 200 + self.dtype = "float64" + + +class TestGumbelSoftmaxOp5(TestGumbelSoftmaxOp): + def init_attrs(self): + self.shape = [20, 10, 5] + self.attrs = {"hard": True, "axis": 1} + self.count_expected = 100 + self.dtype = "float64" + + +class TestGumbelSoftmaxOpSampleDistribution(OpTest): + def softmax(self, x): + x_row_max = x.max(axis=-1) + x_row_max = x_row_max.reshape(list(x.shape)[:-1] + [1]) + x = x - x_row_max + x_exp = np.exp(x) + x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1] + [1]) + softmax = x_exp / x_exp_row_sum + return softmax + + def init_attrs(self): + self.shape = [100, 3] + self.attrs = {"hard": True, "axis": -1} + self.counts = np.zeros(self.shape).astype(self.dtype) + self._cpu_only = True + + def accumulate_output(self, outs): + out_np = np.array(outs) + out_np = out_np.reshape(self.shape) + self.counts = np.sum(out_np, axis=0) + + def setUp(self): + self.op_type = "gumbel_softmax" + self.init_attrs() + single_x = np.array([0.2, 0.3, 0.5]) + batch_x = np.ones(self.shape) * single_x + out = np.zeros(self.shape).astype(self.dtype) + self.probs = self.softmax(single_x) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(batch_x)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output_customized(self.accumulate_output) + # Experiment should result in batch num . + self.assertEqual(self.counts.sum(), self.shape[0]) + + # Treat the probability from softmax as + # the probability of binomial distribution. + # Samples from gumbel softmax meet this binomial distribution. + # Construct statistics z for samples and + # z is approximately N(0,1) for unbiased count + expected = self.probs * self.shape[0] + z = (self.counts - expected) / np.sqrt((expected * (1 - self.probs))) + # A (lazy) approximate 99% two-sided test: + # occurs with prob alpha~>=0.01 if unbiased + self.assertLess(np.max(np.abs(z)).item(), 2.58) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestGumbelSoftmaxOpGrad(unittest.TestCase): + def init_attrs(self): + self.shape = [20, 10] + self.dtype = "float64" + + def setUp(self): + self.init_attrs() + np.random.seed(0) + self.x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + + def test_dygraph_check(self): + paddle.disable_static() + x_hard = paddle.to_tensor(self.x_np, stop_gradient=False) + x_soft = paddle.to_tensor(self.x_np, stop_gradient=False) + out_hard = paddle.nn.functional.gumbel_softmax(x_hard, hard=True) + out_soft = paddle.nn.functional.gumbel_softmax(x_soft, hard=False) + + out_hard.sum().backward() + out_soft.sum().backward() + + self.assertEqual( + np.allclose(x_hard.grad.numpy(), x_soft.grad.numpy()), True) + paddle.enable_static() + + +class TestGumbelSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32) + self.count_expected = 24 + self.place = paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_check_api(self): + # test static api + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=self.x_shape) + y = paddle.nn.functional.gumbel_softmax(x, hard=True) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) + out_np = np.array(out[0]) + self.assertEqual(out_np.sum(), self.count_expected) + + # test dygrapg api + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = paddle.nn.functional.gumbel_softmax(x, hard=True) + out_np = np.array(y) + self.assertEqual(out_np.sum(), self.count_expected) + paddle.enable_static() + + +class TestGumbelSoftmaxOpError(unittest.TestCase): + def test_errors(self): + paddle.disable_static() + + def test_Variable(): + x1 = fluid.create_lod_tensor( + np.zeros((100, 784)), [[10, 10, 10, 70]], fluid.CPUPlace()) + paddle.nn.functional.gumbel_softmax(x1) + + self.assertRaises(ValueError, test_Variable) + + def test_Variable2(): + x1 = np.zeros((100, 784)) + paddle.nn.functional.gumbel_softmax(x1) + + self.assertRaises(ValueError, test_Variable2) + + def test_argument1(): + x = paddle.to_tensor([0.2, 0.3, 0.4]) + paddle.nn.functional.gumbel_softmax(x, temperature=-1) + + self.assertRaises(ValueError, test_argument1) + + def test_argument2(): + x = paddle.to_tensor([0.2, 0.3, 0.4]) + paddle.nn.functional.gumbel_softmax(x, axis=1.1) + + self.assertRaises(ValueError, test_argument2) + + paddle.enable_static() + + def test_dtype(): + with paddle.static.program_guard(paddle.static.Program()): + x_int32 = paddle.fluid.data( + name='x_int32', shape=[2, 3], dtype='int32') + paddle.nn.functional.gumbel_softmax(x_int32) + + self.assertRaises(TypeError, test_dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index feacbeeea7..7965b362b9 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -44,6 +44,7 @@ from .activation import tanhshrink # noqa: F401 from .activation import thresholded_relu # noqa: F401 from .activation import log_softmax # noqa: F401 from .activation import glu # noqa: F401 +from .activation import gumbel_softmax # noqa: F401 from .common import dropout # noqa: F401 from .common import dropout2d # noqa: F401 from .common import dropout3d # noqa: F401 @@ -147,6 +148,7 @@ __all__ = [ #noqa 'thresholded_relu', 'log_softmax', 'glu', + 'gumbel_softmax', 'diag_embed', 'sequence_mask', 'dropout', diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 7228c903d6..67be64c01c 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1329,3 +1329,78 @@ def glu(x, axis=-1, name=None): gate = sigmoid(b, name=name) out = paddle.multiply(a, gate, name=name) return out + + +def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None): + r""" + Samples from the Gumbel-Softmax distribution and optionally discretizes. + temperature is denoted by t. The calculation process is as follows: + + First, generate gumbel noise: + + .. math:: + + G_i = -log(-log(U_i)), U_i \sim U(0,1) + + Second, add noise to ``x``: + + .. math:: + + v = [x_1 + G_1,...,x_n + G_n] + + Finally, calculate gumbel_softmax and generate samples: + + .. math:: + gumbel\_softmax(v_i)=\frac{e^{v_i/t}}{\sum_{j=1}^n{e^{v_j/t}}},i=1,2,3...n + + Parameters: + x (Tensor): An N-D Tensor, the first N - 1 dimensions index into a batch + of independent distributions and the last dimension represents + a vector of probabilities with datatype float32, float64. + temperature (float, optional): non-negative scalar temperature. + Default is 1.0. + hard (bool, optional): if True, the returned samples will be discretized as + one-hot vectors, but will be differentiated as if it is the soft sample + in autograd. Default is False. + axis (int, optional): The axis along will be calculated softmax value. + Default is -1. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Sampled tensor of same shape as ``x`` from the Gumbel-Softmax distribution. + If ``hard = True``, the returned samples will be one-hot, otherwise they will be + probability distributions that sum to 1 across ``axis``. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + logits = paddle.randn([4, 6]) + temperature = 0.01 + gumbel_softmax = F.gumbel_softmax(logits, temperature) + print(gumbel_softmax) + # out's value is as follows: + # [[0.00000001, 1. , 0.00000000, 0.00000000, 0.00000006, 0.00000000], + # [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 1. ], + # [0.00000062, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.99999940], + # [0.00000000, 0.00000000, 0.00000000, 0.00001258, 0.99998736, 0.00000000]] + + """ + if in_dygraph_mode(): + return _C_ops.gumbel_softmax(x, 'temperature', temperature, 'hard', + hard, 'axis', axis) + + helper = LayerHelper("gumbel_softmax", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gumbel_softmax') + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='gumbel_softmax', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'temperature': temperature, + 'hard': hard, + 'axis': axis}) + return out -- GitLab