未验证 提交 caa57498 编写于 作者: C crystal 提交者: GitHub

Implement dropout_nd operator to optimize dropout with axis not None. (#42463)

Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 e61b25f9
......@@ -34,10 +34,9 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h"
......@@ -142,15 +141,154 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
}
}
template <typename T1, typename T2 = T1, typename OutT = T1>
struct MaskFunctor {
const float retain_prob_;
using MT = typename details::MPTypeTrait<T1>::Type;
MT factor;
HOSTDEVICE inline MaskFunctor(const float retain_prob)
: retain_prob_(retain_prob) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline void operator()(OutT* dst, const T2* rand, int num) const {
static constexpr int kCount =
phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for (int i = 0; i < kCount; i++) {
if (rand[i] < retain_prob_) {
dst[i] = static_cast<T1>(1);
} else {
dst[i] = static_cast<T1>(0);
}
}
}
};
template <typename T, typename MaskType>
struct DstFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline DstFunctor(const float retain_prob,
const bool is_upscale_in_train,
const int64_t num)
: retain_prob_(retain_prob),
is_upscale_in_train_(is_upscale_in_train),
num_(num) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline T operator()(const T src_val, const MaskType mask) const {
for (int i = 0; i < num_; i++) {
if (mask == static_cast<MaskType>(1)) {
return is_upscale_in_train_
? static_cast<T>(static_cast<MT>(src_val) * factor)
: static_cast<T>(src_val);
} else {
return static_cast<T>(0);
}
}
}
private:
const float retain_prob_;
const bool is_upscale_in_train_;
const int64_t num_;
};
template <typename T, typename MaskType>
__global__ void VectorizedGeneratorMask(const size_t n, uint64_t seed,
const float dropout_prob, const T* src,
MaskType* mask, uint64_t increment,
size_t main_offset) {
constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = hiprandStatePhilox4_32_10_t;
#else
curandStatePhilox4_32_10_t state;
curand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = curandStatePhilox4_32_10_t;
#endif
T dst_mask[kCount]; // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
float rands[kCount];
MaskType mask_result[kCount];
using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>;
int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount;
auto mask_functor = MaskFunctor<T, float>(1.0f - dropout_prob);
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorBinary<float, T, MaskFunctor<T, float>>(
&dst_mask[0], &rands[0], mask_functor, kCount);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[0], Cast());
kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorBinary<float, T, MaskFunctor<T, float>>(
&dst_mask[0], &rands[0], mask_functor, kCount);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[0], Cast());
kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
remainder);
__syncthreads();
}
}
inline void CalcBroadcastedMask(const phi::GPUContext& dev_ctx,
const framework::Tensor& mask,
framework::Tensor* broadcasted_mask) {
// The broadcast of mask can be combined to the following ElementwiseKernel
// when the BroadcastKernel supports different input types.
broadcasted_mask->mutable_data<uint8_t>(dev_ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {&mask};
std::vector<framework::Tensor*> outs = {broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary, uint8_t, uint8_t>(
dev_ctx, ins, &outs, -1, kps::IdentityFunctor<uint8_t>());
}
template <typename T, typename MT>
void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx,
const framework::Tensor& x, framework::Tensor* y,
MT factor) {
std::vector<const framework::Tensor*> ins = {&x};
std::vector<framework::Tensor*> outs = {y};
auto functor = phi::funcs::ScaleFunctor<T>(factor);
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}
template <typename T>
void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
const std::string dropout_implementation,
float dropout_prob, bool upscale_in_train,
bool is_fix_seed, int seed_val,
const framework::Tensor& x,
const framework::Tensor* seed,
framework::Tensor* mask, framework::Tensor* y) {
auto& place = *dev_ctx.eigen_device();
framework::Tensor* mask, framework::Tensor* y,
bool is_dropout_nd = false) {
int64_t x_numel = x.numel();
auto stream = dev_ctx.stream();
auto* x_data = x.data<T>();
......@@ -198,33 +336,38 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t main_offset =
size / (block_size * kVecSize) * (block_size * kVecSize);
if (is_dropout_nd) {
VectorizedGeneratorMask<T, uint8_t><<<grid_size, block_size, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, increment,
main_offset);
framework::Tensor broadcasted_mask;
broadcasted_mask.Resize(x.dims());
CalcBroadcastedMask(dev_ctx, *mask, &broadcasted_mask);
auto dst_functor = DstFunctor<T, uint8_t>(1.0f - dropout_prob,
upscale_in_train, x_numel);
std::vector<const framework::Tensor*> ins = {&x, &broadcasted_mask};
std::vector<framework::Tensor*> outs = {y};
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, dst_functor);
} else {
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T, uint8_t>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
!is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, stream,
offset, KERNEL_PARAMS.As<uint64_t>(1), KERNEL_PARAMS.As<uint64_t>(7),
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment, main_offset);
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
!is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0,
stream, offset, KERNEL_PARAMS.As<uint64_t>(1),
KERNEL_PARAMS.As<uint64_t>(7), size, seed_data, dropout_prob, x_data,
mask_data, y_data, upscale_in_train, increment, main_offset);
#undef PD_DROPOUT_KERNEL_NAME
}
} else {
if (upscale_in_train) {
// todo: can y share with data with x directly?
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
hipMemcpyDeviceToDevice, stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
cudaMemcpyDeviceToDevice, stream));
#endif
// y = x
framework::TensorCopy(x, dev_ctx.GetPlace(), dev_ctx, y);
} else {
using MT = typename details::MPTypeTrait<T>::Type;
MT factor = static_cast<MT>(1.0f - dropout_prob);
std::vector<const framework::Tensor*> ins = {&x};
std::vector<framework::Tensor*> outs = {y};
auto functor = phi::funcs::ScaleFunctor<T>(factor);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
// y = factor * x
ScaleByDropoutFactor<T, MT>(dev_ctx, x, y, factor);
}
}
}
......@@ -246,45 +389,44 @@ struct CudaDropoutGradFunctor {
};
template <typename T>
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
const std::string dropout_implementation,
float dropout_prob,
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
float dropout_prob, bool upscale_in_train,
const framework::Tensor& grad_y,
const framework::Tensor& mask, int64_t size,
const framework::Tensor& mask,
framework::Tensor* grad_x,
bool is_test = false) {
bool is_dropout_nd = false) {
using MT = typename details::MPTypeTrait<T>::Type;
auto stream = dev_ctx.stream();
MT factor;
if (is_test) {
if (dropout_implementation == "upscale_in_train") {
factor = static_cast<MT>(1.0f);
} else {
factor = static_cast<MT>(1.0f - dropout_prob);
}
std::vector<const framework::Tensor*> ins = {&grad_y};
std::vector<framework::Tensor*> outs = {grad_x};
auto functor = phi::funcs::ScaleFunctor<T>(factor);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
MT factor = static_cast<MT>(upscale_in_train ? 1.0f : 1.0f - dropout_prob);
// y = factor * x
ScaleByDropoutFactor<T, MT>(dev_ctx, grad_y, grad_x, factor);
} else {
std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
framework::Tensor broadcasted_mask;
if (is_dropout_nd) {
broadcasted_mask.Resize(grad_y.dims());
CalcBroadcastedMask(dev_ctx, mask, &broadcasted_mask);
}
std::vector<const framework::Tensor*> ins = {
&grad_y, is_dropout_nd ? &broadcasted_mask : &mask};
std::vector<framework::Tensor*> outs = {grad_x};
if (dropout_implementation == "upscale_in_train") {
if (upscale_in_train) {
if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
hipMemset(grad_x->data<T>(), 0, size * sizeof(T));
hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#else
cudaMemset(grad_x->data<T>(), 0, size * sizeof(T));
cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#endif
} else {
factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
}
} else {
factor = static_cast<MT>(1.0f);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
MT factor = static_cast<MT>(1.0f);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
}
}
......
......@@ -161,15 +161,49 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class DropoutNdOpMaker : public DropoutOpMaker {
public:
void Make() override {
DropoutOpMaker::Make();
AddAttr<std::vector<int>>("axis",
"(std::vector<int>). List of integers,"
" indicating the dimensions to be dropout_nd.")
.SetDefault({});
}
};
template <typename T>
class DropoutNdGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("dropout_nd_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("Mask", this->Output("Mask"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(dropout, DropoutInferShapeFunctor,
PD_INFER_META(phi::DropoutInferMeta));
REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
ops::DropoutGradOpMaker<paddle::framework::OpDesc>,
ops::DropoutGradOpMaker<paddle::imperative::OpBase>,
DropoutInferShapeFunctor);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
DECLARE_INFER_SHAPE_FUNCTOR(dropout_nd, DropoutNdInferShapeFunctor,
PD_INFER_META(phi::DropoutNdInferMeta));
REGISTER_OPERATOR(dropout_nd, ops::DropoutOp, ops::DropoutNdOpMaker,
ops::DropoutNdGradOpMaker<paddle::framework::OpDesc>,
ops::DropoutNdGradOpMaker<paddle::imperative::OpBase>,
DropoutNdInferShapeFunctor);
REGISTER_OPERATOR(dropout_nd_grad, ops::DropoutOpGrad);
......@@ -131,8 +131,7 @@ class FMHARef {
auto functor = phi::funcs::ScaleFunctor<T>(alpha);
std::vector<const framework::Tensor*> ins = {&q_tensor};
std::vector<framework::Tensor*> outs = {&q_tensor};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx_, ins,
&outs, functor);
phi::funcs::ElementwiseKernel<T>(dev_ctx_, ins, &outs, functor);
}
// q*k^t, batched_gemm
......@@ -186,13 +185,11 @@ class FMHARef {
if (dropout_param_.dropout_prob_) {
DropoutFwGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
dropout_param_.is_test_,
static_cast<const std::string>(
dropout_param_.dropout_implementation_),
dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_,
dropout_param_.is_fix_seed_, dropout_param_.seed_val_,
dropout_param_.is_test_, dropout_param_.dropout_prob_,
dropout_param_.is_upscale_in_train_, dropout_param_.is_fix_seed_,
dropout_param_.seed_val_,
static_cast<const Tensor&>(*softmax_out_tensor), dropout_param_.seed_,
dropout_mask_out_tensor, dropout_out_tensor);
dropout_mask_out_tensor, dropout_out_tensor, false);
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
dropout_out_data, v_ptr, beta, qktv_out_data,
gemm_batch_size, stride_a, stride_b);
......@@ -288,13 +285,10 @@ class FMHARef {
// dropout bw
if (dropout_param_.dropout_prob_) {
DropoutGradGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
static_cast<const std::string>(
dropout_param_.dropout_implementation_),
dropout_param_.dropout_prob_,
static_cast<const phi::GPUContext&>(dev_ctx_), false,
dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_,
static_cast<const Tensor&>(*dropout_out_grad_tensor),
dropout_mask_out_tensor, softmax_out_grad_tensor->numel(),
softmax_out_grad_tensor);
dropout_mask_out_tensor, softmax_out_grad_tensor, false);
}
if (src_mask_tensor != nullptr) {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.1 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.1
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/hostdevice.h"
namespace paddle {
namespace platform {
// Aligned vector generates vectorized load/store on CUDA.
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
};
template <typename T, int Size>
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
const AlignedVector<T, Size>* addr_vec =
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
*vec = *addr_vec;
}
template <typename T, int Size>
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
AlignedVector<T, Size>* addr_vec =
reinterpret_cast<AlignedVector<T, Size>*>(addr);
*addr_vec = vec;
}
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template <typename T>
int GetVectorizedSize(const T* pointer) {
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec8 = std::alignment_of<AlignedVector<T, 8>>::value; // NOLINT
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
constexpr int vec2 = std::alignment_of<AlignedVector<T, 2>>::value; // NOLINT
if (address % vec8 == 0) {
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return std::min(4, valid_vec_size);
} else if (address % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0) {
return std::min(2, valid_vec_size);
} else {
return 1;
}
}
} // namespace platform
} // namespace paddle
......@@ -886,6 +886,58 @@ void DropoutInferMeta(const MetaTensor& x,
}
}
void DropoutNdInferMeta(const MetaTensor& x,
const MetaTensor& seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
const std::vector<int>& axis,
MetaTensor* out,
MetaTensor* mask) {
auto x_dims = x.dims();
PADDLE_ENFORCE_LE(
axis.size(),
x_dims.size(),
phi::errors::InvalidArgument(
"The length of axis is expected to be less than or equal to the "
"dimension size of x. But recieved the length of axis is %d, the "
"dimension size of x is %d, x's shape is {%s}.",
axis.size(),
x_dims.size(),
x_dims));
for (size_t i = 0; i < axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
axis[i] >= 0 && axis[i] <= x_dims.size() - 1,
true,
phi::errors::InvalidArgument(
"The %d-th value of axis is expected to be greater ot "
"equal to 0 and less than the dimensions of x. But "
"recieved axis is {%s}, the dimension size of x is %d.",
i,
phi::make_ddim(axis),
x_dims.size()));
}
out->set_dims(x_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
if (mask != nullptr) {
std::vector<int64_t> mask_dims(x.dims().size(), 1);
std::for_each(
axis.begin(), axis.end(), [&mask_dims, &x_dims](const int64_t& t) {
mask_dims[t] = x_dims[t];
});
mask->set_dims(make_ddim(mask_dims));
mask->set_dtype(DataType::UINT8);
}
}
void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims();
auto x_rank = static_cast<size_t>(x_dims.size());
......
......@@ -145,6 +145,17 @@ void DropoutInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* mask);
void DropoutNdInferMeta(const MetaTensor& x,
const MetaTensor& seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
const std::vector<int>& axis,
MetaTensor* out,
MetaTensor* mask);
void ElementwiseInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
......
......@@ -21,16 +21,17 @@
namespace phi {
template <typename T, typename Context>
void DropoutGradRawKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
DenseTensor* x_grad) {
void DropoutNdGradKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
const std::vector<int>& axis,
DenseTensor* x_grad) {
auto* grad_x = x_grad;
auto* grad_y = &out_grad;
grad_x->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(grad_x);
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y);
......@@ -44,19 +45,41 @@ void DropoutGradRawKernel(const Context& dev_ctx,
dX.device(place) = dY * static_cast<T>(1.0f - p);
}
} else {
std::vector<int64_t> out_dims = phi::vectorize(out_grad.dims());
auto M = EigenVector<uint8_t>::Flatten(mask);
if (dropout_implementation == "upscale_in_train") {
if (p == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
dX.device(place) = dY * M.cast<T>() / static_cast<T>(1.0f - p);
if (axis.empty()) {
dX.device(place) = dY * M.cast<T>() / static_cast<T>(1.0f - p);
} else {
dX.device(place) =
dY * M.broadcast(out_dims).cast<T>() / static_cast<T>(1.0f - p);
}
}
} else {
dX.device(place) = dY * M.cast<T>();
if (axis.empty()) {
dX.device(place) = dY * M.cast<T>();
} else {
dX.device(place) = dY * M.broadcast(out_dims).cast<T>();
}
}
}
}
template <typename T, typename Context>
void DropoutGradRawKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
DenseTensor* x_grad) {
DropoutNdGradKernel<T, Context>(
dev_ctx, mask, out_grad, p, is_test, mode, {}, x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(dropout_grad,
......@@ -66,3 +89,7 @@ PD_REGISTER_KERNEL(dropout_grad,
float,
double,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
dropout_nd_grad, CPU, ALL_LAYOUT, phi::DropoutNdGradKernel, float, double) {
}
......@@ -17,10 +17,34 @@
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void ComputeDropoutInference(const Context& ctx,
const DenseTensor& x,
float dropout_prob,
bool upscale_in_train,
DenseTensor* y) {
if (upscale_in_train) {
const auto* X_data = x.data<T>();
T* Y_data = ctx.template Alloc<T>(y);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < x.numel(); i++) {
Y_data[i] = X_data[i];
}
} else {
auto X = EigenMatrix<T>::Reshape(x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = *ctx.eigen_device();
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T, typename Context>
void DropoutRawKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -34,13 +58,13 @@ void DropoutRawKernel(const Context& dev_ctx,
DenseTensor* mask) {
auto* y = out;
const auto* x_data = x.data<T>();
auto* y_data = y->mutable_data<T>(dev_ctx.GetPlace());
T* y_data = dev_ctx.template Alloc<T>(y);
float dropout_prob = p;
auto& dropout_implementation = mode;
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
if (!is_test) {
auto* mask_data = mask->mutable_data<uint8_t>(dev_ctx.GetPlace());
auto* mask_data = dev_ctx.template Alloc<uint8_t>(mask);
size_t size = phi::product(mask->dims());
// Special case when dropout_prob is 1.0
......@@ -76,21 +100,92 @@ void DropoutRawKernel(const Context& dev_ctx,
}
}
} else {
if (upscale_in_train) {
const auto* X_data = x.data<T>();
auto* Y_data = y->mutable_data<T>(dev_ctx.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < x.numel(); i++) {
Y_data[i] = X_data[i];
}
ComputeDropoutInference<T, Context>(
dev_ctx, x, dropout_prob, upscale_in_train, y);
}
}
template <typename T, typename Context>
void DropoutNdKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
const std::vector<int>& axis,
DenseTensor* out,
DenseTensor* mask) {
auto* y = out;
const auto* x_data = x.data<T>();
T* y_data = dev_ctx.template Alloc<T>(y);
float dropout_prob = p;
auto& dropout_implementation = mode;
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
if (!is_test) {
DenseTensor t_mask;
t_mask.Resize(mask->dims());
T* t_mask_data = dev_ctx.template Alloc<T>(&t_mask);
auto* mask_data = dev_ctx.template Alloc<uint8_t>(mask);
size_t size = phi::product(mask->dims());
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT
std::memset(t_mask_data, 0, size * sizeof(*t_mask_data)); // NOLINT
std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT
return;
}
// std::minstd_rand engine;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
int seed_data = 0;
if (seed_tensor.get_ptr() != nullptr) {
seed_data = *(seed_tensor->data<int>());
} else {
auto X = EigenMatrix<T>::Reshape(x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = *dev_ctx.eigen_device();
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
seed_data = fix_seed ? seed : 0;
}
auto engine = paddle::framework::GetCPURandomEngine(seed_data);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
t_mask_data[i] = 0;
mask_data[i] = 0;
} else {
t_mask_data[i] = 1;
mask_data[i] = 1;
}
}
auto& x_dims = x.dims();
DenseTensor broadcast_mask;
broadcast_mask.Resize(x_dims);
T* broadcast_mask_data = dev_ctx.template Alloc<T>(&broadcast_mask);
std::vector<int64_t> mask_bst_dims_vec;
for (int i = 0; i < x_dims.size(); i++) {
mask_bst_dims_vec.emplace_back(x_dims[i]);
}
IntArray mask_bst_dims(mask_bst_dims_vec);
ExpandKernel<T, Context>(dev_ctx, t_mask, mask_bst_dims, &broadcast_mask);
for (auto i = 0; i < x.numel(); i++) {
if (broadcast_mask_data[i] == static_cast<T>(1)) {
if (upscale_in_train) {
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
y_data[i] = x_data[i];
}
} else {
y_data[i] = 0;
}
}
} else {
ComputeDropoutInference<T, Context>(
dev_ctx, x, dropout_prob, upscale_in_train, y);
}
}
......@@ -103,3 +198,6 @@ PD_REGISTER_KERNEL(dropout,
float,
double,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {}
......@@ -28,4 +28,14 @@ void DropoutGradRawKernel(const Context& dev_ctx,
const std::string& mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void DropoutNdGradKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
const std::vector<int>& axis,
DenseTensor* x_grad);
} // namespace phi
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
......@@ -31,4 +32,17 @@ void DropoutRawKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* mask);
template <typename T, typename Context>
void DropoutNdKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
const std::vector<int>& axis,
DenseTensor* out,
DenseTensor* mask);
} // namespace phi
......@@ -27,10 +27,25 @@ void DropoutGradRawKernel(const Context& dev_ctx,
bool is_test,
const std::string& mode,
DenseTensor* x_grad) {
bool upscale_in_train = (mode == "upscale_in_train");
x_grad->mutable_data<T>(dev_ctx.GetPlace());
auto size = x_grad->numel();
paddle::operators::DropoutGradGPUKernelDriver<T>(
dev_ctx, mode, p, out_grad, mask, size, x_grad, is_test);
dev_ctx, is_test, p, upscale_in_train, out_grad, mask, x_grad, false);
}
template <typename T, typename Context>
void DropoutNdGradKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
const std::vector<int>& axis,
DenseTensor* x_grad) {
bool upscale_in_train = (mode == "upscale_in_train");
dev_ctx.template Alloc<T>(x_grad);
paddle::operators::DropoutGradGPUKernelDriver<T>(
dev_ctx, is_test, p, upscale_in_train, out_grad, mask, x_grad, true);
}
} // namespace phi
......@@ -43,3 +58,12 @@ PD_REGISTER_KERNEL(dropout_grad,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(dropout_nd_grad,
GPU,
ALL_LAYOUT,
phi::DropoutNdGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -30,22 +30,48 @@ void DropoutRawKernel(const Context& dev_ctx,
bool fix_seed,
DenseTensor* out,
DenseTensor* mask) {
out->mutable_data<T>(dev_ctx.GetPlace());
float dropout_prob = p;
bool upscale_in_train = (mode == "upscale_in_train");
out->mutable_data<T>(dev_ctx.GetPlace());
mask->mutable_data<uint8_t>(dev_ctx.GetPlace());
paddle::operators::DropoutFwGPUKernelDriver<T>(dev_ctx,
is_test,
p,
upscale_in_train,
fix_seed,
seed,
x,
seed_tensor.get_ptr(),
mask,
out,
false);
}
template <typename T, typename Context>
void DropoutNdKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
const std::vector<int>& axis,
DenseTensor* out,
DenseTensor* mask) {
bool upscale_in_train = (mode == "upscale_in_train");
dev_ctx.template Alloc<T>(out);
dev_ctx.template Alloc<uint8_t>(mask);
paddle::operators::DropoutFwGPUKernelDriver<T>(dev_ctx,
is_test,
mode,
dropout_prob,
p,
upscale_in_train,
fix_seed,
seed,
x,
seed_tensor.get_ptr(),
mask,
out);
out,
true);
}
} // namespace phi
......@@ -58,3 +84,12 @@ PD_REGISTER_KERNEL(dropout,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(dropout_nd,
GPU,
ALL_LAYOUT,
phi::DropoutNdKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -32,7 +32,31 @@ KernelSignature DropoutGradOpArgumentMapping(
{"X@GRAD"});
}
KernelSignature DropoutNdOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dropout_nd",
{"X", "Seed"},
{"dropout_prob",
"is_test",
"dropout_implementation",
"seed",
"fix_seed",
"axis"},
{"Out", "Mask"});
}
KernelSignature DropoutNdGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"dropout_nd_grad",
{"Mask", "Out@GRAD"},
{"dropout_prob", "is_test", "dropout_implementation", "axis"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(dropout, phi::DropoutOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(dropout_grad, phi::DropoutGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(dropout_nd, phi::DropoutNdOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(dropout_nd_grad,
phi::DropoutNdGradOpArgumentMapping);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _non_static_mode
from paddle import _C_ops
from paddle.static import default_main_program
def dropout_nd(x,
p=0.5,
axis=None,
training=True,
mode="upscale_in_train",
name=None):
drop_axes = [axis] if isinstance(axis, int) else list(axis)
seed = None
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if _non_static_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
out, mask = _C_ops.dropout_nd(x, 'dropout_prob', p, 'is_test',
not training, 'fix_seed', seed
is not None, 'seed',
seed if seed is not None else 0,
'dropout_implementation', mode, 'axis',
drop_axes)
return out
helper = LayerHelper('dropout_nd', **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'dropout')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
def get_attrs(prog, dropout_prob, is_test, seed):
if (seed is None or seed == 0) and prog.random_seed != 0:
seed = prog.random_seed
attrs = {
'dropout_prob': dropout_prob,
'is_test': is_test,
'fix_seed': seed is not None,
'seed': seed if seed is not None else 0,
'dropout_implementation': mode,
'axis': drop_axes
}
return attrs
attrs = get_attrs(helper.main_program, p, not training, seed)
helper.append_op(type='dropout_nd',
inputs={'X': [x]},
outputs={
'Out': [out],
'Mask': [mask]
},
attrs=attrs)
return out
paddle.enable_static()
class TestDropoutNdOp(OpTest):
def setUp(self):
self.op_type = "dropout_nd"
self.inputs = {'X': np.random.random((4, 32, 16)).astype("float64")}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'axis': [1]
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((1, 32, 1)).astype('uint8')
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestDropoutNdAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def test_dygraph(self):
paddle.disable_static()
for place in self.places:
with fluid.dygraph.guard(place):
in_np = np.random.random([4, 32, 16]).astype("float32")
input = paddle.to_tensor(in_np)
res1 = dropout_nd(x=input, p=0., axis=[0, 1])
res2 = dropout_nd(x=input, p=0.5, axis=[0, 1])
self.assertTrue(np.allclose(res1.numpy(), in_np))
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册