未验证 提交 18eda6c3 编写于 作者: Y YuanRisheng 提交者: GitHub

Add New OP: gumbel_softmax (#35506)

* Add New Op: gumbel_softmax

* Add New Op: gumbel_softmax

* Add New Op: gumbel_softmax (amend)

* add __main__ function in unit test

* fix bugs when test in windows ci

* update en docs

* delete reletive error in unit test

* delete relative error in unit test

* set hard=True in unit test
上级 3218075d
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gumbel_softmax_op.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
namespace paddle {
namespace operators {
class GumbelSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
return UnaryOpUnchangedInferShapeCheckAxis(ctx);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class GumbelSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) An N-D Tensor, N >= 1,"
"The first N - 1 dimensions index into a batch of independent "
"distributions "
"and the last dimension represents a vector of probabilities for "
"each class.");
AddOutput("Out", "The sampled tensor with the same shape as X.");
AddAttr<float>("temperature",
"(float, default 1.0) non-negative scalar temperature.")
.SetDefault(1.0);
AddAttr<bool>(
"hard",
"(bool, default false) "
"if True, the returned samples will be discretized as one-hot vectors, "
"but will be differentiated as if it is the soft sample in autograd.")
.SetDefault(false);
AddAttr<int>("axis",
"(int, default -1)"
"The dimension index of Input(x) to perform gumbel_softmax.")
.SetDefault(-1);
AddComment(R"DOC(
GumbelSoftmax Operator.
Samples from the Gumbel-Softmax distribution and optionally discretizes.
)DOC");
}
};
class GumbelSoftmaxGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "gumbel_softmax_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "gumbel_softmax_grad");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
platform::errors::InvalidArgument("Input(Out) and its gradients "
"should have the same shape."));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
};
template <typename T>
class GumbelSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("gumbel_softmax_grad");
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gumbel_softmax, ops::GumbelSoftmaxOp,
ops::GumbelSoftmaxOpMaker,
ops::GumbelSoftmaxGradOpMaker<paddle::framework::OpDesc>,
ops::GumbelSoftmaxGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
gumbel_softmax,
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
gumbel_softmax_grad,
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/gumbel_softmax_op.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/memcpy.h"
namespace paddle {
namespace operators {
template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;
template <typename T>
struct UniformCUDAGenerator {
T min_, max_;
unsigned int seed_;
unsigned int offset_ = 0;
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed)
: min_(min), max_(max), seed_(seed) {}
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed,
unsigned int offset)
: min_(min), max_(max), seed_(seed), offset_(offset) {}
HOSTDEVICE T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
return dist(rng);
}
};
template <typename T, size_t BlockDim>
__global__ void OneHotCUDAKernel(const int64_t height, const int64_t width,
const int64_t size_out_axis, const T init,
const T* in, T* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int64_t idx = blockIdx.x; idx < height; idx += gridDim.x) {
KeyValuePair<int, T> kv_pair = {-1, init};
int h = idx / size_out_axis;
int w = idx % size_out_axis;
cub::ArgMax reducer;
for (int k = threadIdx.x; k < width; k += blockDim.x) {
kv_pair = reducer(
{k, in[h * width * size_out_axis + k * size_out_axis + w]}, kv_pair);
}
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
if (threadIdx.x == 0) {
int index = static_cast<int>(kv_pair.key);
out[h * width * size_out_axis + index * size_out_axis + w] = 1;
}
__syncthreads();
}
}
template <typename T>
struct OneHotGenerator<platform::CUDADeviceContext, T> {
static void Transform(const platform::CUDADeviceContext& context,
const Tensor& X, Tensor* Out, int axis) {
const int size_to_axis = SizeToAxis(axis, X.dims());
const int size_from_axis = SizeFromAxis(axis, X.dims());
const int size_out_axis = SizeOutAxis(axis, X.dims());
constexpr int thread_size = 512;
int64_t max_grid_dimx = context.GetCUDAMaxGridDimSize().x;
int64_t height = size_to_axis * size_out_axis;
int block_size = height < max_grid_dimx ? height : max_grid_dimx;
Tensor input_tensor;
input_tensor.mutable_data<T>(Out->dims(), platform::CUDAPlace());
TensorCopy(*Out, context.GetPlace(), &input_tensor);
math::set_constant(context, Out, 0.0);
OneHotCUDAKernel<
T, thread_size><<<block_size, thread_size, 0, context.stream()>>>(
height, size_from_axis / size_out_axis, size_out_axis,
std::numeric_limits<T>::lowest(), input_tensor.data<T>(),
Out->data<T>());
}
};
template <typename T>
__global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data,
T* noise, const float temperature,
int64_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int step = blockDim.x * gridDim.x;
for (int64_t i = index; i < n; i += step) {
T gumbel_noise = -log(-log(noise[i]));
output_data[i] = (gumbel_noise + input_data[i]) / temperature;
}
}
template <typename T>
struct GumbleNoiseGenerator<platform::CUDADeviceContext, T> {
static void Transform(const platform::CUDADeviceContext& context,
const T* input_data, T* output_data, int size_to_axis,
int size_from_axis, const float temperature) {
Tensor random_tensor;
int64_t size = size_to_axis * size_from_axis;
T* random_data =
random_tensor.mutable_data<T>({size}, platform::CUDAPlace());
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
const unsigned int seed = std::random_device()();
// generate gumbel noise
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy()) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(0.00001, 1, seed_offset.first, gen_offset));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(0.00001, 1, seed));
}
// add gumbel noise to X
const int thread_size = 512;
int64_t block_size = (size + thread_size) / thread_size;
AddGumbelNoiseCUDAKernel<
T><<<block_size, thread_size, 0, context.stream()>>>(
input_data, output_data, random_data, temperature, size);
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
gumbel_softmax, ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
gumbel_softmax_grad,
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D>;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename DeviceContext, typename T, int64_t Rank>
struct ArgMaxFunctor {
void operator()(const DeviceContext& ctx, const Tensor& in,
Tensor* index_tensor, const int64_t& axis) {
auto in_eigen = EigenTensor<T, Rank>::From(in, in.dims());
auto index_eigen = EigenTensor<int, Rank - 1>::From(*index_tensor);
index_eigen = in_eigen.argmax(axis).template cast<int>();
}
};
template <typename DeviceContext, typename T>
struct GumbleNoiseGenerator;
template <typename DeviceContext, typename T>
struct OneHotGenerator;
template <typename T>
struct GumbleNoiseGenerator<platform::CPUDeviceContext, T> {
static void Transform(const platform::CPUDeviceContext& context,
const T* input_data, T* output_data, int size_to_axis,
int size_from_axis, const float temperature) {
// generate uniform random number
const int size = size_to_axis * size_from_axis;
std::uniform_real_distribution<T> dist(0.00001, 1);
const int seed = std::random_device()();
auto engine = paddle::framework::GetCPURandomEngine(seed);
Tensor random_tensor;
auto* random_data =
random_tensor.mutable_data<T>({size}, platform::CPUPlace());
for (int64_t i = 0; i < size; ++i) {
random_data[i] = dist(*engine);
}
// generate gumbel noise
framework::DDim dim_2d{size_to_axis, size_from_axis};
auto gumbel_noise_eigen = EigenMatrix<T>::From(random_tensor, dim_2d);
gumbel_noise_eigen = -(((-(gumbel_noise_eigen.log())).log()));
// add noise
for (int64_t i = 0; i < size_to_axis * size_from_axis; i++) {
output_data[i] = (input_data[i] + random_data[i]) / temperature;
}
}
};
template <typename T>
struct OneHotGenerator<platform::CPUDeviceContext, T> {
static void Transform(const platform::CPUDeviceContext& context,
const Tensor& X, Tensor* Out, int axis) {
Tensor index;
std::vector<int> index_dim;
const auto rank = X.dims().size();
const int size_to_axis = SizeToAxis(axis, X.dims());
const int size_from_axis = SizeFromAxis(axis, X.dims());
const int size_out_axis = SizeOutAxis(axis, X.dims());
for (int i = 0; i < X.dims().size(); i++) {
if (i != axis) index_dim.push_back(X.dims().Get()[i]);
}
DDim index_ddim(index_dim.data(), rank - 1);
index.Resize(index_ddim);
auto* index_data = index.mutable_data<int>(context.GetPlace());
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMaxFunctor<platform::CPUDeviceContext, T, rank> functor##rank; \
functor##rank(context, *Out, &index, axis);
switch (Out->dims().size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(Out->dims().size(), 6,
platform::errors::InvalidArgument(
"gumbel_softmax operator doesn't supports "
"tensors whose ranks are greater "
"than 6 in CPU mode."));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
math::set_constant(context, Out, 0.0);
for (int i = 0; i < size_to_axis; i++) {
for (int j = 0; j < size_out_axis; j++) {
*(Out->data<T>() + i * size_from_axis + j +
index_data[i * size_out_axis + j] * size_out_axis) = 1.0;
}
}
}
};
template <typename DeviceContext, typename T>
class GumbelSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = X->dims()[axis];
const bool is_hard = context.Attr<bool>("hard");
const float temperature = context.Attr<float>("temperature");
PADDLE_ENFORCE_GT(temperature, 0,
platform::errors::InvalidArgument(
"The temperature must be greater than 0. But "
"received temperature = %f",
temperature));
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
if (Out->numel() == 0) {
return;
}
const int size_to_axis = SizeToAxis(axis, X->dims());
const int size_from_axis = SizeFromAxis(axis, X->dims());
Tensor X_noise_2d, Out_2d;
X_noise_2d.Resize({size_to_axis, size_from_axis});
Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis});
// generate gumbel noise and add it to X
auto* x_noise_data = X_noise_2d.mutable_data<T>(context.GetPlace());
GumbleNoiseGenerator<DeviceContext, T>::Transform(
context.template device_context<DeviceContext>(), X->data<T>(),
x_noise_data, size_to_axis, size_from_axis, temperature);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), axis_dim, &X_noise_2d,
&Out_2d);
#else
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), axis_dim, &X_noise_2d,
&Out_2d);
#endif
if (is_hard) {
OneHotGenerator<DeviceContext, T>::Transform(
context.template device_context<DeviceContext>(), *X, Out, axis);
}
}
};
template <typename DeviceContext, typename T>
class GumbelSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int rank = dX->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = dX->dims()[axis];
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
if (dX->numel() == 0) {
return;
}
const int size_to_axis = SizeToAxis(axis, dX->dims());
const int size_from_axis = SizeFromAxis(axis, dX->dims());
Tensor dX_2d, Out_2d, dOut_2d;
dX_2d.ShareDataWith(*dX).Resize({size_to_axis, size_from_axis});
Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis});
dOut_2d.ShareDataWith(*dOut).Resize({size_to_axis, size_from_axis});
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), axis_dim, &Out_2d,
&dOut_2d, &dX_2d);
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
paddle.enable_static()
class TestGumbelSoftmaxOp(OpTest):
def init_attrs(self):
self.shape = [20, 10]
self.attrs = {"hard": True, "axis": -1}
self.count_expected = 20
self.dtype = "float64"
def verify_output(self, outs):
out_np = np.array(outs[0])
out_np.shape = self.shape
self.assertTrue(list(out_np.shape) == self.shape)
self.assertEqual(out_np.sum(), self.count_expected)
def setUp(self):
self.op_type = "gumbel_softmax"
self.init_attrs()
np.random.seed(0)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.zeros(self.shape).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestGumbelSoftmaxOp2(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10]
self.attrs = {"hard": True, "axis": 0}
self.count_expected = 10
self.dtype = "float64"
class TestGumbelSoftmaxOp3(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [100]
self.attrs = {"hard": True, "axis": -1}
self.count_expected = 1
self.dtype = "float64"
class TestGumbelSoftmaxOp4(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10, 5]
self.attrs = {"hard": True, "axis": -1}
self.count_expected = 200
self.dtype = "float64"
class TestGumbelSoftmaxOp5(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10, 5]
self.attrs = {"hard": True, "axis": 1}
self.count_expected = 100
self.dtype = "float64"
class TestGumbelSoftmaxOpSampleDistribution(OpTest):
def softmax(self, x):
x_row_max = x.max(axis=-1)
x_row_max = x_row_max.reshape(list(x.shape)[:-1] + [1])
x = x - x_row_max
x_exp = np.exp(x)
x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1] + [1])
softmax = x_exp / x_exp_row_sum
return softmax
def init_attrs(self):
self.shape = [100, 3]
self.attrs = {"hard": True, "axis": -1}
self.counts = np.zeros(self.shape).astype(self.dtype)
self._cpu_only = True
def accumulate_output(self, outs):
out_np = np.array(outs)
out_np = out_np.reshape(self.shape)
self.counts = np.sum(out_np, axis=0)
def setUp(self):
self.op_type = "gumbel_softmax"
self.init_attrs()
single_x = np.array([0.2, 0.3, 0.5])
batch_x = np.ones(self.shape) * single_x
out = np.zeros(self.shape).astype(self.dtype)
self.probs = self.softmax(single_x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(batch_x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output_customized(self.accumulate_output)
# Experiment should result in batch num .
self.assertEqual(self.counts.sum(), self.shape[0])
# Treat the probability from softmax as
# the probability of binomial distribution.
# Samples from gumbel softmax meet this binomial distribution.
# Construct statistics z for samples and
# z is approximately N(0,1) for unbiased count
expected = self.probs * self.shape[0]
z = (self.counts - expected) / np.sqrt((expected * (1 - self.probs)))
# A (lazy) approximate 99% two-sided test:
# occurs with prob alpha~>=0.01 if unbiased
self.assertLess(np.max(np.abs(z)).item(), 2.58)
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestGumbelSoftmaxOpGrad(unittest.TestCase):
def init_attrs(self):
self.shape = [20, 10]
self.dtype = "float64"
def setUp(self):
self.init_attrs()
np.random.seed(0)
self.x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
def test_dygraph_check(self):
paddle.disable_static()
x_hard = paddle.to_tensor(self.x_np, stop_gradient=False)
x_soft = paddle.to_tensor(self.x_np, stop_gradient=False)
out_hard = paddle.nn.functional.gumbel_softmax(x_hard, hard=True)
out_soft = paddle.nn.functional.gumbel_softmax(x_soft, hard=False)
out_hard.sum().backward()
out_soft.sum().backward()
self.assertEqual(
np.allclose(x_hard.grad.numpy(), x_soft.grad.numpy()), True)
paddle.enable_static()
class TestGumbelSoftmaxAPI(unittest.TestCase):
def setUp(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32)
self.count_expected = 24
self.place = paddle.CUDAPlace(0) \
if paddle.fluid.core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_check_api(self):
# test static api
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=self.x_shape)
y = paddle.nn.functional.gumbel_softmax(x, hard=True)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
out_np = np.array(out[0])
self.assertEqual(out_np.sum(), self.count_expected)
# test dygrapg api
paddle.disable_static()
x = paddle.to_tensor(self.x)
y = paddle.nn.functional.gumbel_softmax(x, hard=True)
out_np = np.array(y)
self.assertEqual(out_np.sum(), self.count_expected)
paddle.enable_static()
class TestGumbelSoftmaxOpError(unittest.TestCase):
def test_errors(self):
paddle.disable_static()
def test_Variable():
x1 = fluid.create_lod_tensor(
np.zeros((100, 784)), [[10, 10, 10, 70]], fluid.CPUPlace())
paddle.nn.functional.gumbel_softmax(x1)
self.assertRaises(ValueError, test_Variable)
def test_Variable2():
x1 = np.zeros((100, 784))
paddle.nn.functional.gumbel_softmax(x1)
self.assertRaises(ValueError, test_Variable2)
def test_argument1():
x = paddle.to_tensor([0.2, 0.3, 0.4])
paddle.nn.functional.gumbel_softmax(x, temperature=-1)
self.assertRaises(ValueError, test_argument1)
def test_argument2():
x = paddle.to_tensor([0.2, 0.3, 0.4])
paddle.nn.functional.gumbel_softmax(x, axis=1.1)
self.assertRaises(ValueError, test_argument2)
paddle.enable_static()
def test_dtype():
with paddle.static.program_guard(paddle.static.Program()):
x_int32 = paddle.fluid.data(
name='x_int32', shape=[2, 3], dtype='int32')
paddle.nn.functional.gumbel_softmax(x_int32)
self.assertRaises(TypeError, test_dtype)
if __name__ == '__main__':
unittest.main()
......@@ -44,6 +44,7 @@ from .activation import tanhshrink # noqa: F401
from .activation import thresholded_relu # noqa: F401
from .activation import log_softmax # noqa: F401
from .activation import glu # noqa: F401
from .activation import gumbel_softmax # noqa: F401
from .common import dropout # noqa: F401
from .common import dropout2d # noqa: F401
from .common import dropout3d # noqa: F401
......@@ -147,6 +148,7 @@ __all__ = [ #noqa
'thresholded_relu',
'log_softmax',
'glu',
'gumbel_softmax',
'diag_embed',
'sequence_mask',
'dropout',
......
......@@ -1329,3 +1329,78 @@ def glu(x, axis=-1, name=None):
gate = sigmoid(b, name=name)
out = paddle.multiply(a, gate, name=name)
return out
def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
r"""
Samples from the Gumbel-Softmax distribution and optionally discretizes.
temperature is denoted by t. The calculation process is as follows:
First, generate gumbel noise:
.. math::
G_i = -log(-log(U_i)), U_i \sim U(0,1)
Second, add noise to ``x``:
.. math::
v = [x_1 + G_1,...,x_n + G_n]
Finally, calculate gumbel_softmax and generate samples:
.. math::
gumbel\_softmax(v_i)=\frac{e^{v_i/t}}{\sum_{j=1}^n{e^{v_j/t}}},i=1,2,3...n
Parameters:
x (Tensor): An N-D Tensor, the first N - 1 dimensions index into a batch
of independent distributions and the last dimension represents
a vector of probabilities with datatype float32, float64.
temperature (float, optional): non-negative scalar temperature.
Default is 1.0.
hard (bool, optional): if True, the returned samples will be discretized as
one-hot vectors, but will be differentiated as if it is the soft sample
in autograd. Default is False.
axis (int, optional): The axis along will be calculated softmax value.
Default is -1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Sampled tensor of same shape as ``x`` from the Gumbel-Softmax distribution.
If ``hard = True``, the returned samples will be one-hot, otherwise they will be
probability distributions that sum to 1 across ``axis``.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
logits = paddle.randn([4, 6])
temperature = 0.01
gumbel_softmax = F.gumbel_softmax(logits, temperature)
print(gumbel_softmax)
# out's value is as follows:
# [[0.00000001, 1. , 0.00000000, 0.00000000, 0.00000006, 0.00000000],
# [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 1. ],
# [0.00000062, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.99999940],
# [0.00000000, 0.00000000, 0.00000000, 0.00001258, 0.99998736, 0.00000000]]
"""
if in_dygraph_mode():
return _C_ops.gumbel_softmax(x, 'temperature', temperature, 'hard',
hard, 'axis', axis)
helper = LayerHelper("gumbel_softmax", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gumbel_softmax')
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='gumbel_softmax',
inputs={'X': x},
outputs={'Out': out},
attrs={'temperature': temperature,
'hard': hard,
'axis': axis})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册