未验证 提交 092d45c3 编写于 作者: L Li Min 提交者: GitHub

Add fused_dropout wrapper to ease use. (#36185)

上级 5e1d0b5c
develop Ligoml-patch-1 ZHUI-patch-1 add_some_yaml_config dingjiaweiww-patch-1 dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_op_flops fix_rnn_docs fix_tensor_type incubate/infrt inplace_addto move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc preln_ernie prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/2.3 release/2.3-fc-ernie-fix release/2.4 revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment support_weight_transpose zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0
无相关合并请求
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
......@@ -196,31 +197,9 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
config.thread_per_block.x * vec_size) +
1) *
vec_size;
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if ((seed) && platform::is_gpu_place(seed->place())) {
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
increment = offset;
} else if (seed && platform::is_cpu_place(seed->place())) {
seed_data = *(seed->data<int>());
increment = offset;
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first;
increment = seed_offset.second;
} else {
if (seed) {
seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
seed_data = is_fix_seed ? seed_val : rnd();
}
increment = offset;
}
GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
&seed_data, &increment);
#ifdef __HIPCC__
if (vec_size == 4 && size % 4 == 0) {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace operators {
inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* seed,
const bool is_fix_seed, const int seed_val,
const int offset, uint64_t* seed_data,
uint64_t* increment) {
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if ((seed) && platform::is_gpu_place(seed->place())) {
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
*increment = offset;
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
*seed_data = seed_offset.first;
*increment = seed_offset.second;
} else {
if (seed) {
*seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
}
*increment = offset;
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
#include "paddle/fluid/operators/math/functors.h"
namespace paddle {
namespace operators {
/**
* Support two Dropouts in the use senarieo.
* This warpper can be used in FFN op.
* The DropoutParam will be used in the fused_dropout_act_bias,
* fused_residual_dropout_bias(pre_layer_norm=ture) or
* fused_layernorm_residual_dropout_bias(pre_layer_norm=false).
*/
struct DropoutParam {
uint64_t seed;
float dropout_prob;
bool is_upscale_in_train;
bool is_test;
bool fix_seed;
int increment;
const framework::Tensor* tensor_seed;
int seed_val;
DropoutParam() {
fix_seed = false;
seed = 0;
is_test = false;
is_upscale_in_train = false;
dropout_prob = 0.5;
tensor_seed = nullptr;
seed_val = 0;
}
/**
* dropout_index: can be 0, 1, 2. 0 means there is only one dropout,
* 1 and 2 represent two dropout, the parameter name of dropout
* will be "dropout" + dropout_index + param name, such as dropout1_seed,
* dropout1_is_test.
*/
DropoutParam(const framework::ExecutionContext& context,
const int dropout_index) {
std::string pre_fix = "dropout";
std::string str_index = std::to_string(dropout_index);
if (dropout_index > 0) {
pre_fix = pre_fix + str_index + "_";
} else {
pre_fix = pre_fix + "_";
}
dropout_prob = context.Attr<float>(pre_fix + "prob");
auto& dropout_implementation =
context.Attr<std::string>(pre_fix + "implementation");
is_upscale_in_train = (dropout_implementation == "upscale_in_train");
is_test = context.Attr<bool>(pre_fix + "is_test");
fix_seed = context.Attr<bool>(pre_fix + "fix_seed");
std::string str_seed = "Dropout";
if (dropout_index > 0) {
str_seed = str_seed + str_index + "Seed";
} else {
str_seed = str_seed + "Seed";
}
tensor_seed =
context.HasInput(str_seed) ? context.Input<Tensor>(str_seed) : nullptr;
seed_val = context.Attr<int>(pre_fix + "seed");
}
int UpdateSeedAndIncrement(const platform::CUDADeviceContext& ctx,
const int offset) {
uint64_t tmp_increment;
GetSeedDataAndIncrement(ctx, tensor_seed, fix_seed, seed_val, offset, &seed,
&tmp_increment);
increment = static_cast<int>(tmp_increment);
return increment;
}
};
template <typename T, typename MaskType>
class FusedDropoutHelper {
private:
int GetIncrement(const platform::CUDADeviceContext& ctx) {
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols_ % VecSize == 0 ? VecSize : 1;
auto config =
Get1DBlocksAnd2DGrids(ctx, static_cast<uint64_t>(rows_),
static_cast<uint64_t>(cols_), real_vec_size);
int increment = ((cols_ - 1) / (config.thread_per_block.x *
config.block_per_grid.x * real_vec_size) +
1) *
real_vec_size;
increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment);
return increment;
}
public:
FusedDropoutHelper() {}
FusedDropoutHelper(const platform::CUDADeviceContext& ctx, const int rows,
const int cols, const DropoutParam& dropout_param) {
rows_ = rows;
cols_ = cols;
dropout_param_ = dropout_param;
}
// out = residual + dropout( src + bias )
void ResidualDropoutBias(const platform::CUDADeviceContext& ctx, const T* src,
const T* residual, const T* bias, T* out,
MaskType* mask) {
auto increment = GetIncrement(ctx);
LaunchResidualDropoutBias<T, MaskType>(
rows_, cols_, increment, dropout_param_.seed,
dropout_param_.dropout_prob, dropout_param_.is_test,
dropout_param_.is_upscale_in_train, src, residual, bias, mask, out,
ctx);
}
void ResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx,
const T* d_out, const MaskType* mask, T* d_src,
T* d_residual, T* d_bias) {
LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out, mask, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(cuda_place, d_residual, cuda_place, d_out,
rows_ * cols_ * sizeof(T), ctx.stream());
}
// out = dropout(activation(src + bias))
void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src,
const T* bias, const std::string& act_method, T* out,
MaskType* mask) {
auto increment = GetIncrement(ctx);
if (act_method == "gelu") {
GeluFunctor<T> gelu;
LaunchDropoutActBias<T, MaskType, GeluFunctor<T>>(
gelu, dropout_param_.seed, rows_, cols_, dropout_param_.increment,
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
dropout_param_.is_test, src, bias, out, mask, ctx);
} else if (act_method == "relu") {
math::ReluFunctor<T> relu;
LaunchDropoutActBias<T, MaskType, math::ReluFunctor<T>>(
relu, dropout_param_.seed, rows_, cols_, increment,
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
dropout_param_.is_test, src, bias, out, mask, ctx);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
}
}
void DropoutActBiasGrad(const platform::CUDADeviceContext& ctx, const T* dout,
const T* src, const T* bias, const MaskType* mask,
T* d_src, T* d_bias, const std::string& act_method) {
if (act_method == "gelu") {
GeluGradFunctor<T> gelu_grad;
LaunchDropoutActBiasGrad<T, MaskType, GeluGradFunctor<T>>(
gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
} else if (act_method == "relu") {
math::ReluGradFunctor<T> relu_grad;
LaunchDropoutActBiasGrad<T, MaskType, math::ReluGradFunctor<T>>(
relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
}
}
protected:
int rows_;
int cols_;
DropoutParam dropout_param_;
};
template <typename T, typename MaskType>
class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
public:
FusedDropoutLayerNormHelper() {}
FusedDropoutLayerNormHelper(const int rows, const int cols,
const float epsilon) {
using U = LayerNormParamType<T>;
this->rows_ = rows;
this->cols_ = cols;
epsilon_ = epsilon;
}
FusedDropoutLayerNormHelper(const platform::CUDADeviceContext& ctx,
const int rows, const int cols,
const DropoutParam& dropout_param,
const float epsilon)
: FusedDropoutHelper<T, MaskType>(ctx, rows, cols, dropout_param) {
using U = LayerNormParamType<T>;
epsilon_ = epsilon;
}
// call layer_norm
void LayerNorm(const platform::CUDADeviceContext& ctx, const T* src,
const LayerNormParamType<T>* gamma,
const LayerNormParamType<T>* beta, T* out,
LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) {
using U = LayerNormParamType<T>;
switch (GetDesiredBlockDim(this->cols_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<
T, U, kBlockDim><<<this->rows_, kBlockDim, 0, ctx.stream()>>>(
src, gamma, beta, out, mean, variance, epsilon_, this->cols_));
}
}
void LayerNormGrad(const platform::CUDADeviceContext& ctx, const T* dout,
const T* src, const LayerNormParamType<T>* gamma,
const LayerNormParamType<T>* mean,
const LayerNormParamType<T>* variance, T* d_src,
LayerNormParamType<T>* d_scale,
LayerNormParamType<T>* d_bias) {
using U = LayerNormParamType<T>;
LayerNormBackward<T, U>(src, dout, gamma, mean, variance, d_src, d_scale,
d_bias, epsilon_, this->rows_, this->cols_, ctx);
}
// out = layernorm(residual + dropout(src + bias))
void LayernormResidualDropoutBias(
const platform::CUDADeviceContext& ctx, const T* src, const T* residual,
const T* bias, const LayerNormParamType<T>* gamma,
const LayerNormParamType<T>* beta, T* dropout_out, MaskType* mask, T* out,
LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) {
using U = LayerNormParamType<T>;
int vec_size = MAX_CACHE_BYTES / sizeof(T);
if (this->cols_ % vec_size != 0) {
vec_size = 1;
}
int threads = GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T, MaskType>(
this->rows_, this->cols_, increment, this->dropout_param_.seed,
this->dropout_param_.dropout_prob, epsilon_,
this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test,
src, residual, bias, gamma, beta, mask, dropout_out, out, mean,
variance, ctx);
}
void LayernormResidualDropoutBiasGrad(
const platform::CUDADeviceContext& ctx, const T* d_out,
const T* layernorm_src, const MaskType* mask,
const LayerNormParamType<T>* gamma, const LayerNormParamType<T>* mean,
const LayerNormParamType<T>* variance, T* d_layernorm_src,
LayerNormParamType<T>* d_scale, LayerNormParamType<T>* d_layernorm_bias,
T* d_dropout_src, T* d_bias, T* d_residual) {
using U = LayerNormParamType<T>;
LayerNormBackward<T, U>(layernorm_src, d_out, gamma, mean, variance,
d_layernorm_src, d_scale, d_layernorm_bias,
epsilon_, this->rows_, this->cols_, ctx);
this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src,
d_residual, d_bias);
}
protected:
float epsilon_;
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部