未验证 提交 a2020d0c 编写于 作者: S sneaxiy 提交者: GitHub

fix dropout (#43234)

上级 d9f8636c
......@@ -198,11 +198,13 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t main_offset =
size / (block_size * kVecSize) * (block_size * kVecSize);
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T, uint8_t>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
!is_fix_seed, (VectorizedRandomGenerator<T, uint8_t>), grid_size,
block_size, 0, stream, offset, KERNEL_PARAMS.As<uint64_t>(1),
KERNEL_PARAMS.As<uint64_t>(7), size, seed_data, dropout_prob, x_data,
mask_data, y_data, upscale_in_train, increment, main_offset);
!is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, stream,
offset, KERNEL_PARAMS.As<uint64_t>(1), KERNEL_PARAMS.As<uint64_t>(7),
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment, main_offset);
#undef PD_DROPOUT_KERNEL_NAME
} else {
if (upscale_in_train) {
// todo: can y share with data with x directly?
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册