未验证 提交 c5a6ae4c 编写于 作者: S Shang Zhizhou 提交者: GitHub

1, remove layernorm dynamic fp16; 2, let reshape out in dynamic shape (#33535)

* 1, remove layernorm dynamic fp16; 2, let reshape out in dynamic shape

* remove useless code
上级 1f8de080
......@@ -694,7 +694,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
// Paddle-TRT does not support the input tensors: Shape and ShapeTensor
} else if (desc.Input("Shape").size() >= 1 ||
desc.Input("ShapeTensor").size() >= 1 || with_dynamic_shape) {
desc.Input("ShapeTensor").size() >= 1) {
return false;
} else {
std::vector<int> shape =
......
......@@ -182,69 +182,9 @@ int LayerNormPluginDynamic::enqueue(
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp16";
const half *input = reinterpret_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]);
size_t mean_shape_product = 1;
for (auto s : mean_shape_) {
mean_shape_product *= s;
}
size_t variance_shape_product = 1;
for (auto s : variance_shape_) {
variance_shape_product *= s;
}
if (!scale_gpu_half_d_) {
cudaMalloc(&scale_gpu_half_d_, feature_size * sizeof(half));
}
if (!bias_gpu_half_d_) {
cudaMalloc(&bias_gpu_half_d_, feature_size * sizeof(half));
}
if (!mean_gpu_half_d_) {
cudaMalloc(&mean_gpu_half_d_, mean_shape_product * sizeof(half));
}
if (!variance_gpu_half_d_) {
cudaMalloc(&variance_gpu_half_d_, variance_shape_product * sizeof(half));
}
half *scale_cpu_half =
static_cast<half *>(malloc(feature_size * sizeof(half)));
half *bias_cpu_half =
static_cast<half *>(malloc(feature_size * sizeof(half)));
PADDLE_ENFORCE_EQ(
scale_cpu_half && bias_cpu_half, true,
platform::errors::Unavailable("Out of memory, malloc size %d.",
feature_size * sizeof(half)));
for (int i = 0; i < feature_size; i++) {
scale_cpu_half[i] = static_cast<half>(scale_[i]);
bias_cpu_half[i] = static_cast<half>(bias_[i]);
}
cudaMemcpyAsync(scale_gpu_half_d_, scale_cpu_half,
sizeof(half) * feature_size, cudaMemcpyHostToDevice,
stream);
cudaMemcpyAsync(bias_gpu_half_d_, bias_cpu_half,
sizeof(half) * feature_size, cudaMemcpyHostToDevice,
stream);
free(scale_cpu_half);
free(bias_cpu_half);
paddle::operators::LayerNormDirectCUDAFunctor<half> layer_norm;
layer_norm(stream, input, input_shape, bias_gpu_half_d_, scale_gpu_half_d_,
output, mean_gpu_half_d_, variance_gpu_half_d_, begin_norm_axis,
eps);
#else
PADDLE_THROW(platform::errors::Fatal(
"The layer_norm tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The LayerNorm TRT Plugin's input type should be float or half."));
"The LayerNorm TRT Plugin's input type should be float."));
}
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -114,22 +114,14 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
: begin_norm_axis_(begin_norm_axis),
eps_(eps),
mean_shape_(mean_shape),
variance_shape_(variance_shape),
scale_gpu_half_d_(nullptr),
bias_gpu_half_d_(nullptr),
mean_gpu_half_d_(nullptr),
variance_gpu_half_d_(nullptr) {
variance_shape_(variance_shape) {
bias_.resize(bias_num);
scale_.resize(scale_num);
std::copy(bias, bias + bias_num, bias_.data());
std::copy(scale, scale + scale_num, scale_.data());
}
LayerNormPluginDynamic(void const* serialData, size_t serialLength)
: scale_gpu_half_d_(nullptr),
bias_gpu_half_d_(nullptr),
mean_gpu_half_d_(nullptr),
variance_gpu_half_d_(nullptr) {
LayerNormPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &begin_norm_axis_);
......@@ -190,21 +182,6 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
~LayerNormPluginDynamic() {
if (scale_gpu_half_d_) {
cudaFree(scale_gpu_half_d_);
}
if (bias_gpu_half_d_) {
cudaFree(bias_gpu_half_d_);
}
if (mean_gpu_half_d_) {
cudaFree(mean_gpu_half_d_);
}
if (variance_gpu_half_d_) {
cudaFree(variance_gpu_half_d_);
}
}
void destroy() override { delete this; }
private:
......@@ -218,10 +195,6 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
float eps_;
std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_;
half* scale_gpu_half_d_;
half* bias_gpu_half_d_;
half* mean_gpu_half_d_;
half* variance_gpu_half_d_;
};
class LayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
......
......@@ -243,73 +243,6 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
}
}
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForwardFP16(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon,
int feature_size) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U mean_share;
__shared__ U var_share;
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean and var
U mean_val = 0;
U var_val = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
U tmp = static_cast<U>(x[i]);
mean_val += tmp;
var_val += (tmp * tmp);
}
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<U>(mean_val, var_val),
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) {
auto tmp = pair.first_ / static_cast<U>(feature_size);
mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
var[blockIdx.x] = var_share =
static_cast<U>(pair.second_ / static_cast<U>(feature_size) - tmp * tmp);
}
__syncthreads();
mean_val = mean_share;
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
// Step 2: Calculate y
if (scale != nullptr) {
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar);
}
}
} else { // scale == nullptr
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
}
}
}
#endif
}
template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
const int64_t i1_block, const int thr_load_row_off,
......@@ -965,28 +898,6 @@ void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
}
}
template <>
void LayerNormDirectCUDAFunctor<half>::operator()(
gpuStream_t stream, const half *input, std::vector<int> input_shape,
const half *bias, const half *scale, half *output, half *mean,
half *variance, int begin_norm_axis, float eps) {
const auto x_dims = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForwardFP16<half, half,
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end in layer_norm must be larger "
"than 1"));
break;
}
}
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -1076,9 +987,6 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
};
template class LayerNormDirectCUDAFunctor<float>;
#ifdef TRT_PLUGIN_FP16_AVALIABLE
template class LayerNormDirectCUDAFunctor<half>;
#endif
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册