From c66eec75c9de7b78eb594a35580c23bad4172e50 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Fri, 25 Sep 2020 08:45:34 +0000 Subject: [PATCH] support num_distribution different multinomial distributions --- paddle/fluid/operators/multinomial_op.cc | 9 +- paddle/fluid/operators/multinomial_op.cu | 185 +++++++++++++----- .../tests/unittests/test_multinomial_op.py | 11 +- 3 files changed, 150 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/multinomial_op.cc b/paddle/fluid/operators/multinomial_op.cc index b1d2fa205a4..f631210795f 100644 --- a/paddle/fluid/operators/multinomial_op.cc +++ b/paddle/fluid/operators/multinomial_op.cc @@ -30,6 +30,7 @@ class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "A tensor contains probabilities of categories"); AddOutput("Out", "The output tensor of multinomial op"); + // AddOutput("yokiOut", "yoki"); AddAttr("num_samples", "number of the generated samples") .SetDefault(1); AddAttr("replacement", "can a category be sampled more than once") @@ -49,7 +50,7 @@ class MultinomialOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial"); + // OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial"); auto x_dim = ctx->GetInputDim("X"); int64_t x_rank = x_dim.size(); @@ -62,6 +63,7 @@ class MultinomialOp : public framework::OperatorWithKernel { out_dims[x_rank - 1] = num_samples; ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + // ctx->SetOutputDim("yokiOut", x_dim); } }; @@ -72,11 +74,16 @@ class MultinomialOpKernel void Compute(const framework::ExecutionContext &ctx) const override { const auto x = ctx.Input("X"); auto out = ctx.Output("Out"); + // auto yokiout = ctx.Output("yokiOut"); const int64_t num_samples = ctx.Attr("num_samples"); const bool replacement = ctx.Attr("replacement"); auto *in_data = x->data(); auto *out_data = out->mutable_data(ctx.GetPlace()); + /*auto *yokiout_data = yokiout->mutable_data(ctx.GetPlace()); + for (int i = 0; i < x->numel(); i++) { + yokiout_data[i] = in_data[i]; + }*/ auto in_dims = x->dims(); int64_t in_rank = in_dims.size(); diff --git a/paddle/fluid/operators/multinomial_op.cu b/paddle/fluid/operators/multinomial_op.cu index f83f2ea8c4c..51399f7116c 100644 --- a/paddle/fluid/operators/multinomial_op.cu +++ b/paddle/fluid/operators/multinomial_op.cu @@ -70,8 +70,31 @@ template __global__ void NormalizeProbability(T* norm_probs, const T* in_data, T* sum_rows) { // int id = blockIdx.x * blockDim.x + threadIdx.x; - int id = threadIdx.x; - norm_probs[id] = in_data[id] / sum_rows[0]; + // int id = threadIdx.x; + int id = threadIdx.x + blockIdx.x * blockDim.x + + blockIdx.y * gridDim.x * blockDim.x; + norm_probs[id] = in_data[id] / sum_rows[blockIdx.y]; +} + +template +__global__ void yokiFunc(const T* in_data, T* out) { + // int id = blockIdx.x * blockDim.x + threadIdx.x; + // int id = threadIdx.x; + int id = threadIdx.x + blockIdx.x * blockDim.x + + blockIdx.y * gridDim.x * blockDim.x; + out[id] = in_data[id]; +} + +template +__global__ void Cumsum(T* norm_probs_data, int64_t num_distributions, + int64_t num_categories, T* cumulative_probs) { + // int id = blockIdx.x; + for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) { + thrust::inclusive_scan(thrust::device, + norm_probs_data + id * num_categories, + norm_probs_data + (id + 1) * num_categories, + cumulative_probs + id * num_categories); + } } template @@ -141,21 +164,29 @@ __global__ void sampleMultinomialWithReplacement( // global index formula for 2D grid of 1D blocks // int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + // threadIdx.x; - int idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int sample = blockIdx.x * blockDim.x + threadIdx.x; - sample < totalSamples; sample += blockDim.x * gridDim.x) { - // we are losing 3 out of 4 generated numbers but it's ok - // this kernel is not very efficient anyway + // int idx = blockIdx.x * blockDim.x + threadIdx.x; - // T uniform_random = dist(rng); - T uniform_random = rng[sample]; + int idx = threadIdx.x + blockIdx.x * blockDim.x + + blockIdx.y * gridDim.x * blockDim.x; - // Find the bucket that a uniform sample lies in - int choice = binarySearchForMultinomial(normDistPrefixSum, normDist, - categories, uniform_random); + for (int curDist = blockIdx.y; curDist < distributions; + curDist += gridDim.y) { + for (int sample = blockIdx.x * blockDim.x + threadIdx.x; + sample < totalSamples; sample += blockDim.x * gridDim.x) { + // we are losing 3 out of 4 generated numbers but it's ok + // this kernel is not very efficient anyway - dest[sample] = choice; + // T uniform_random = dist(rng); + T uniform_random = rng[sample + curDist * totalSamples]; + + // Find the bucket that a uniform sample lies in + int choice = binarySearchForMultinomial( + normDistPrefixSum + curDist * categories, + normDist + curDist * categories, categories, uniform_random); + + dest[sample + curDist * totalSamples] = choice; + } } } @@ -167,17 +198,48 @@ class MultinomialOpKernel const auto x = ctx.Input("X"); auto out = ctx.Output("Out"); + // auto yokiout = ctx.Output("yokiOut"); + const int64_t num_samples = ctx.Attr("num_samples"); const bool replacement = ctx.Attr("replacement"); auto* in_data = x->data(); auto* out_data = out->mutable_data(ctx.GetPlace()); + // auto* yokiout_data = yokiout->mutable_data(ctx.GetPlace()); auto in_dims = x->dims(); int64_t in_rank = in_dims.size(); const int64_t num_categories = in_dims[in_rank - 1]; const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; + if (!replacement) { + int in_data_numel = x->numel(); + int out_data_numel = out->numel(); + // std::vector cpu_in_data(in_data_numel); + // std::vector cpu_out_data(out_data_numel); + // T cpu_in_data[in_data_numel]; + // T cpu_out_data[out_data_numel]; + + T* cpu_in_data = new T[in_data_numel]; + T* cpu_out_data = new T[out_data_numel]; + + cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T), + cudaMemcpyDeviceToHost); + + VLOG(3) << "Print cpu_in_data " << cpu_in_data[0] << "\n"; + VLOG(3) << "Print in_data_numel " << in_data_numel << "\n"; + VLOG(3) << "Print out_data_numel " << out_data_numel << "\n"; + + MultinomialFunctor(cpu_out_data, cpu_in_data, num_samples, replacement, + num_categories, num_distributions); + cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(T), + cudaMemcpyHostToDevice); + + delete[] cpu_in_data; + delete[] cpu_out_data; + return; + } + // std::vector sum_rows(num_distributions); // SumArrayCUDAKernel(in_data, sum_rows,) @@ -188,30 +250,44 @@ class MultinomialOpKernel VLOG(3) << "Print in_rank " << in_rank << "\n"; framework::Tensor sum_rows_t; - auto* sum_rows_data = sum_rows_t.mutable_data({1}, ctx.GetPlace()); + auto* sum_rows_data = + sum_rows_t.mutable_data({num_distributions}, ctx.GetPlace()); // auto* sum_rows_data = - // sum_rows_t->mutable_data(framework::make_ddim({1}), ctx.GetPlace()); + // sum_rows_t->mutable_data(framework::make_ddim({num_distributions}), + // ctx.GetPlace()); auto& place = *ctx.template device_context() .eigen_device(); - auto eigen_input = framework::EigenVector::Flatten(*x); - // auto eigen_sum_rows = framework::EigenVector::From(sum_rows_t); - auto eigen_sum_rows = framework::EigenScalar::From(sum_rows_t); - eigen_sum_rows.device(place) = - eigen_input.sum(Eigen::DSizes(0)) - .eval() - .reshape(Eigen::DSizes(sum_rows_t.dims()[0])); - // eigen_sum_rows.device(place) = - // eigen_input.sum().eval().reshape(Eigen::DSizes(1)); - - dim3 grid(num_distributions); - dim3 block(num_categories); + if (num_distributions == 1) { + auto eigen_input = framework::EigenVector::Flatten(*x); + auto eigen_sum_rows = framework::EigenVector::From(sum_rows_t); + // auto eigen_sum_rows = framework::EigenScalar::From(sum_rows_t); + eigen_sum_rows.device(place) = + eigen_input.sum(Eigen::DSizes(1)) + .eval() + .reshape(Eigen::DSizes(sum_rows_t.dims()[0])); + } else { + auto eigen_input = framework::EigenMatrix::From(*x); + // auto eigen_sum_rows = framework::EigenVector::From(sum_rows_t); + auto eigen_sum_rows = framework::EigenVector::From(sum_rows_t); + eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes(1)); + // .eval() + // .reshape(Eigen::DSizes(sum_rows_t.dims()[0])); + // eigen_sum_rows.device(place) = + // eigen_input.sum().eval().reshape(Eigen::DSizes(1)); + } // std::vector in_data_norm(num_categories); framework::Tensor norm_probs_t; - auto* norm_probs_data = - norm_probs_t.mutable_data({num_categories}, ctx.GetPlace()); + auto* norm_probs_data = norm_probs_t.mutable_data( + {num_distributions, num_categories}, ctx.GetPlace()); + + // dim3 grid(num_distributions); + // dim3 block(num_categories); + + dim3 block(num_categories < 512 ? num_categories : 512); + dim3 grid((num_categories - 1) / block.x + 1, num_distributions); NormalizeProbability< T><<>>( norm_probs_data, in_data, sum_rows_data); @@ -219,43 +295,46 @@ class MultinomialOpKernel // num_distributions can only be 1. // std::vector cumulative_probs(num_categories); framework::Tensor cumulative_probs_t; - auto* cumulative_probs = - cumulative_probs_t.mutable_data({num_categories}, ctx.GetPlace()); + auto* cumulative_probs = cumulative_probs_t.mutable_data( + {num_distributions, num_categories}, ctx.GetPlace()); // T cumulative_probs[num_categories]; - int64_t size = num_categories; - thrust::inclusive_scan(thrust::device, norm_probs_data, - norm_probs_data + num_categories, cumulative_probs); + dim3 block1(1); + dim3 grid1(num_distributions); + Cumsum<<>>( + norm_probs_data, num_distributions, num_categories, cumulative_probs); + + /* + dim3 block2(num_categories < 512 ? num_categories : 512); + dim3 grid2((num_categories-1)/block2.x+1, num_distributions); + yokiFunc<<>>( + cumulative_probs, yokiout_data);*/ + + // int64_t size = num_categories; + // thrust::inclusive_scan(thrust::device, norm_probs_data, + // norm_probs_data + num_categories, + // cumulative_probs); + + VLOG(3) << "Print cumsum " << cumulative_probs << "\n"; if (replacement) { dim3 block(128); // int grid_y = 1; - dim3 grid((num_samples - 1) / block.x + 1); - - /* - // std::vector rng(num_samples); - T rng[num_samples]; - std::uniform_real_distribution dist(0, 1); - auto gen_ptr = framework::DefaultCPUGenerator(); - auto engine = gen_ptr->GetCPUEngine(); - - for (int s = 0; s < num_samples; s++) { - rng[s] = dist(*engine); - } - */ + dim3 grid((num_samples - 1) / block.x + 1, num_distributions); std::random_device rd; auto seed = rd(); framework::Tensor rng_data_t; - auto* rng_data = - rng_data_t.mutable_data({num_samples}, ctx.GetPlace()); + auto* rng_data = rng_data_t.mutable_data( + {num_distributions, num_samples}, ctx.GetPlace()); thrust::counting_iterator index_sequence_begin(0); platform::Transform trans; auto* context = static_cast( &ctx.device_context()); - trans(*context, index_sequence_begin, index_sequence_begin + num_samples, - rng_data, RandomGeneratorCudaFunctor(seed)); + trans(*context, index_sequence_begin, + index_sequence_begin + num_distributions * num_samples, rng_data, + RandomGeneratorCudaFunctor(seed)); VLOG(3) << "Print enter\n"; // VLOG(3) << "Print size in_data " << @@ -267,8 +346,12 @@ class MultinomialOpKernel T><<>>( rng_data, num_samples, out_data, num_distributions, num_categories, cumulative_probs, norm_probs_data); + + VLOG(3) << "Print end\n" << out_data; } + VLOG(3) << "Print final end\n"; + // MultinomialCudaFunctor(out_data, in_data, num_samples, replacement, // num_categories, num_distributions); } diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py index d4f9a176831..32d438eabf8 100644 --- a/python/paddle/fluid/tests/unittests/test_multinomial_op.py +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -38,6 +38,7 @@ class TestMultinomialOp(OpTest): # input probability is a vector, and replacement is True self.input_np = np.random.rand(4) self.outputs = {"Out": np.zeros(100000).astype("int64")} + # self.outputs = {"yokiOut": np.zeros(4).astype("int64")} self.attrs = {"num_samples": 100000, "replacement": True} def test_check_output(self): @@ -53,19 +54,21 @@ class TestMultinomialOp(OpTest): # normalize the input to get the probability prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True) sample_prob = self.sample_output(np.array(outs[0])) - print("sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) + # sample_prob = np.array(outs[0]) + # print("input", self.input_np) + # print("sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) self.assertTrue( np.allclose( sample_prob, prob, rtol=0, atol=0.01), "sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) -""" class TestMultinomialOp2(TestMultinomialOp): def init_data(self): # input probability is a matrix self.input_np = np.random.rand(3, 4) self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")} + # self.outputs = {"yokiOut": np.zeros((3, 4)).astype("int64")} self.attrs = {"num_samples": 100000, "replacement": True} def sample_output(self, out): @@ -88,11 +91,13 @@ class TestMultinomialOp3(TestMultinomialOp): def verify_output(self, outs): out = np.array(outs[0]) + # print("op3out", out) unique_out = np.unique(out) self.assertEqual( len(unique_out), 100, "replacement is False. categories can't be sampled repeatedly") -""" + + """ class TestReplacementError(unittest.TestCase): def init_data(self): -- GitLab