dirichlet_kernel_impl.h 4.4 KB
Newer Older
1
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

17 18
#include <cmath>
#include <random>
19
#include "paddle/phi/kernels/dirichlet_kernel.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(PADDLE_WITH_CUDA)
#define COMPAT_EXP exp
#define COMPAT_CEIL ceil
#define COMPAT_FLOOR floor
#define COMPAT_LOG log
#define COMPAT_POW pow
#define COMPAT_SQRT sqrt
#define COMPAT_TAN tan
#define COMPAT_ABS abs
#define COMPAT_LOG1P log1p
#else
#define COMPAT_EXP std::exp
#define COMPAT_CEIL std::ceil
#define COMPAT_FLOOR std::floor
#define COMPAT_LOG std::log
#define COMPAT_POW std::pow
#define COMPAT_SQRT std::sqrt
#define COMPAT_TAN std::tan
#define COMPAT_ABS std::abs
#define COMPAT_LOG1P std::log1p
#endif

44
namespace phi {
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

template <typename ScalarT, typename SamplerT>
struct BaseSampler {
  SamplerT sampler_;
  HOSTDEVICE BaseSampler(const SamplerT& sampler) : sampler_(sampler) {}
  HOSTDEVICE ScalarT sample() { return sampler_(); }
};

// `sample_gamma` is d from Numpy's distributions.c, and add support for
//  paddle data type and code style.
//  Source MIT licensed:
/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the
 * "Software"), to deal in the Software without restriction, including
 * without limitation the rights to use, copy, modify, merge, publish,
 * distribute, sublicense, and/or sell copies of the Software, and to
 * permit persons to whom the Software is furnished to do so, subject to
 * the following conditions:
 *
 * The above copyright notice and this permission notice shall be included
 * in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

78 79 80
template <typename ScalarT,
          typename AccscalarT,
          typename UniformSamplerT,
81
          typename NormalSamplerT>
82 83 84 85
HOSTDEVICE ScalarT
sample_gamma(ScalarT alpha,
             BaseSampler<AccscalarT, UniformSamplerT> standard_uniform,
             BaseSampler<AccscalarT, NormalSamplerT> standard_normal) {
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
  AccscalarT scale = 1.0f;

  // Boost alpha for higher acceptance probability.
  if (alpha < 1.0f) {
    if (alpha == 0.f) return 0.f;
    scale *= COMPAT_POW(1 - standard_uniform.sample(), 1.0f / alpha);
    alpha += 1.0f;
  }

  // This implements the acceptance-rejection method of Marsaglia and Tsang
  // (2000)
  // doi:10.1145/358407.358414
  const AccscalarT d = alpha - 1.0f / 3.0f;
  const AccscalarT c = 1.0f / COMPAT_SQRT(9.0f * d);
  for (;;) {
    AccscalarT x, y;
    do {
      x = standard_normal.sample();
      y = 1.0f + c * x;
    } while (y <= 0);
    const AccscalarT v = y * y * y;
    const AccscalarT u = 1 - standard_uniform.sample();
    const AccscalarT xx = x * x;
    if (u < 1.0f - 0.0331f * xx * xx)
      return static_cast<ScalarT>(scale * d * v);
    if (COMPAT_LOG(u) < 0.5f * xx + d * (1.0f - v + COMPAT_LOG(v)))
      return static_cast<ScalarT>(scale * d * v);
  }
}

116 117 118 119 120
template <typename Context, typename T>
struct DirichletSampler {
  void operator()(const Context& dev_ctx,
                  const DenseTensor& alpha,
                  DenseTensor* out);
121
};
122 123 124 125 126 127 128 129 130 131

template <typename T, typename Context>
void Dirichletkernel(const Context& dev_ctx,
                     const DenseTensor& alpha,
                     DenseTensor* out) {
  dev_ctx.template Alloc<T>(out);
  DirichletSampler<Context, T> sampler;
  sampler(dev_ctx, alpha, out);
}
}  // namespace phi