未验证 提交 a2020d0c 编写于 作者: S sneaxiy 提交者: GitHub

fix dropout (#43234)

上级 d9f8636c
...@@ -198,11 +198,13 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, ...@@ -198,11 +198,13 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t main_offset = size_t main_offset =
size / (block_size * kVecSize) * (block_size * kVecSize); size / (block_size * kVecSize) * (block_size * kVecSize);
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T, uint8_t>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL( PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
!is_fix_seed, (VectorizedRandomGenerator<T, uint8_t>), grid_size, !is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, stream,
block_size, 0, stream, offset, KERNEL_PARAMS.As<uint64_t>(1), offset, KERNEL_PARAMS.As<uint64_t>(1), KERNEL_PARAMS.As<uint64_t>(7),
KERNEL_PARAMS.As<uint64_t>(7), size, seed_data, dropout_prob, x_data, size, seed_data, dropout_prob, x_data, mask_data, y_data,
mask_data, y_data, upscale_in_train, increment, main_offset); upscale_in_train, increment, main_offset);
#undef PD_DROPOUT_KERNEL_NAME
} else { } else {
if (upscale_in_train) { if (upscale_in_train) {
// todo: can y share with data with x directly? // todo: can y share with data with x directly?
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册