未验证 提交 0c3e4cf6 编写于 作者: D duanyanhui 提交者: GitHub

[DCU] support cum & multinomial for dcu (#56612)

* support cum & multinomial for dcu

* rm commt
上级 76b328bc
......@@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(cumsum_grad,
phi::CumsumGradKernel,
float,
double,
phi::dtype::float16,
int16_t,
int,
int64_t) {}
......
......@@ -435,6 +435,7 @@ PD_REGISTER_KERNEL(cumsum,
ALL_LAYOUT,
phi::CumsumKernel,
float,
phi::dtype::float16,
double,
int16_t,
int,
......
......@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include "paddle/phi/kernels/multinomial_kernel.h"
#ifdef __NVCC__
......@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement(
size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x +
threadIdx.x;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state);
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state);
#endif
int sample = blockIdx.x * blockDim.x + threadIdx.x;
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
if (sample < num_samples) {
#if defined(__NVCC__)
T rng_number = static_cast<T>(curand_uniform4(&state).x);
// Find the bucket that a uniform random number lies in
#else
T rng_number = static_cast<T>(hiprand_uniform4(&state).x);
#endif
int selected_category =
binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories,
norm_probs_data + dist * num_categories,
......@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(multinomial, // cuda_only
PD_REGISTER_KERNEL(multinomial,
GPU,
ALL_LAYOUT,
phi::MultinomialKernel,
......@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif
......@@ -183,10 +183,6 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
"""
assert (
not core.is_compiled_with_rocm()
), "multinomial op is not supported on ROCM yet."
if in_dynamic_mode():
return _C_ops.multinomial(x, num_samples, replacement)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册