未验证 提交 e0007f31 编写于 作者: D denglianbin 提交者: GitHub

【Hackathon No.46】为 Paddle gumbel_softmax 算子实现 float16 数据类型支持 (#50923)

* finish task

* fix some question.

* fix error

* change unittest:zeroDim.
上级 b94fe95a
......@@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(gumbel_softmax_grad,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxGradKernel,
phi::dtype::float16,
float,
double) {}
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/gumbel_softmax_kernel.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h"
......@@ -116,17 +116,18 @@ struct OneHotGenerator<GPUContext, T> {
}
};
template <typename T>
template <typename T, typename MPType>
__global__ void AddGumbelNoiseCUDAKernel(const T* input_data,
T* output_data,
T* noise,
MPType* noise,
const float temperature,
int64_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int step = blockDim.x * gridDim.x;
for (int64_t i = index; i < n; i += step) {
T gumbel_noise = -log(-log(noise[i]));
output_data[i] = (gumbel_noise + input_data[i]) / temperature;
MPType gumbel_noise = -log(-log(noise[i]));
output_data[i] = static_cast<T>(
(gumbel_noise + static_cast<MPType>(input_data[i])) / temperature);
}
}
......@@ -141,7 +142,8 @@ struct GumbleNoiseGenerator<GPUContext, T> {
DenseTensor random_tensor;
int64_t size = size_to_axis * size_from_axis;
random_tensor.Resize(make_ddim({size}));
T* random_data = ctx.template Alloc<T>(&random_tensor);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType* random_data = ctx.template Alloc<MPType>(&random_tensor);
// generate gumbel noise
int device_id = ctx.GetPlace().GetDeviceId();
......@@ -152,10 +154,11 @@ struct GumbleNoiseGenerator<GPUContext, T> {
uint64_t offset = seed_offset.second;
thrust::counting_iterator<int64_t> index_sequence_begin(0);
thrust::transform(index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(0.00001, 1, seed, size * offset));
thrust::transform(
index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<MPType>(random_data),
UniformCUDAGenerator<MPType>(0.00001, 1, seed, size * offset));
// add gumbel noise to X
const int thread_size = 512;
......@@ -168,5 +171,10 @@ struct GumbleNoiseGenerator<GPUContext, T> {
} // namespace phi
#endif
PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxKernel,
phi::dtype::float16,
float,
double) {}
......@@ -53,10 +53,13 @@ class TestGumbelSoftmaxOp(OpTest):
class TestGumbelSoftmax_ZeroDim(OpTest):
def init_attrs(self):
self.dtype = "float64"
def setUp(self):
self.op_type = "gumbel_softmax"
self.python_api = F.gumbel_softmax
self.dtype = "float64"
self.init_attrs()
x = np.random.uniform(0.1, 1, []).astype(self.dtype)
out = np.array(1.0).astype(self.dtype)
......@@ -103,6 +106,43 @@ class TestGumbelSoftmaxOp5(TestGumbelSoftmaxOp):
self.dtype = "float64"
class TestGumbelSoftmax_ZeroDim_FP16OP(TestGumbelSoftmax_ZeroDim):
def init_attrs(self):
self.dtype = np.float16
class TestGumbelSoftmaxFP16OP2(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10]
self.attrs = {"hard": True, "axis": 0}
self.count_expected = 10
self.dtype = np.float16
class TestGumbelSoftmaxFP16OP3(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [100]
self.attrs = {"hard": True, "axis": -1}
self.count_expected = 1
self.dtype = np.float16
class TestGumbelSoftmaxFP16OP4(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10, 5]
self.attrs = {"hard": True, "axis": -1}
self.count_expected = 200
self.dtype = np.float16
class TestGumbelSoftmaxFP16OP5(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10, 5]
self.attrs = {"hard": True, "axis": 1}
self.count_expected = 100
self.dtype = np.float16
class TestGumbelSoftmaxOpSampleDistribution(OpTest):
def softmax(self, x):
x_row_max = x.max(axis=-1)
......
......@@ -1664,7 +1664,7 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
Parameters:
x (Tensor): An N-D Tensor, the first N - 1 dimensions index into a batch
of independent distributions and the last dimension represents
a vector of probabilities with datatype float32, float64.
a vector of probabilities with datatype float16, float32, float64.
temperature (float, optional): non-negative scalar temperature.
Default is 1.0.
hard (bool, optional): if True, the returned samples will be discretized as
......@@ -1705,7 +1705,9 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
)
helper = LayerHelper("gumbel_softmax", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gumbel_softmax')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'gumbel_softmax'
)
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='gumbel_softmax',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册