未验证 提交 7788b65e 编写于 作者: Y yunyaoXYY 提交者: GitHub

[AMP OP&Test] Register FP16 for multinomial. (#52107)

* add FP16 for multinomial

* fix input data

* update code

* fix FP16

* fix code
上级 e5a0dc31
...@@ -27,6 +27,7 @@ namespace cub = hipcub; ...@@ -27,6 +27,7 @@ namespace cub = hipcub;
#endif #endif
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -41,25 +42,25 @@ namespace cub = hipcub; ...@@ -41,25 +42,25 @@ namespace cub = hipcub;
namespace phi { namespace phi {
template <typename T> template <typename T, typename MT>
__global__ void NormalizeProbability(T* norm_probs, __global__ void NormalizeProbability(MT* norm_probs,
const T* in_data, const T* in_data,
T* sum_rows, MT* sum_rows,
int64_t num_distributions, int64_t num_distributions,
int64_t num_categories) { int64_t num_categories) {
int id = threadIdx.x + blockIdx.x * blockDim.x + int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x; blockIdx.y * gridDim.x * blockDim.x;
if (id < num_distributions * num_categories) { if (id < num_distributions * num_categories) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
in_data[id] >= 0.0, static_cast<MT>(in_data[id]) >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f.", "The input of multinomial distribution should be >= 0, but got %f.",
in_data[id]); static_cast<MT>(in_data[id]));
int64_t row_id = id / num_categories; int64_t row_id = id / num_categories;
PADDLE_ENFORCE(sum_rows[row_id] > 0.0, PADDLE_ENFORCE(sum_rows[row_id] > 0.0,
"The sum of one multinomial distribution probability should " "The sum of one multinomial distribution probability should "
"be > 0, but got %f.", "be > 0, but got %f.",
sum_rows[row_id]); sum_rows[row_id]);
norm_probs[id] = in_data[id] / sum_rows[row_id]; norm_probs[id] = static_cast<MT>(in_data[id]) / sum_rows[row_id];
} }
} }
...@@ -131,6 +132,8 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -131,6 +132,8 @@ void MultinomialKernel(const Context& dev_ctx,
const Scalar& num_samples, const Scalar& num_samples,
bool replacement, bool replacement,
DenseTensor* out) { DenseTensor* out) {
using MT = typename kps::details::MPTypeTrait<T>::Type;
auto int_num_samples = num_samples.to<int>(); auto int_num_samples = num_samples.to<int>();
auto* in_data = x.data<T>(); auto* in_data = x.data<T>();
int64_t* out_data = dev_ctx.template Alloc<int64_t>(out); int64_t* out_data = dev_ctx.template Alloc<int64_t>(out);
...@@ -138,7 +141,6 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -138,7 +141,6 @@ void MultinomialKernel(const Context& dev_ctx,
int64_t dim_size = in_dims.size(); int64_t dim_size = in_dims.size();
const int64_t num_categories = in_dims[dim_size - 1]; const int64_t num_categories = in_dims[dim_size - 1];
const int64_t num_distributions = dim_size > 1 ? in_dims[dim_size - 2] : 1; const int64_t num_distributions = dim_size > 1 ? in_dims[dim_size - 2] : 1;
// If replacement is False, it's not a replaceable sample. Every category // If replacement is False, it's not a replaceable sample. Every category
// can be used only once. // can be used only once.
if (!replacement) { if (!replacement) {
...@@ -153,11 +155,11 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -153,11 +155,11 @@ void MultinomialKernel(const Context& dev_ctx,
for (size_t j = 0; j < num_categories; ++j) { for (size_t j = 0; j < num_categories; ++j) {
T weight = cpu_in_data[i * num_categories + j]; T weight = cpu_in_data[i * num_categories + j];
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
weight, static_cast<MT>(weight),
0, 0,
errors::InvalidArgument( errors::InvalidArgument(
"Each element of multinomial'input must >= 0, but got %f.", "Each element of multinomial'input must >= 0, but got %f.",
weight)); static_cast<MT>(weight)));
if (weight == static_cast<T>(0)) { if (weight == static_cast<T>(0)) {
zero_num++; zero_num++;
} }
...@@ -174,8 +176,8 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -174,8 +176,8 @@ void MultinomialKernel(const Context& dev_ctx,
// Refer to [gumbel softmax algorithm] // Refer to [gumbel softmax algorithm]
DenseTensor rand = EmptyLike<T, Context>(dev_ctx, x); DenseTensor rand = EmptyLike<T, Context>(dev_ctx, x);
T* rand_data = rand.data<T>(); T* rand_data = rand.data<T>();
funcs::uniform_distribution<T> dist; funcs::uniform_distribution<MT> dist;
funcs::exponential_transform<T> trans(1.0); funcs::exponential_transform<MT> trans(1.0);
funcs::distribution_and_transform<T>(dev_ctx, &rand, dist, trans); funcs::distribution_and_transform<T>(dev_ctx, &rand, dist, trans);
funcs::ForRange<Context> for_range(dev_ctx, x.numel()); funcs::ForRange<Context> for_range(dev_ctx, x.numel());
...@@ -200,61 +202,60 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -200,61 +202,60 @@ void MultinomialKernel(const Context& dev_ctx,
// sum_row_data: sum of each row // sum_row_data: sum of each row
DenseTensor sum_rows_tensor; DenseTensor sum_rows_tensor;
sum_rows_tensor.Resize({num_distributions}); sum_rows_tensor.Resize({num_distributions});
auto* sum_rows_data = dev_ctx.template Alloc<T>(&sum_rows_tensor); auto* sum_rows_data = dev_ctx.template Alloc<MT>(&sum_rows_tensor);
auto& place = *dev_ctx.eigen_device(); auto& place = *dev_ctx.eigen_device();
if (num_distributions == 1) { if (num_distributions == 1) {
auto eigen_input = EigenVector<T>::Flatten(x); auto eigen_input = EigenVector<T>::Flatten(x);
auto eigen_sum_rows = EigenVector<T>::Flatten(sum_rows_tensor); auto eigen_sum_rows = EigenVector<MT>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) = eigen_sum_rows.device(place) =
eigen_input.sum(Eigen::DSizes<int, 1>(1)) eigen_input.sum(Eigen::DSizes<int, 1>(1))
.template cast<MT>()
.eval() .eval()
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0])); .template cast<MT>()
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]))
.template cast<MT>();
} else { } else {
auto eigen_input = EigenMatrix<T>::From(x); auto eigen_input = EigenMatrix<T>::From(x);
auto eigen_sum_rows = EigenVector<T>::Flatten(sum_rows_tensor); auto eigen_sum_rows = EigenVector<MT>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1)); eigen_sum_rows.device(place) =
eigen_input.sum(Eigen::DSizes<int, 1>(1)).template cast<MT>();
} }
// Normalize row of each distribution to get the probability in range [0, // Normalize row of each distribution to get the probability in range [0,
// 1]. // 1].
// norm_probs_data: probability of the distribution // norm_probs_data: probability of the distribution
DenseTensor norm_probs_tensor; DenseTensor norm_probs_tensor;
norm_probs_tensor.Resize({num_distributions, num_categories}); norm_probs_tensor.Resize({num_distributions, num_categories});
auto* norm_probs_data = dev_ctx.template Alloc<T>(&norm_probs_tensor); auto* norm_probs_data = dev_ctx.template Alloc<MT>(&norm_probs_tensor);
// number of threads in a block is min(num_categories, 512) // number of threads in a block is min(num_categories, 512)
int block_size = num_categories < 512 ? num_categories : 512; int block_size = num_categories < 512 ? num_categories : 512;
dim3 block_norm(block_size); dim3 block_norm(block_size);
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1); dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability<T>
NormalizeProbability<T, MT>
<<<grid_norm, block_norm, 0, dev_ctx.stream()>>>(norm_probs_data, <<<grid_norm, block_norm, 0, dev_ctx.stream()>>>(norm_probs_data,
in_data, in_data,
sum_rows_data, sum_rows_data,
num_distributions, num_distributions,
num_categories); num_categories);
// Get cumulative probability of each distribution. It's the same function // Get cumulative probability of each distribution. It's the same function
// of ``cumsum`` op. // of ``cumsum`` op.
DenseTensor cumulative_probs_tensor; DenseTensor cumulative_probs_tensor;
cumulative_probs_tensor.Resize({num_distributions, num_categories}); cumulative_probs_tensor.Resize({num_distributions, num_categories});
auto* cumulative_probs_data = auto* cumulative_probs_data =
dev_ctx.template Alloc<T>(&cumulative_probs_tensor); dev_ctx.template Alloc<MT>(&cumulative_probs_tensor);
// 'phi::funcs::InclusiveScan' has higher accuracy than // 'phi::funcs::InclusiveScan' has higher accuracy than
// 'thrust::inclusive_scan' // 'thrust::inclusive_scan'
funcs::InclusiveScan<T, std::plus<T>>( funcs::InclusiveScan<MT, std::plus<MT>>(
/*in*/ norm_probs_data, /*in*/ norm_probs_data,
/*out*/ cumulative_probs_data, /*out*/ cumulative_probs_data,
/*outer_dim*/ static_cast<size_t>(num_distributions), /*outer_dim*/ static_cast<size_t>(num_distributions),
/*mid_dim*/ static_cast<size_t>(num_categories), /*mid_dim*/ static_cast<size_t>(num_categories),
/*inner_dim*/ static_cast<size_t>(1), /*inner_dim*/ static_cast<size_t>(1),
/*init*/ static_cast<T>(0), /*init*/ static_cast<T>(0),
std::plus<T>(), std::plus<MT>(),
/*reverse=*/false, /*reverse=*/false,
dev_ctx); dev_ctx);
// Sample the multinomial distributions. // Sample the multinomial distributions.
dim3 block(128); dim3 block(128);
int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
...@@ -269,7 +270,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -269,7 +270,7 @@ void MultinomialKernel(const Context& dev_ctx,
uint64_t increment = curand4_loop_times * 4; uint64_t increment = curand4_loop_times * 4;
auto seed_offset = gen_cuda->IncrementOffset(increment); auto seed_offset = gen_cuda->IncrementOffset(increment);
sampleMultinomialWithReplacement<T> sampleMultinomialWithReplacement<MT>
<<<grid, block, 0, dev_ctx.stream()>>>(int_num_samples, <<<grid, block, 0, dev_ctx.stream()>>>(int_num_samples,
out_data, out_data,
num_distributions, num_distributions,
...@@ -286,6 +287,7 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only ...@@ -286,6 +287,7 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultinomialKernel, phi::MultinomialKernel,
phi::dtype::float16,
float, float,
double) { double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
......
...@@ -104,6 +104,68 @@ class TestMultinomialOp3(TestMultinomialOp): ...@@ -104,6 +104,68 @@ class TestMultinomialOp3(TestMultinomialOp):
) )
# FP16 OP
class TestMultinomialFP16Op(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "multinomial"
self.dtype = np.float16
self.init_data()
self.inputs = {"X": self.input_np}
def init_data(self):
# input probability is a vector, and replacement is True
self.input_np = np.random.rand(4).astype(self.dtype)
self.outputs = {"Out": np.zeros(100000).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def sample_output(self, out):
return sample_output_one_dimension(out, 4)
def verify_output(self, outs):
# normalize the input to get the probability
prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True)
sample_prob = self.sample_output(np.array(outs[0]))
np.testing.assert_allclose(
sample_prob,
prob,
rtol=0,
atol=0.01,
err_msg='sample_prob: ' + str(sample_prob) + '\nprob: ' + str(prob),
)
class TestMultinomialFP16Op2(TestMultinomialFP16Op):
def init_data(self):
# input probability is a matrix
self.input_np = np.random.rand(3, 4).astype(self.dtype)
self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}
def sample_output(self, out):
return sample_output_two_dimension(out, [3, 4])
class TestMultinomialFP16Op3(TestMultinomialFP16Op):
def init_data(self):
# replacement is False. number of samples must be less than number of categories.
self.input_np = np.random.rand(1000).astype(self.dtype)
self.outputs = {"Out": np.zeros(100).astype("int64")}
self.attrs = {"num_samples": 100, "replacement": False}
def verify_output(self, outs):
out = np.array(outs[0])
unique_out = np.unique(out)
self.assertEqual(
len(unique_out),
100,
"replacement is False. categories can't be sampled repeatedly",
)
class TestMultinomialApi(unittest.TestCase): class TestMultinomialApi(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
# input probability is a vector, and replacement is True # input probability is a vector, and replacement is True
......
...@@ -187,7 +187,9 @@ def multinomial(x, num_samples=1, replacement=False, name=None): ...@@ -187,7 +187,9 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.multinomial(x, num_samples, replacement) return _C_ops.multinomial(x, num_samples, replacement)
else: else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "multinomial") check_variable_and_dtype(
x, "x", ["uint16", "float16", "float32", "float64"], "multinomial"
)
helper = LayerHelper("multinomial", **locals()) helper = LayerHelper("multinomial", **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册