未验证 提交 65e3fa35 编写于 作者: B Bo Zhang 提交者: GitHub

dropout_nd_optimization (#51479)

* with printf

* add DropOutNdForwardKernel

* PR comment
上级 c74aaf67
...@@ -33,15 +33,68 @@ limitations under the License. */ ...@@ -33,15 +33,68 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h" #include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
template <typename T1, typename T2 = T1, typename OutT = T1> template <typename T>
struct DstFunctor {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline DstFunctor(const float retain_prob,
const bool is_upscale_in_train,
const int64_t num)
: retain_prob_(retain_prob),
is_upscale_in_train_(is_upscale_in_train),
num_(num) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline T operator()(const T src_val, const uint8_t mask) const {
for (int i = 0; i < num_; i++) {
if (mask == static_cast<uint8_t>(1)) {
return is_upscale_in_train_
? static_cast<T>(static_cast<MT>(src_val) * factor)
: static_cast<T>(src_val);
} else {
return static_cast<T>(0);
}
}
}
private:
const float retain_prob_;
const bool is_upscale_in_train_;
const int64_t num_;
};
template <typename T>
struct MaskFunctor {
const float retain_prob_;
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline MaskFunctor(const float retain_prob)
: retain_prob_(retain_prob) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline void operator()(T* dst, const float* rand, int num) const {
static constexpr int kCount =
phi::funcs::uniform_distribution<float>::kReturnsCount;
// 0 ~ kCount - 1 is dst, kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for (int i = 0; i < kCount; i++) {
dst[i] = rand[i] < retain_prob_ ? static_cast<T>(1) : static_cast<T>(0);
}
}
};
template <typename T>
struct DstMaskFunctor { struct DstMaskFunctor {
const float retain_prob_; const float retain_prob_;
const bool is_upscale_in_train_; const bool is_upscale_in_train_;
using MT = typename phi::kps::details::MPTypeTrait<T1>::Type; using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor; MT factor;
HOSTDEVICE inline DstMaskFunctor(const float retain_prob, HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
const bool is_upscale_in_train) const bool is_upscale_in_train)
...@@ -49,34 +102,34 @@ struct DstMaskFunctor { ...@@ -49,34 +102,34 @@ struct DstMaskFunctor {
factor = static_cast<MT>(1.0f / retain_prob_); factor = static_cast<MT>(1.0f / retain_prob_);
} }
HOSTDEVICE inline void operator()(OutT* dst, HOSTDEVICE inline void operator()(T* dst,
const T1* src_val, const T* src_val,
const T2* rand, const float* rand,
int num) const { int num) const {
static constexpr int kCount = static constexpr int kCount =
phi::funcs::uniform_distribution<T2>::kReturnsCount; phi::funcs::uniform_distribution<float>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask // 0 ~ kCount - 1 is dst, kCount ~ 2 * kCount - 1 is mask
#pragma unroll #pragma unroll
for (int i = 0; i < kCount; i++) { for (int i = 0; i < kCount; i++) {
if (rand[i] < retain_prob_) { if (rand[i] < retain_prob_) {
dst[i] = is_upscale_in_train_ dst[i] = is_upscale_in_train_
? static_cast<T1>(static_cast<MT>(src_val[i]) * factor) ? static_cast<T>(static_cast<MT>(src_val[i]) * factor)
: static_cast<T1>(src_val[i]); : static_cast<T>(src_val[i]);
dst[i + kCount] = static_cast<T1>(1); dst[i + kCount] = static_cast<T>(1);
} else { } else {
dst[i] = static_cast<T1>(0); dst[i] = static_cast<T>(0);
dst[i + kCount] = dst[i]; dst[i + kCount] = dst[i];
} }
} }
} }
}; };
template <typename T, typename MaskType> template <typename T>
__global__ void VectorizedRandomGenerator(const size_t n, __global__ void VectorizedRandomGenerator(const size_t n,
uint64_t seed, uint64_t seed,
const float dropout_prob, const float dropout_prob,
const T* src, const T* src,
MaskType* mask, uint8_t* mask,
T* dst, T* dst,
bool is_upscale_in_train, bool is_upscale_in_train,
uint64_t increment, uint64_t increment,
...@@ -94,9 +147,10 @@ __global__ void VectorizedRandomGenerator(const size_t n, ...@@ -94,9 +147,10 @@ __global__ void VectorizedRandomGenerator(const size_t n,
curand_init(seed, idx + THREAD_ID_X, increment, &state); curand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = curandStatePhilox4_32_10_t; using SType = curandStatePhilox4_32_10_t;
#endif #endif
T dst_mask[kCount * 2]; // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask T dst_mask[kCount *
2]; // 0 ~ kCount - 1 : dst, kCount ~ 2 * kCount - 1: mask
float rands[kCount]; float rands[kCount];
MaskType mask_result[kCount]; uint8_t mask_result[kCount];
using Rand = phi::funcs::uniform_distribution<float>; using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>; using Cast = kps::IdentityFunctor<T>;
int deal_size = BLOCK_NUM_X * kCount; int deal_size = BLOCK_NUM_X * kCount;
...@@ -104,19 +158,19 @@ __global__ void VectorizedRandomGenerator(const size_t n, ...@@ -104,19 +158,19 @@ __global__ void VectorizedRandomGenerator(const size_t n,
size_t fix = idx * kCount; size_t fix = idx * kCount;
auto dst_functor = auto dst_functor =
DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train); DstMaskFunctor<T>(1.0f - dropout_prob, is_upscale_in_train);
for (; fix < main_offset; fix += stride) { for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size); kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>( kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state); &rands[0], Rand(), &state);
// dst // dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>( kps::OperatorTernary<T, float, T, DstMaskFunctor<T>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, false>(dst + fix, &dst_mask[0], deal_size); kps::WriteData<T, kCount, 1, false>(dst + fix, &dst_mask[0], deal_size);
// mask // mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, Cast>( kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast()); &mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, false>( kps::WriteData<uint8_t, kCount, 1, false>(
mask + fix, &mask_result[0], deal_size); mask + fix, &mask_result[0], deal_size);
if (fix > idx * kCount + 1) { if (fix > idx * kCount + 1) {
__syncthreads(); __syncthreads();
...@@ -128,82 +182,33 @@ __global__ void VectorizedRandomGenerator(const size_t n, ...@@ -128,82 +182,33 @@ __global__ void VectorizedRandomGenerator(const size_t n,
kps::ElementwiseRandom<SType, float, kCount, Rand>( kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state); &rands[0], Rand(), &state);
// dst // dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>( kps::OperatorTernary<T, float, T, DstMaskFunctor<T>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, true>(dst + fix, &dst_mask[0], remainder); kps::WriteData<T, kCount, 1, true>(dst + fix, &dst_mask[0], remainder);
// mask // mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, Cast>( kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast()); &mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, true>( kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder); mask + fix, &mask_result[0], remainder);
__syncthreads(); __syncthreads();
} }
} }
template <typename T1, typename T2 = T1, typename OutT = T1> template <typename T>
struct MaskFunctor { __global__ void DropOutNdForwardKernel(
const float retain_prob_; const size_t n,
using MT = typename phi::kps::details::MPTypeTrait<T1>::Type; uint64_t seed,
MT factor; const float dropout_prob,
HOSTDEVICE inline MaskFunctor(const float retain_prob) const T* src,
: retain_prob_(retain_prob) { uint8_t* mask,
factor = static_cast<MT>(1.0f / retain_prob_); uint64_t increment,
} size_t main_offset,
DstFunctor<T> dst_functor,
HOSTDEVICE inline void operator()(OutT* dst, const T2* rand, int num) const { T* y,
static constexpr int kCount = int64_t N,
phi::funcs::uniform_distribution<T2>::kReturnsCount; kps::details::BroadcastConfig broadcast_config) {
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask // Vectorized Generate Mask
#pragma unroll // kCount is 4 for curand_uniform4 is used
for (int i = 0; i < kCount; i++) {
if (rand[i] < retain_prob_) {
dst[i] = static_cast<T1>(1);
} else {
dst[i] = static_cast<T1>(0);
}
}
}
};
template <typename T, typename MaskType>
struct DstFunctor {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline DstFunctor(const float retain_prob,
const bool is_upscale_in_train,
const int64_t num)
: retain_prob_(retain_prob),
is_upscale_in_train_(is_upscale_in_train),
num_(num) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline T operator()(const T src_val, const MaskType mask) const {
for (int i = 0; i < num_; i++) {
if (mask == static_cast<MaskType>(1)) {
return is_upscale_in_train_
? static_cast<T>(static_cast<MT>(src_val) * factor)
: static_cast<T>(src_val);
} else {
return static_cast<T>(0);
}
}
}
private:
const float retain_prob_;
const bool is_upscale_in_train_;
const int64_t num_;
};
template <typename T, typename MaskType>
__global__ void VectorizedGeneratorMask(const size_t n,
uint64_t seed,
const float dropout_prob,
const T* src,
MaskType* mask,
uint64_t increment,
size_t main_offset) {
constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount; constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X); size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
...@@ -216,28 +221,28 @@ __global__ void VectorizedGeneratorMask(const size_t n, ...@@ -216,28 +221,28 @@ __global__ void VectorizedGeneratorMask(const size_t n,
curand_init(seed, idx + THREAD_ID_X, increment, &state); curand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = curandStatePhilox4_32_10_t; using SType = curandStatePhilox4_32_10_t;
#endif #endif
T dst_mask[kCount]; // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask T dst_mask[kCount]; // 0 ~ kCount - 1 : dst, kCount ~ 2 * kCount - 1: mask
float rands[kCount]; float rands[kCount];
MaskType mask_result[kCount]; uint8_t mask_result[kCount];
using Rand = phi::funcs::uniform_distribution<float>; using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>; using Cast = kps::IdentityFunctor<T>;
int deal_size = BLOCK_NUM_X * kCount; int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount; size_t fix = idx * kCount;
auto mask_functor = MaskFunctor<T, float>(1.0f - dropout_prob); auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
for (; fix < main_offset; fix += stride) { for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size); kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>( kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state); &rands[0], Rand(), &state);
// dst // dst
kps::OperatorBinary<float, T, MaskFunctor<T, float>>( kps::OperatorBinary<float, T, MaskFunctor<T>>(
&dst_mask[0], &rands[0], mask_functor, kCount); &dst_mask[0], &rands[0], mask_functor, kCount);
// mask // mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, Cast>( kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
&mask_result[0], &dst_mask[0], Cast()); &mask_result[0], &dst_mask[0], Cast());
kps::WriteData<MaskType, kCount, 1, false>( kps::WriteData<uint8_t, kCount, 1, false>(
mask + fix, &mask_result[0], deal_size); mask + fix, &mask_result[0], deal_size);
if (fix > idx * kCount + 1) { if (fix > idx * kCount + 1) {
__syncthreads(); __syncthreads();
...@@ -249,28 +254,30 @@ __global__ void VectorizedGeneratorMask(const size_t n, ...@@ -249,28 +254,30 @@ __global__ void VectorizedGeneratorMask(const size_t n,
kps::ElementwiseRandom<SType, float, kCount, Rand>( kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state); &rands[0], Rand(), &state);
// dst // dst
kps::OperatorBinary<float, T, MaskFunctor<T, float>>( kps::OperatorBinary<float, T, MaskFunctor<T>>(
&dst_mask[0], &rands[0], mask_functor, kCount); &dst_mask[0], &rands[0], mask_functor, kCount);
// mask // mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, Cast>( kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
&mask_result[0], &dst_mask[0], Cast()); &mask_result[0], &dst_mask[0], Cast());
kps::WriteData<MaskType, kCount, 1, true>( kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder); mask + fix, &mask_result[0], remainder);
__syncthreads(); __syncthreads();
} }
} // Broadcast mask data and do elementwise operaiton with DstFunctor
CUDA_KERNEL_LOOP(i, N) {
inline void CalcBroadcastedMask(const phi::GPUContext& dev_ctx, uint32_t offset = 0u;
const phi::DenseTensor& mask, uint32_t idx = i;
phi::DenseTensor* broadcasted_mask) { // Use (j < phi::DDim::kMaxRank) conditiion rather than
// The broadcast of mask can be combined to the following ElementwiseKernel // (j < broadcast_config.rank) for (#pragma unroll)
// when the BroadcastKernel supports different input types. #pragma unroll
dev_ctx.template Alloc<uint8_t>(broadcasted_mask); for (int j = 0; j < phi::DDim::kMaxRank; ++j) {
if (j == broadcast_config.rank) break;
std::vector<const phi::DenseTensor*> ins = {&mask}; auto fast_divmoder = broadcast_config.divmoders[j].Divmod(idx);
std::vector<phi::DenseTensor*> outs = {broadcasted_mask}; idx = fast_divmoder.val[0];
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary, uint8_t, uint8_t>( offset += broadcast_config.strides[j] * fast_divmoder.val[1];
dev_ctx, ins, &outs, -1, kps::IdentityFunctor<uint8_t>()); }
y[i] = dst_functor(src[i], mask[offset]);
}
} }
template <typename T, typename MT> template <typename T, typename MT>
...@@ -285,17 +292,19 @@ void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx, ...@@ -285,17 +292,19 @@ void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx,
} }
template <typename T> template <typename T>
void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, void DropoutFwGPUKernelDriver(
bool is_test, const phi::GPUContext& dev_ctx,
float dropout_prob, bool is_test,
bool upscale_in_train, float dropout_prob,
bool is_fix_seed, bool upscale_in_train,
int seed_val, bool is_fix_seed,
const phi::DenseTensor& x, int seed_val,
const phi::DenseTensor* seed, const phi::DenseTensor& x,
phi::DenseTensor* mask, const phi::DenseTensor* seed,
phi::DenseTensor* y, phi::DenseTensor* mask,
bool is_dropout_nd = false) { phi::DenseTensor* y,
bool is_dropout_nd = false,
const std::vector<int>& axis = std::vector<int>()) {
int64_t x_numel = x.numel(); int64_t x_numel = x.numel();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
...@@ -344,26 +353,32 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, ...@@ -344,26 +353,32 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx,
size / (block_size * kVecSize) * (block_size * kVecSize); size / (block_size * kVecSize) * (block_size * kVecSize);
if (is_dropout_nd) { if (is_dropout_nd) {
VectorizedGeneratorMask<T, uint8_t> auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
auto input_x_dims = x.dims();
auto mask_dims = mask->dims();
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(input_x_dims);
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask_dims);
reverse(out_dims.begin(), out_dims.end());
reverse(in_dims.begin(), in_dims.end());
kps::details::BroadcastConfig broadcast_config(
out_dims, in_dims, x.dims().size());
DropOutNdForwardKernel<T>
<<<grid_size, block_size, 0, stream>>>(size, <<<grid_size, block_size, 0, stream>>>(size,
seed_data, seed_data,
dropout_prob, dropout_prob,
x_data, x_data,
mask_data, mask_data,
increment, increment,
main_offset); main_offset,
dst_functor,
phi::DenseTensor broadcasted_mask; y_data,
broadcasted_mask.Resize(x.dims()); y->numel(),
CalcBroadcastedMask(dev_ctx, *mask, &broadcasted_mask); broadcast_config);
auto dst_functor = DstFunctor<T, uint8_t>(
1.0f - dropout_prob, upscale_in_train, x_numel);
std::vector<const phi::DenseTensor*> ins = {&x, &broadcasted_mask};
std::vector<phi::DenseTensor*> outs = {y};
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, dst_functor);
} else { } else {
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T, uint8_t> #define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed, PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed,
PD_DROPOUT_KERNEL_NAME, PD_DROPOUT_KERNEL_NAME,
grid_size, grid_size,
...@@ -397,14 +412,14 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, ...@@ -397,14 +412,14 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx,
} }
} }
template <typename T, typename MaskType> template <typename T>
struct CudaDropoutGradFunctor { struct CudaDropoutGradFunctor {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type; using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {} explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}
__device__ __forceinline__ T operator()(const T dout, __device__ __forceinline__ T operator()(const T dout,
const MaskType mask) const { const uint8_t mask) const {
return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) * return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
factor_); factor_);
} }
...@@ -433,7 +448,17 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, ...@@ -433,7 +448,17 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
phi::DenseTensor broadcasted_mask; phi::DenseTensor broadcasted_mask;
if (is_dropout_nd) { if (is_dropout_nd) {
broadcasted_mask.Resize(grad_y.dims()); broadcasted_mask.Resize(grad_y.dims());
CalcBroadcastedMask(dev_ctx, mask, &broadcasted_mask); dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);
std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
uint8_t,
uint8_t>(dev_ctx,
broadcast_ins,
&broadcast_outs,
-1,
kps::IdentityFunctor<uint8_t>());
} }
std::vector<const phi::DenseTensor*> ins = { std::vector<const phi::DenseTensor*> ins = {
...@@ -449,12 +474,12 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, ...@@ -449,12 +474,12 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
} else { } else {
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob)); MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
phi::funcs::ElementwiseKernel<T>( phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor)); dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} }
} else { } else {
MT factor = static_cast<MT>(1.0f); MT factor = static_cast<MT>(1.0f);
phi::funcs::ElementwiseKernel<T>( phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor)); dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} }
} }
} }
......
...@@ -45,8 +45,7 @@ void DropoutRawKernel(const Context& dev_ctx, ...@@ -45,8 +45,7 @@ void DropoutRawKernel(const Context& dev_ctx,
x, x,
seed_tensor.get_ptr(), seed_tensor.get_ptr(),
mask, mask,
out, out);
false);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -76,7 +75,8 @@ void DropoutNdKernel(const Context& dev_ctx, ...@@ -76,7 +75,8 @@ void DropoutNdKernel(const Context& dev_ctx,
seed_tensor.get_ptr(), seed_tensor.get_ptr(),
mask, mask,
out, out,
true); true,
axis);
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册