未验证 提交 0c3e4cf6 编写于 作者: D duanyanhui 提交者: GitHub

[DCU] support cum & multinomial for dcu (#56612)

* support cum & multinomial for dcu

* rm commt
上级 76b328bc
...@@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(cumsum_grad, ...@@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(cumsum_grad,
phi::CumsumGradKernel, phi::CumsumGradKernel,
float, float,
double, double,
phi::dtype::float16,
int16_t, int16_t,
int, int,
int64_t) {} int64_t) {}
......
...@@ -435,6 +435,7 @@ PD_REGISTER_KERNEL(cumsum, ...@@ -435,6 +435,7 @@ PD_REGISTER_KERNEL(cumsum,
ALL_LAYOUT, ALL_LAYOUT,
phi::CumsumKernel, phi::CumsumKernel,
float, float,
phi::dtype::float16,
double, double,
int16_t, int16_t,
int, int,
......
...@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include "paddle/phi/kernels/multinomial_kernel.h" #include "paddle/phi/kernels/multinomial_kernel.h"
#ifdef __NVCC__ #ifdef __NVCC__
...@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement( ...@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement(
size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x + size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x +
threadIdx.x; threadIdx.x;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state); curand_init(seed, idx, offset, &state);
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state);
#endif
int sample = blockIdx.x * blockDim.x + threadIdx.x; int sample = blockIdx.x * blockDim.x + threadIdx.x;
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
if (sample < num_samples) { if (sample < num_samples) {
#if defined(__NVCC__)
T rng_number = static_cast<T>(curand_uniform4(&state).x); T rng_number = static_cast<T>(curand_uniform4(&state).x);
// Find the bucket that a uniform random number lies in #else
T rng_number = static_cast<T>(hiprand_uniform4(&state).x);
#endif
int selected_category = int selected_category =
binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories, binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories,
norm_probs_data + dist * num_categories, norm_probs_data + dist * num_categories,
...@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(multinomial, // cuda_only PD_REGISTER_KERNEL(multinomial,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultinomialKernel, phi::MultinomialKernel,
...@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only ...@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
double) { double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
#endif
...@@ -183,10 +183,6 @@ def multinomial(x, num_samples=1, replacement=False, name=None): ...@@ -183,10 +183,6 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
""" """
assert (
not core.is_compiled_with_rocm()
), "multinomial op is not supported on ROCM yet."
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.multinomial(x, num_samples, replacement) return _C_ops.multinomial(x, num_samples, replacement)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册