未验证 提交 34fd65cf 编写于 作者: W Wang Bojun 提交者: GitHub

Group norm fp16 support (#48222)

* group norm fp16 support
上级 9a227ee7
...@@ -34,7 +34,7 @@ class GroupNormOpConverter : public OpConverter { ...@@ -34,7 +34,7 @@ class GroupNormOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, const framework::Scope& scope,
bool test_mode) override { bool test_mode) override {
VLOG(3) << "convert a fluid group_norm op"; VLOG(4) << "convert a fluid group_norm op to tensorrt group_norm plugin";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -61,6 +61,8 @@ class GroupNormOpConverter : public OpConverter { ...@@ -61,6 +61,8 @@ class GroupNormOpConverter : public OpConverter {
framework::DDim bias_dims; framework::DDim bias_dims;
auto scale_weights = GetWeight(scale_name, &scale_dims); auto scale_weights = GetWeight(scale_name, &scale_dims);
auto bias_weights = GetWeight(bias_name, &bias_dims); auto bias_weights = GetWeight(bias_name, &bias_dims);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
int gn_num = groups; int gn_num = groups;
std::vector<int64_t> mean_shape({gn_num}); std::vector<int64_t> mean_shape({gn_num});
...@@ -74,7 +76,8 @@ class GroupNormOpConverter : public OpConverter { ...@@ -74,7 +76,8 @@ class GroupNormOpConverter : public OpConverter {
epsilon, epsilon,
groups, groups,
mean_shape, mean_shape,
variance_shape); variance_shape,
with_fp16);
nvinfer1::ILayer* groupnorm_layer = nvinfer1::ILayer* groupnorm_layer =
engine_->AddDynamicPlugin(&input_itensor, 1, plugin); engine_->AddDynamicPlugin(&input_itensor, 1, plugin);
auto output_name = op_desc.Output("Y")[0]; auto output_name = op_desc.Output("Y")[0];
...@@ -92,7 +95,8 @@ class GroupNormOpConverter : public OpConverter { ...@@ -92,7 +95,8 @@ class GroupNormOpConverter : public OpConverter {
epsilon, epsilon,
groups, groups,
mean_shape, mean_shape,
variance_shape); variance_shape,
with_fp16);
nvinfer1::ILayer* groupnorm_layer = nvinfer1::ILayer* groupnorm_layer =
engine_->AddPlugin(&input_itensor, 1, plugin); engine_->AddPlugin(&input_itensor, 1, plugin);
auto output_name = op_desc.Output("Y")[0]; auto output_name = op_desc.Output("Y")[0];
......
...@@ -415,15 +415,6 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -415,15 +415,6 @@ struct SimpleOpTypeSetTeller : public Teller {
<< layout_str; << layout_str;
return false; return false;
} }
auto* block = desc.Block();
if (block == nullptr) return false;
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
auto dtype = x_var_desc->GetDataType();
if (dtype != 5) {
VLOG(3) << "Group norm trt plugin only support float32";
return false;
}
} }
if (op_type == "concat") { if (op_type == "concat") {
if (!desc.HasAttr("axis")) { if (!desc.HasAttr("axis")) {
......
...@@ -25,7 +25,53 @@ namespace tensorrt { ...@@ -25,7 +25,53 @@ namespace tensorrt {
namespace plugin { namespace plugin {
using DataLayout = phi::DataLayout; using DataLayout = phi::DataLayout;
int GroupNormPlugin::initialize() TRT_NOEXCEPT { return 0; } int GroupNormPlugin::initialize() TRT_NOEXCEPT {
if (!with_fp16_) {
// if use fp32
cudaMalloc(&scale_gpu_, sizeof(float) * scale_.size());
cudaMalloc(&bias_gpu_, sizeof(float) * bias_.size());
cudaMemcpy(scale_gpu_,
scale_.data(),
scale_.size() * sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(bias_gpu_,
bias_.data(),
bias_.size() * sizeof(float),
cudaMemcpyHostToDevice);
} else {
// if use fp16
std::vector<half> scale_half(scale_.size());
std::vector<half> bias_half(bias_.size());
for (int i = 0; i < scale_.size(); ++i) {
scale_half[i] = static_cast<half>(scale_[i]);
}
for (int i = 0; i < bias_.size(); ++i) {
bias_half[i] = static_cast<half>(bias_[i]);
}
cudaMalloc(&scale_gpu_, sizeof(half) * scale_half.size());
cudaMalloc(&bias_gpu_, sizeof(half) * bias_half.size());
cudaMemcpy(scale_gpu_,
scale_half.data(),
scale_half.size() * sizeof(half),
cudaMemcpyHostToDevice);
cudaMemcpy(bias_gpu_,
bias_half.data(),
bias_half.size() * sizeof(half),
cudaMemcpyHostToDevice);
}
return 0;
}
bool GroupNormPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return ((type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kLINEAR));
}
}
nvinfer1::Dims GroupNormPlugin::getOutputDimensions( nvinfer1::Dims GroupNormPlugin::getOutputDimensions(
int index, const nvinfer1::Dims *inputDims, int nbInputs) TRT_NOEXCEPT { int index, const nvinfer1::Dims *inputDims, int nbInputs) TRT_NOEXCEPT {
...@@ -70,48 +116,48 @@ int GroupNormPlugin::enqueue(int batch_size, ...@@ -70,48 +116,48 @@ int GroupNormPlugin::enqueue(int batch_size,
"but got channel number:%d, bias's size:%d.", "but got channel number:%d, bias's size:%d.",
C, C,
bias_.size())); bias_.size()));
float *mean_d = static_cast<float *>(workspace);
int device_id; float *variance_d = mean_d + input_shape[0] * groups_;
cudaGetDevice(&device_id); float *temp_variance_d = variance_d + input_shape[0] * groups_;
const float *input = static_cast<const float *>(inputs[0]); auto input_type = getDataType();
float *output = static_cast<float *>(outputs[0]); if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. GroupNorm-->fp32";
scale_t.Resize(phi::make_ddim({C})); const float *input = static_cast<const float *>(inputs[0]);
bias_t.Resize(phi::make_ddim({C})); float *output = static_cast<float *>(outputs[0]);
phi::GroupNormDirectCUDAFunctor<float> group_norm;
mean_t.Resize(phi::make_ddim(mean_shape_)); group_norm(stream,
variance_t.Resize(phi::make_ddim(variance_shape_)); input,
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id)); input_shape,
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id)); reinterpret_cast<float *>(bias_gpu_),
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id)); reinterpret_cast<float *>(scale_gpu_),
float *variance_d = temp_variance_d,
variance_t.mutable_data<float>(platform::CUDAPlace(device_id)); groups_,
eps_,
phi::DenseTensor temp_variance_t; output,
temp_variance_t.Resize(phi::make_ddim(variance_shape_)); mean_d,
float *temp_variance_d = variance_d,
temp_variance_t.mutable_data<float>(platform::CUDAPlace(device_id)); DataLayout::kNCHW);
cudaMemcpyAsync(scale_d, } else if (input_type == nvinfer1::DataType::kHALF) {
scale_.data(), VLOG(1) << "TRT Plugin DataType selected. GroupNorm-->fp16";
sizeof(float) * C, const half *input = static_cast<const half *>(inputs[0]);
cudaMemcpyHostToDevice, half *output = static_cast<half *>(outputs[0]);
stream); phi::GroupNormDirectCUDAFunctor<half, float> group_norm;
cudaMemcpyAsync( group_norm(stream,
bias_d, bias_.data(), sizeof(float) * C, cudaMemcpyHostToDevice, stream); input,
phi::GroupNormDirectCUDAFunctor<float> group_norm; input_shape,
group_norm(stream, reinterpret_cast<const half *>(bias_gpu_),
input, reinterpret_cast<const half *>(scale_gpu_),
input_shape, temp_variance_d,
bias_d, groups_,
scale_d, eps_,
mean_d, output,
temp_variance_d, mean_d,
groups_, variance_d,
eps_, DataLayout::kNCHW);
output, } else {
mean_d, PADDLE_THROW(platform::errors::Fatal(
variance_d, "The GroupNorm TRT Plugin's input type should be float or half."));
DataLayout::kNCHW); }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
nvinfer1::DimsExprs GroupNormPluginDynamic::getOutputDimensions( nvinfer1::DimsExprs GroupNormPluginDynamic::getOutputDimensions(
...@@ -140,8 +186,13 @@ bool GroupNormPluginDynamic::supportsFormatCombination( ...@@ -140,8 +186,13 @@ bool GroupNormPluginDynamic::supportsFormatCombination(
nb_inputs + nb_outputs)); nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos]; const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) { if (pos == 0) {
return (in.type == nvinfer1::DataType::kFLOAT) && if (with_fp16_) {
(in.format == nvinfer1::TensorFormat::kLINEAR); return ((in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::PluginFormat::kLINEAR));
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
} }
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output // output
...@@ -158,8 +209,50 @@ nvinfer1::DataType GroupNormPluginDynamic::getOutputDataType( ...@@ -158,8 +209,50 @@ nvinfer1::DataType GroupNormPluginDynamic::getOutputDataType(
"The groupnorm Plugin only has one input, so the " "The groupnorm Plugin only has one input, so the "
"index value should be 0, but get %d.", "index value should be 0, but get %d.",
index)); index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
input_types[0] == nvinfer1::DataType::kHALF),
true,
platform::errors::InvalidArgument(
"The input type should be half or float"));
return input_types[0]; return input_types[0];
} }
int GroupNormPluginDynamic::initialize() TRT_NOEXCEPT {
if (with_fp16_ == false) {
// if use fp32
cudaMalloc(&scale_gpu_, sizeof(float) * scale_.size());
cudaMalloc(&bias_gpu_, sizeof(float) * bias_.size());
cudaMemcpy(scale_gpu_,
scale_.data(),
scale_.size() * sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(bias_gpu_,
bias_.data(),
bias_.size() * sizeof(float),
cudaMemcpyHostToDevice);
} else {
// if use fp16
std::vector<half> scale_half(scale_.size());
std::vector<half> bias_half(bias_.size());
for (int i = 0; i < scale_.size(); ++i) {
scale_half[i] = static_cast<half>(scale_[i]);
}
for (int i = 0; i < bias_.size(); ++i) {
bias_half[i] = static_cast<half>(bias_[i]);
}
cudaMalloc(&scale_gpu_, sizeof(half) * scale_.size());
cudaMalloc(&bias_gpu_, sizeof(half) * bias_.size());
cudaMemcpy(scale_gpu_,
scale_half.data(),
scale_half.size() * sizeof(half),
cudaMemcpyHostToDevice);
cudaMemcpy(bias_gpu_,
bias_half.data(),
bias_half.size() * sizeof(half),
cudaMemcpyHostToDevice);
}
return 0;
}
int GroupNormPluginDynamic::enqueue( int GroupNormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *input_desc,
...@@ -202,46 +295,38 @@ int GroupNormPluginDynamic::enqueue( ...@@ -202,46 +295,38 @@ int GroupNormPluginDynamic::enqueue(
C, C,
bias_.size())); bias_.size()));
int device_id; float *mean_d = static_cast<float *>(workspace);
cudaGetDevice(&device_id); float *variance_d = mean_d + input_shape[0] * groups_;
float *temp_variance_d = variance_d + input_shape[0] * groups_;
auto input_type = input_desc[0].type; auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) { if (input_type == nvinfer1::DataType::kFLOAT) {
const float *input = static_cast<const float *>(inputs[0]); VLOG(1) << "TRT Plugin DataType selected. GroupNorm-->fp32";
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]); float *output = static_cast<float *>(outputs[0]);
scale_t.Resize(phi::make_ddim({C})); phi::GroupNormDirectCUDAFunctor<float, float> group_norm;
bias_t.Resize(phi::make_ddim({C}));
mean_t.Resize(phi::make_ddim(batched_mean_shape));
variance_t.Resize(phi::make_ddim(batched_variance_shape));
float *scale_d =
scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
phi::DenseTensor temp_variance_t;
temp_variance_t.Resize(phi::make_ddim(batched_variance_shape));
float *temp_variance_d =
temp_variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(scale_d,
scale_.data(),
sizeof(float) * C,
cudaMemcpyHostToDevice,
stream);
cudaMemcpyAsync(bias_d,
bias_.data(),
sizeof(float) * C,
cudaMemcpyHostToDevice,
stream);
phi::GroupNormDirectCUDAFunctor<float> group_norm;
group_norm(stream, group_norm(stream,
input, input,
input_shape, input_shape,
bias_d, reinterpret_cast<float *>(bias_gpu_),
scale_d, reinterpret_cast<float *>(scale_gpu_),
temp_variance_d,
groups,
eps,
output,
mean_d, mean_d,
variance_d,
DataLayout::kNCHW);
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. GroupNorm-->fp16";
const half *input = reinterpret_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]);
phi::GroupNormDirectCUDAFunctor<half, float> group_norm;
group_norm(stream,
input,
input_shape,
reinterpret_cast<half *>(bias_gpu_),
reinterpret_cast<half *>(scale_gpu_),
temp_variance_d, temp_variance_d,
groups, groups,
eps, eps,
......
...@@ -32,7 +32,7 @@ class GroupNormPlugin : public PluginTensorRT { ...@@ -32,7 +32,7 @@ class GroupNormPlugin : public PluginTensorRT {
return getBaseSerializationSize() + SerializedSize(scale_) + return getBaseSerializationSize() + SerializedSize(scale_) +
SerializedSize(bias_) + SerializedSize(eps_) + SerializedSize(bias_) + SerializedSize(eps_) +
SerializedSize(groups_) + SerializedSize(mean_shape_) + SerializedSize(groups_) + SerializedSize(mean_shape_) +
SerializedSize(variance_shape_); SerializedSize(variance_shape_) + SerializedSize(with_fp16_);
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
serializeBase(buffer); serializeBase(buffer);
...@@ -42,6 +42,7 @@ class GroupNormPlugin : public PluginTensorRT { ...@@ -42,6 +42,7 @@ class GroupNormPlugin : public PluginTensorRT {
SerializeValue(&buffer, groups_); SerializeValue(&buffer, groups_);
SerializeValue(&buffer, mean_shape_); SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_); SerializeValue(&buffer, variance_shape_);
SerializeValue(&buffer, with_fp16_);
} }
GroupNormPlugin(const float* scale, GroupNormPlugin(const float* scale,
...@@ -51,11 +52,13 @@ class GroupNormPlugin : public PluginTensorRT { ...@@ -51,11 +52,13 @@ class GroupNormPlugin : public PluginTensorRT {
float eps, float eps,
int groups, int groups,
std::vector<int64_t> mean_shape, std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape) std::vector<int64_t> variance_shape,
bool with_fp16)
: groups_(groups), : groups_(groups),
eps_(eps), eps_(eps),
mean_shape_(mean_shape), mean_shape_(mean_shape),
variance_shape_(variance_shape) { variance_shape_(variance_shape),
with_fp16_(with_fp16) {
scale_.resize(scale_num); scale_.resize(scale_num);
bias_.resize(bias_num); bias_.resize(bias_num);
std::copy(scale, scale + scale_num, scale_.data()); std::copy(scale, scale + scale_num, scale_.data());
...@@ -69,22 +72,33 @@ class GroupNormPlugin : public PluginTensorRT { ...@@ -69,22 +72,33 @@ class GroupNormPlugin : public PluginTensorRT {
DeserializeValue(&serialData, &serialLength, &groups_); DeserializeValue(&serialData, &serialLength, &groups_);
DeserializeValue(&serialData, &serialLength, &mean_shape_); DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_); DeserializeValue(&serialData, &serialLength, &variance_shape_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
} }
~GroupNormPlugin() {} ~GroupNormPlugin() {}
int initialize() TRT_NOEXCEPT override; int initialize() TRT_NOEXCEPT override;
GroupNormPlugin* clone() const TRT_NOEXCEPT override { GroupNormPlugin* clone() const TRT_NOEXCEPT override {
return new GroupNormPlugin(scale_.data(), auto* ptr = new GroupNormPlugin(scale_.data(),
scale_.size(), scale_.size(),
bias_.data(), bias_.data(),
bias_.size(), bias_.size(),
eps_, eps_,
groups_, groups_,
mean_shape_, mean_shape_,
variance_shape_); variance_shape_,
with_fp16_);
ptr->scale_gpu_ = scale_gpu_;
ptr->bias_gpu_ = bias_gpu_;
return ptr;
} }
const char* getPluginType() const TRT_NOEXCEPT override { const char* getPluginType() const TRT_NOEXCEPT override {
return "groupnorm_plugin"; return "groupnorm_plugin";
} }
size_t getWorkspaceSize(int max_batch_size) const TRT_NOEXCEPT {
return 3 * max_batch_size * groups_;
}
bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims* inputs, const nvinfer1::Dims* inputs,
...@@ -101,18 +115,27 @@ class GroupNormPlugin : public PluginTensorRT { ...@@ -101,18 +115,27 @@ class GroupNormPlugin : public PluginTensorRT {
#endif #endif
void* workspace, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override; cudaStream_t stream) TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override {
if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
};
private: private:
std::vector<float> scale_; std::vector<float> scale_;
std::vector<float> bias_; std::vector<float> bias_;
phi::DenseTensor scale_t; void* scale_gpu_;
phi::DenseTensor bias_t; void* bias_gpu_;
phi::DenseTensor mean_t;
phi::DenseTensor variance_t;
int groups_; int groups_;
float eps_; float eps_;
std::vector<int64_t> mean_shape_; std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_; std::vector<int64_t> variance_shape_;
bool with_fp16_;
}; };
class GroupNormPluginCreator : public TensorRTPluginCreator { class GroupNormPluginCreator : public TensorRTPluginCreator {
public: public:
...@@ -138,11 +161,13 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -138,11 +161,13 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
float eps, float eps,
int groups, int groups,
std::vector<int64_t> mean_shape, std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape) std::vector<int64_t> variance_shape,
bool with_fp16)
: groups_(groups), : groups_(groups),
eps_(eps), eps_(eps),
mean_shape_(mean_shape), mean_shape_(mean_shape),
variance_shape_(variance_shape) { variance_shape_(variance_shape),
with_fp16_(with_fp16) {
scale_.resize(scale_num); scale_.resize(scale_num);
bias_.resize(bias_num); bias_.resize(bias_num);
std::copy(scale, scale + scale_num, scale_.data()); std::copy(scale, scale + scale_num, scale_.data());
...@@ -156,28 +181,34 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -156,28 +181,34 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
DeserializeValue(&serialData, &serialLength, &groups_); DeserializeValue(&serialData, &serialLength, &groups_);
DeserializeValue(&serialData, &serialLength, &mean_shape_); DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_); DeserializeValue(&serialData, &serialLength, &variance_shape_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
} }
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new GroupNormPluginDynamic(scale_.data(), auto* ptr = new GroupNormPluginDynamic(scale_.data(),
scale_.size(), scale_.size(),
bias_.data(), bias_.data(),
bias_.size(), bias_.size(),
eps_, eps_,
groups_, groups_,
mean_shape_, mean_shape_,
variance_shape_); variance_shape_,
with_fp16_);
ptr->scale_gpu_ = scale_gpu_;
ptr->bias_gpu_ = bias_gpu_;
return ptr;
} }
const char* getPluginType() const TRT_NOEXCEPT override { const char* getPluginType() const TRT_NOEXCEPT override {
return "groupnorm_plugin_dynamic"; return "groupnorm_plugin_dynamic";
} }
int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; } int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(scale_) + SerializedSize(bias_) + return SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(eps_) + SerializedSize(groups_) + SerializedSize(eps_) + SerializedSize(groups_) +
SerializedSize(mean_shape_) + SerializedSize(variance_shape_); SerializedSize(mean_shape_) + SerializedSize(variance_shape_) +
SerializedSize(with_fp16_);
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, scale_); SerializeValue(&buffer, scale_);
...@@ -186,6 +217,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -186,6 +217,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
SerializeValue(&buffer, groups_); SerializeValue(&buffer, groups_);
SerializeValue(&buffer, mean_shape_); SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_); SerializeValue(&buffer, variance_shape_);
SerializeValue(&buffer, with_fp16_);
} }
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int output_index, int output_index,
...@@ -208,7 +240,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -208,7 +240,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override { int nbOutputs) const TRT_NOEXCEPT override {
return 0; return 3 * inputs[0].dims.d[0] * groups_ * sizeof(float);
} }
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
...@@ -222,19 +254,27 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -222,19 +254,27 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
TRT_NOEXCEPT override; TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; } void destroy() TRT_NOEXCEPT override { delete this; }
// void terminate() TRT_NOEXCEPT override; void terminate() TRT_NOEXCEPT override {
if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
};
private: private:
std::vector<float> scale_; std::vector<float> scale_;
std::vector<float> bias_; std::vector<float> bias_;
phi::DenseTensor scale_t; void* scale_gpu_ = nullptr;
phi::DenseTensor bias_t; void* bias_gpu_ = nullptr;
phi::DenseTensor mean_t;
phi::DenseTensor variance_t;
int groups_; int groups_;
float eps_; float eps_;
std::vector<int64_t> mean_shape_; std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_; std::vector<int64_t> variance_shape_;
bool with_fp16_;
}; };
class GroupNormPluginDynamicCreator : public TensorRTPluginCreator { class GroupNormPluginDynamicCreator : public TensorRTPluginCreator {
public: public:
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace phi { namespace phi {
template <typename T, int flags> template <typename T, typename AccT, int flags>
__global__ void GroupNormBackwardGetMeanAndVar(const T* x, __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
const T* scale, const T* scale,
const T* bias, const T* bias,
...@@ -33,9 +33,9 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, ...@@ -33,9 +33,9 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
int imsize, int imsize,
int groups, int groups,
int group_size, int group_size,
T epsilon, float epsilon,
T* d_mean, AccT* d_mean,
T* d_var, AccT* d_var,
T* d_scale, T* d_scale,
T* d_bias) { T* d_bias) {
int gid = blockIdx.y; int gid = blockIdx.y;
...@@ -45,29 +45,35 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, ...@@ -45,29 +45,35 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
int number = min(group_size, static_cast<int>(C - gid * group_size)); int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid; int ccid = gid * group_size + cid;
if (ccid >= C) return; if (ccid >= C) return;
T x_scale = (flags & kHasScale) ? scale[ccid] : 1; T x_scale = (flags & kHasScale) ? scale[ccid] : static_cast<T>(1);
T x_bias = (flags & kHasBias) ? bias[ccid] : 0; T x_bias = (flags & kHasBias) ? bias[ccid] : static_cast<T>(0);
T x_scale_inv = 0; T x_scale_inv = static_cast<T>(0);
if (x_scale != 0) x_scale_inv = 1.0 / x_scale; if (x_scale != static_cast<T>(0)) x_scale_inv = static_cast<T>(1.0) / x_scale;
T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0; AccT d_mean_data = static_cast<AccT>(0);
AccT d_var_data = static_cast<AccT>(0);
T d_scale_data = static_cast<T>(0);
T d_bias_data = static_cast<T>(0);
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val, dval; AccT val, dval;
int hid = imid / W; int hid = imid / W;
int wid = imid % W; int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias; val = static_cast<AccT>(x[(bid * H + hid) * W * C + wid * C + ccid]) -
dval = d_y[(bid * H + hid) * W * C + wid * C + ccid]; static_cast<AccT>(x_bias);
dval = static_cast<AccT>(d_y[(bid * H + hid) * W * C + wid * C + ccid]);
d_var_data += val * dval; d_var_data += val * dval;
d_mean_data += dval * x_scale; d_mean_data += dval * static_cast<AccT>(x_scale);
val = val * x_scale_inv; val = val * static_cast<AccT>(x_scale_inv);
d_bias_data += dval; d_bias_data += static_cast<T>(dval);
d_scale_data += val * dval; d_scale_data += static_cast<T>(val * dval);
} }
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data); CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]),
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data); static_cast<AccT>(d_mean_data));
CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]),
static_cast<AccT>(d_var_data));
if (flags & kHasScale) { if (flags & kHasScale) {
#if CUDA_VERSION >= 11070 #if CUDA_VERSION >= 11070
...@@ -85,22 +91,24 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, ...@@ -85,22 +91,24 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
} }
} }
template <typename T, int flags> template <typename T, typename AccT, int flags>
__global__ void GroupNormBackward(const T* x, __global__ void GroupNormBackward(const T* x,
const T* d_y, const T* d_y,
const T* scale, const T* scale,
const T* bias, const T* bias,
const T* var, const AccT* var,
const T* d_mean, const AccT* d_mean,
const T* d_var, const AccT* d_var,
int N, int N,
int C, int C,
int W, int W,
int imsize, int imsize,
int groups, int groups,
int group_size, int group_size,
T epsilon, float epsilon,
T* d_x) { T* d_x) {
// using AccT = typename kps::details::MPTypeTrait<T>::Type;
int gid = blockIdx.y; int gid = blockIdx.y;
int cid = blockIdx.x; int cid = blockIdx.x;
int bid = blockIdx.z; int bid = blockIdx.z;
...@@ -108,132 +116,138 @@ __global__ void GroupNormBackward(const T* x, ...@@ -108,132 +116,138 @@ __global__ void GroupNormBackward(const T* x,
int number = min(group_size, static_cast<int>(C - gid * group_size)); int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid; int ccid = gid * group_size + cid;
if (ccid >= C) return; if (ccid >= C) return;
T x_var = var[bid * groups + gid]; AccT x_var = var[bid * groups + gid];
T d_x_mean = d_mean[bid * groups + gid]; AccT d_x_mean = static_cast<AccT>(d_mean[bid * groups + gid]);
T d_x_var = d_var[bid * groups + gid]; AccT d_x_var = static_cast<AccT>(d_var[bid * groups + gid]);
T x_var_inv = 1.0 / sqrt(x_var + epsilon); AccT x_var_inv = static_cast<AccT>(1.0) / sqrt((x_var) + epsilon);
T number_inv = 1.0 / (number * imsize); AccT number_inv =
static_cast<AccT>(1.0) / static_cast<AccT>((number * imsize));
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0; AccT x_scale = (flags & kHasScale) ? static_cast<AccT>(scale[ccid])
T x_scale_inv = 0; : static_cast<AccT>(1);
if (x_scale != 0) x_scale_inv = 1.0 / x_scale; AccT x_bias =
(flags & kHasBias) ? static_cast<AccT>(bias[ccid]) : static_cast<AccT>(0);
AccT x_scale_inv = static_cast<T>(0);
if (x_scale != static_cast<AccT>(0))
x_scale_inv = static_cast<AccT>(1.0) / x_scale;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int hid = imid / W; int hid = imid / W;
int wid = imid % W; int wid = imid % W;
T tmp = x[(bid * H + hid) * W * C + wid * C + ccid]; AccT tmp = static_cast<AccT>(x[(bid * H + hid) * W * C + wid * C + ccid]);
T v_y = (tmp - x_bias) * x_scale_inv; AccT v_y = (tmp - x_bias) * x_scale_inv;
T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid]; AccT dly = static_cast<AccT>(d_y[(bid * H + hid) * W * C + wid * C + ccid]);
d_x[(bid * H + hid) * W * C + wid * C + ccid] = d_x[(bid * H + hid) * W * C + wid * C + ccid] =
x_var_inv * static_cast<T>(x_var_inv * ((dly) * (x_scale)-number_inv * d_x_var *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); (v_y)-number_inv * d_x_mean));
} }
} }
template <typename T> template <typename T, typename AccT>
__global__ void ScalarGetDsDbCUDAKernel( __global__ void ScalarGetDsDbCUDAKernel(
int imsize, const T* x, const T* dy, T* ds, T* db) { int imsize, const T* x, const T* dy, AccT* ds, AccT* db) {
const int nc = blockIdx.x; const int nc = blockIdx.x;
T ds_sum = 0; AccT ds_sum = 0;
T db_sum = 0; AccT db_sum = 0;
for (int i = threadIdx.x; i < imsize; i += blockDim.x) { for (int i = threadIdx.x; i < imsize; i += blockDim.x) {
const int index = nc * imsize + i; const int index = nc * imsize + i;
ds_sum += dy[index] * x[index]; ds_sum += static_cast<AccT>(dy[index]) * static_cast<AccT>(x[index]);
db_sum += dy[index]; db_sum += static_cast<AccT>(dy[index]);
} }
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1); ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1);
} }
template <typename T> template <typename T, typename AccT>
__global__ void GetScaleBiasGradientCUDAKernel(int N, __global__ void GetScaleBiasGradientCUDAKernel(int N,
int C, int C,
int group, int group,
T epsilon, float epsilon,
const T* mean, const AccT* mean,
const T* var, const AccT* var,
const T* ds, const AccT* ds,
const T* db, const AccT* db,
T* d_scale, T* d_scale,
T* d_bias) { T* d_bias) {
const int c = blockIdx.x * blockDim.x + threadIdx.x; const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < C) { if (c < C) {
const int G = group; const int G = group;
const int D = C / G; const int D = C / G;
T sum1 = 0; AccT sum1 = static_cast<AccT>(0);
T sum2 = 0; AccT sum2 = static_cast<AccT>(0);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
const int nc = n * C + c; const int nc = n * C + c;
const int ng = n * G + c / D; const int ng = n * G + c / D;
sum1 += (d_scale == nullptr) sum1 +=
? T(0) (d_scale == nullptr)
: ((ds[nc] - db[nc] * static_cast<T>(mean[ng])) * ? AccT(0)
static_cast<T>(rsqrt(var[ng] + epsilon))); : ((ds[nc] - db[nc] * (mean[ng])) * (rsqrt((var[ng]) + epsilon)));
sum2 += (d_bias == nullptr) ? T(0) : db[nc]; sum2 += (d_bias == nullptr) ? AccT(0) : db[nc];
} }
if (d_scale != nullptr) { if (d_scale != nullptr) {
d_scale[c] = sum1; d_scale[c] = static_cast<T>(sum1);
} }
if (d_bias != nullptr) { if (d_bias != nullptr) {
d_bias[c] = sum2; d_bias[c] = static_cast<T>(sum2);
} }
} }
} }
template <typename T, int BlockDim> template <typename T, typename AccT, int BlockDim>
__global__ void GetBackwardParamsCUDAKernel(int imsize, __global__ void GetBackwardParamsCUDAKernel(int imsize,
int groups, int groups,
int group_size, int group_size,
T epsilon, float epsilon,
const T* mean, const AccT* mean,
const T* var, const AccT* var,
const T* scale, const T* scale,
const T* ds, const AccT* ds,
const T* db, const AccT* db,
T* p1, AccT* p1,
T* p2, AccT* p2,
T* p3) { AccT* p3) {
const int n = blockIdx.x; const int n = blockIdx.x;
const int g = blockIdx.y; const int g = blockIdx.y;
const int ng = n * groups + g; const int ng = n * groups + g;
T sum1 = 0; AccT sum1 = 0;
T sum2 = 0; AccT sum2 = 0;
T var_inv = rsqrt(var[ng] + epsilon); AccT var_inv = rsqrt(static_cast<AccT>(var[ng]) + epsilon);
for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) { for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) {
const int64_t index = ng * group_size + i; const int64_t index = ng * group_size + i;
const int64_t c = g * group_size + i; const int64_t c = g * group_size + i;
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]); const AccT scale_v =
sum1 += ds[index] * scale_v; scale == nullptr ? static_cast<AccT>(1) : static_cast<AccT>(scale[c]);
sum2 += db[index] * scale_v; sum1 += static_cast<AccT>(ds[index]) * scale_v;
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]); sum2 += static_cast<AccT>(db[index]) * scale_v;
p1[index] = scale_c * var_inv; const AccT scale_c =
scale == nullptr ? static_cast<AccT>(0) : static_cast<T>(scale[c]);
p1[index] = static_cast<AccT>(scale_c) * var_inv;
} }
typedef cub::BlockReduce<T, BlockDim> BlockReduce; typedef cub::BlockReduce<AccT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage; __shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage; __shared__ typename BlockReduce::TempStorage db_storage;
sum1 = BlockReduce(ds_storage).Reduce(sum1, cub::Sum()); sum1 = BlockReduce(ds_storage).Reduce(sum1, cub::Sum());
sum2 = BlockReduce(db_storage).Reduce(sum2, cub::Sum()); sum2 = BlockReduce(db_storage).Reduce(sum2, cub::Sum());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const T s = T(1) / static_cast<T>(group_size * imsize); const AccT s =
const T x = (sum2 * static_cast<T>(mean[ng]) - sum1) * static_cast<AccT>(1) / static_cast<AccT>(group_size * imsize);
static_cast<T>(var_inv) * static_cast<T>(var_inv) * const AccT x = (sum2 * static_cast<AccT>(mean[ng]) - sum1) * (var_inv) *
static_cast<T>(var_inv) * s; (var_inv) * (var_inv)*s;
p2[ng] = x; p2[ng] = x;
p3[ng] = -x * static_cast<T>(mean[ng]) - sum2 * static_cast<T>(var_inv) * s; p3[ng] = -x * (mean[ng]) - (sum2 * var_inv) * s;
} }
} }
template <typename T> template <typename T, typename AccT>
__global__ void GetXGradientCUDAKernel(int imsize, __global__ void GetXGradientCUDAKernel(int imsize,
int C, int C,
int group_size, int group_size,
int groups, int groups,
T* p1, AccT* p1,
T* p2, AccT* p2,
T* p3, AccT* p3,
const T* x, const T* x,
const T* dy, const T* dy,
T* dx) { T* dx) {
...@@ -245,7 +259,8 @@ __global__ void GetXGradientCUDAKernel(int imsize, ...@@ -245,7 +259,8 @@ __global__ void GetXGradientCUDAKernel(int imsize,
int nc = gid * group_size + cid; int nc = gid * group_size + cid;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int index = (bid * C + nc) * imsize + imid; int index = (bid * C + nc) * imsize + imid;
dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng]; dx[index] = static_cast<T>(p1[ccid] * static_cast<AccT>(dy[index]) +
p2[ng] * static_cast<AccT>(x[index]) + p3[ng]);
} }
} }
...@@ -264,6 +279,7 @@ void GroupNormGradKernel(const Context& dev_ctx, ...@@ -264,6 +279,7 @@ void GroupNormGradKernel(const Context& dev_ctx,
DenseTensor* d_x, DenseTensor* d_x,
DenseTensor* d_scale, DenseTensor* d_scale,
DenseTensor* d_bias) { DenseTensor* d_bias) {
using AccT = typename kps::details::MPTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr(); const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr(); const auto bias_ptr = bias.get_ptr();
...@@ -277,20 +293,20 @@ void GroupNormGradKernel(const Context& dev_ctx, ...@@ -277,20 +293,20 @@ void GroupNormGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(d_x); dev_ctx.template Alloc<T>(d_x);
phi::funcs::SetConstant<GPUContext, T> set_zero; phi::funcs::SetConstant<GPUContext, T> set_zero;
phi::funcs::SetConstant<GPUContext, AccT> set_zero_AccT;
DenseTensor ds, db; DenseTensor ds, db;
ds.Resize({x_dims[0], C}); ds.Resize({x_dims[0], C});
T* ds_data = dev_ctx.template Alloc<T>(&ds); AccT* ds_data = dev_ctx.template Alloc<AccT>(&ds);
db.Resize({x_dims[0], C}); db.Resize({x_dims[0], C});
T* db_data = dev_ctx.template Alloc<T>(&db); AccT* db_data = dev_ctx.template Alloc<AccT>(&db);
auto* y_data = y.data<T>(); auto* y_data = y.data<T>();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
T* d_x_data = nullptr; T* d_x_data = nullptr;
if (d_x) d_x_data = d_x->data<T>(); if (d_x) d_x_data = d_x->data<T>();
auto* dy_data = d_y.data<T>(); auto* dy_data = d_y.data<T>();
auto* var_data = var.data<T>(); auto* var_data = var.data<AccT>();
auto* mean_data = mean.data<T>(); auto* mean_data = mean.data<AccT>();
T* d_scale_data = nullptr; T* d_scale_data = nullptr;
if (d_scale) { if (d_scale) {
dev_ctx.template Alloc<T>(d_scale); dev_ctx.template Alloc<T>(d_scale);
...@@ -338,12 +354,13 @@ void GroupNormGradKernel(const Context& dev_ctx, ...@@ -338,12 +354,13 @@ void GroupNormGradKernel(const Context& dev_ctx,
} }
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw); dim3 blocks(block_size_nchw);
ScalarGetDsDbCUDAKernel<T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>( ScalarGetDsDbCUDAKernel<T, AccT>
imsize, x_data, dy_data, ds_data, db_data); <<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
if (d_scale || d_bias) { if (d_scale || d_bias) {
const int block = 256; const int block = 256;
GetScaleBiasGradientCUDAKernel<T> GetScaleBiasGradientCUDAKernel<T, AccT>
<<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>( <<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>(
x_dims[0], x_dims[0],
C, C,
...@@ -365,13 +382,13 @@ void GroupNormGradKernel(const Context& dev_ctx, ...@@ -365,13 +382,13 @@ void GroupNormGradKernel(const Context& dev_ctx,
// p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n); // p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n);
DenseTensor p1, p2, p3; DenseTensor p1, p2, p3;
p1.Resize({x_dims[0] * C}); p1.Resize({x_dims[0] * C});
T* p1_data = dev_ctx.template Alloc<T>(&p1); AccT* p1_data = dev_ctx.template Alloc<AccT>(&p1);
p2.Resize({x_dims[0], groups}); p2.Resize({x_dims[0], groups});
T* p2_data = dev_ctx.template Alloc<T>(&p2); AccT* p2_data = dev_ctx.template Alloc<AccT>(&p2);
p3.Resize({x_dims[0], groups}); p3.Resize({x_dims[0], groups});
T* p3_data = dev_ctx.template Alloc<T>(&p3); AccT* p3_data = dev_ctx.template Alloc<AccT>(&p3);
GetBackwardParamsCUDAKernel<T, block_dims> GetBackwardParamsCUDAKernel<T, AccT, block_dims>
<<<dim3(x_dims[0], groups), block_dims, 0, dev_ctx.stream()>>>( <<<dim3(x_dims[0], groups), block_dims, 0, dev_ctx.stream()>>>(
imsize, imsize,
groups, groups,
...@@ -408,14 +425,14 @@ void GroupNormGradKernel(const Context& dev_ctx, ...@@ -408,14 +425,14 @@ void GroupNormGradKernel(const Context& dev_ctx,
DenseTensor temp_var; DenseTensor temp_var;
temp_var.Resize(var.dims()); temp_var.Resize(var.dims());
dev_ctx.template Alloc<T>(&temp_var); dev_ctx.template Alloc<T>(&temp_var);
set_zero(dev_ctx, &temp_var, static_cast<T>(0)); set_zero_AccT(dev_ctx, &temp_var, static_cast<AccT>(0));
T* temp_var_data = temp_var.data<T>(); auto* temp_var_data = temp_var.data<AccT>();
DenseTensor temp_mean; DenseTensor temp_mean;
temp_mean.Resize(var.dims()); temp_mean.Resize(var.dims());
dev_ctx.template Alloc<T>(&temp_mean); dev_ctx.template Alloc<AccT>(&temp_mean);
set_zero(dev_ctx, &temp_mean, static_cast<T>(0)); set_zero_AccT(dev_ctx, &temp_mean, static_cast<AccT>(0));
T* temp_mean_data = temp_mean.data<T>(); auto* temp_mean_data = temp_mean.data<AccT>();
int flags = int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
...@@ -460,6 +477,10 @@ void GroupNormGradKernel(const Context& dev_ctx, ...@@ -460,6 +477,10 @@ void GroupNormGradKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(group_norm_grad,
group_norm_grad, GPU, ALL_LAYOUT, phi::GroupNormGradKernel, float, double) { GPU,
} ALL_LAYOUT,
phi::GroupNormGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace phi { namespace phi {
template <typename T> template <typename T, typename AccT>
__global__ void GroupNormForwardGetMeanAndVar(const T* x, __global__ void GroupNormForwardGetMeanAndVar(const T* x,
int N, int N,
int C, int C,
...@@ -30,8 +30,8 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, ...@@ -30,8 +30,8 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x,
int imsize, int imsize,
int groups, int groups,
int group_size, int group_size,
T* mean, AccT* mean,
T* var) { AccT* var) {
int gid = blockIdx.y; int gid = blockIdx.y;
int cid = blockIdx.x; int cid = blockIdx.x;
int bid = blockIdx.z; int bid = blockIdx.z;
...@@ -39,12 +39,13 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, ...@@ -39,12 +39,13 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x,
int number = min(group_size, static_cast<int>(C - gid * group_size)); int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid; int ccid = gid * group_size + cid;
if (ccid >= C) return; if (ccid >= C) return;
T x_mean = 0, x_var = 0; AccT x_mean = static_cast<AccT>(0);
AccT x_var = static_cast<AccT>(0);
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val; AccT val;
int hid = imid / W; int hid = imid / W;
int wid = imid % W; int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid]; val = static_cast<AccT>(x[(bid * H + hid) * W * C + wid * C + ccid]);
x_mean += val; x_mean += val;
x_var += val * val; x_var += val * val;
...@@ -55,10 +56,10 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, ...@@ -55,10 +56,10 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x,
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var); CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
} }
template <typename T, int flags> template <typename T, typename AccT, int flags>
__global__ void GroupNormForward(const T* x, __global__ void GroupNormForward(const T* x,
const T* mean, const AccT* mean,
const T* var, const AccT* var,
const T* scale, const T* scale,
const T* bias, const T* bias,
int N, int N,
...@@ -67,9 +68,9 @@ __global__ void GroupNormForward(const T* x, ...@@ -67,9 +68,9 @@ __global__ void GroupNormForward(const T* x,
int imsize, int imsize,
int groups, int groups,
int group_size, int group_size,
T epsilon, AccT epsilon,
T* y, T* y,
T* real_var, AccT* real_var,
const DataLayout data_layout) { const DataLayout data_layout) {
int gid = blockIdx.y; int gid = blockIdx.y;
int cid = blockIdx.x; int cid = blockIdx.x;
...@@ -78,35 +79,36 @@ __global__ void GroupNormForward(const T* x, ...@@ -78,35 +79,36 @@ __global__ void GroupNormForward(const T* x,
int ccid = gid * group_size + cid; int ccid = gid * group_size + cid;
if (ccid >= C) return; if (ccid >= C) return;
auto ng = bid * groups + gid; auto ng = bid * groups + gid;
T x_mean = mean[ng]; AccT x_mean = mean[ng];
T x_var = var[ng]; AccT x_var = var[ng];
x_var = x_var - x_mean * x_mean; x_var = x_var - x_mean * x_mean;
T var_inv = rsqrt(x_var + epsilon);
AccT var_inv = rsqrt(x_var + epsilon);
if (cid == 0 && threadIdx.x == 0) { if (cid == 0 && threadIdx.x == 0) {
real_var[ng] = x_var; real_var[ng] = x_var;
} }
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val; AccT val;
int hid, wid; int hid, wid;
int index = (bid * C + ccid) * imsize + imid; int index = (bid * C + ccid) * imsize + imid;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
val = x[index]; val = static_cast<AccT>(x[index]);
} else { } else {
hid = imid / W; hid = imid / W;
wid = imid % W; wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid]; val = static_cast<AccT>(x[(bid * H + hid) * W * C + wid * C + ccid]);
} }
val = (val - x_mean) * var_inv; val = (val - x_mean) * var_inv;
if (flags & kHasScale) { if (flags & kHasScale) {
val *= scale[ccid]; val *= static_cast<AccT>(scale[ccid]);
} }
if (flags & kHasBias) { if (flags & kHasBias) {
val += bias[ccid]; val += static_cast<AccT>(bias[ccid]);
} }
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
y[index] = val; y[index] = static_cast<T>(val);
} else { } else {
y[(bid * H + hid) * W * C + wid * C + ccid] = val; y[(bid * H + hid) * W * C + wid * C + ccid] = static_cast<T>(val);
} }
} }
} }
...@@ -122,6 +124,7 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -122,6 +124,7 @@ void GroupNormKernel(const Context& dev_ctx,
DenseTensor* y, DenseTensor* y,
DenseTensor* mean, DenseTensor* mean,
DenseTensor* var) { DenseTensor* var) {
using AccT = typename kps::details::MPTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr(); const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr(); const auto bias_ptr = bias.get_ptr();
...@@ -135,17 +138,19 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -135,17 +138,19 @@ void GroupNormKernel(const Context& dev_ctx,
: x_dims[x_dims.size() - 2]); : x_dims[x_dims.size() - 2]);
dev_ctx.template Alloc<T>(y); dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean); dev_ctx.template Alloc<AccT>(mean);
dev_ctx.template Alloc<T>(var); dev_ctx.template Alloc<AccT>(var);
phi::funcs::SetConstant<GPUContext, T> set_zero; // temp_var is used to calculate the mean^2
DenseTensor temp_var; DenseTensor temp_var;
temp_var.Resize(var->dims()); temp_var.Resize(var->dims());
dev_ctx.template Alloc<T>(&temp_var); dev_ctx.template Alloc<AccT>(&temp_var);
phi::funcs::SetConstant<GPUContext, T> set_zero;
phi::funcs::SetConstant<GPUContext, AccT> set_zero_AccT;
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
auto* y_data = y->data<T>(); auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>(); auto* mean_data = mean->data<AccT>();
auto* var_data = var->data<T>(); auto* var_data = var->data<AccT>();
auto* temp_var_data = temp_var.data<T>(); auto* temp_var_data = temp_var.data<AccT>();
const T* scale_data = nullptr; const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>(); if (scale_ptr) scale_data = scale_ptr->data<T>();
...@@ -172,7 +177,6 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -172,7 +177,6 @@ void GroupNormKernel(const Context& dev_ctx,
dim3 grid(group_size, groups, x_dims[0]); dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1); dim3 threads(block_size, 1, 1);
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
using AccT = typename kps::details::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T); constexpr int vec_size = sizeof(float4) / sizeof(T);
int size = group_size * imsize; int size = group_size * imsize;
const int max_num_threads = 1024; const int max_num_threads = 1024;
...@@ -185,7 +189,7 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -185,7 +189,7 @@ void GroupNormKernel(const Context& dev_ctx,
dim3 grids(x_dims[0] * groups); dim3 grids(x_dims[0] * groups);
dim3 blocks(block_size_nchw); dim3 blocks(block_size_nchw);
if (size < vec_size * block_size_nchw) { if (size < vec_size * block_size_nchw) {
ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0, dev_ctx.stream()>>>( ScalarGetMeanAndVarNCHW<T, AccT><<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size); x_data, mean_data, temp_var_data, size);
} else { } else {
VectorizedGetMeanAndVarNCHW<T, AccT, vec_size> VectorizedGetMeanAndVarNCHW<T, AccT, vec_size>
...@@ -193,9 +197,9 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -193,9 +197,9 @@ void GroupNormKernel(const Context& dev_ctx,
x_data, mean_data, temp_var_data, size); x_data, mean_data, temp_var_data, size);
} }
} else { } else {
set_zero(dev_ctx, mean, static_cast<T>(0)); set_zero_AccT(dev_ctx, mean, static_cast<AccT>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0)); set_zero_AccT(dev_ctx, &temp_var, static_cast<AccT>(0));
GroupNormForwardGetMeanAndVar<T> GroupNormForwardGetMeanAndVar<T, AccT>
<<<grid, threads, 0, dev_ctx.stream()>>>(x_data, <<<grid, threads, 0, dev_ctx.stream()>>>(x_data,
x_dims[0], x_dims[0],
C, C,
...@@ -221,26 +225,26 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -221,26 +225,26 @@ void GroupNormKernel(const Context& dev_ctx,
imsize, imsize,
groups, groups,
group_size, group_size,
epsilon, static_cast<AccT>(epsilon),
y_data, y_data,
var_data, var_data,
data_layout); data_layout);
} }
template <typename T> template <typename T, typename AccT>
void GroupNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream, void GroupNormDirectCUDAFunctor<T, AccT>::operator()(
const T* input, gpuStream_t stream,
std::vector<int> input_shape, const T* input,
const T* bias, std::vector<int> input_shape,
const T* scale, const T* bias,
T* temp_mean, const T* scale,
T* temp_variance, AccT* temp_variance,
int groups, int groups,
float eps, float eps,
T* output, T* output,
T* mean, AccT* mean,
T* variance, AccT* variance,
const DataLayout data_layout) { const DataLayout data_layout) {
const auto input_ddim = phi::make_ddim(input_shape); const auto input_ddim = phi::make_ddim(input_shape);
const int C = const int C =
(data_layout == DataLayout::kNCHW ? input_ddim[1] (data_layout == DataLayout::kNCHW ? input_ddim[1]
...@@ -268,8 +272,7 @@ void GroupNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream, ...@@ -268,8 +272,7 @@ void GroupNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
dim3 grid(group_size, groups, input_ddim[0]); dim3 grid(group_size, groups, input_ddim[0]);
dim3 threads(block_size, 1, 1); dim3 threads(block_size, 1, 1);
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
using AccT = typename phi::kps::details::MPTypeTrait<float>::Type; constexpr int vec_size = sizeof(float4) / sizeof(T);
constexpr int vec_size = sizeof(float4) / sizeof(float);
int size = group_size * image_size; // group element size int size = group_size * image_size; // group element size
const int max_num_threads = 1024; const int max_num_threads = 1024;
int max_block_size = std::min(size / vec_size, max_num_threads); int max_block_size = std::min(size / vec_size, max_num_threads);
...@@ -283,14 +286,22 @@ void GroupNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream, ...@@ -283,14 +286,22 @@ void GroupNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
dim3 blocks(block_size_nchw); dim3 blocks(block_size_nchw);
if (size < vec_size * block_size_nchw) { if (size < vec_size * block_size_nchw) {
phi::ScalarGetMeanAndVarNCHW<T> phi::ScalarGetMeanAndVarNCHW<T, AccT>
<<<grids, blocks, 0, stream>>>(input, temp_mean, temp_variance, size); <<<grids, blocks, 0, stream>>>(input, mean, temp_variance, size);
} else { } else {
phi::VectorizedGetMeanAndVarNCHW<T, AccT, vec_size> phi::VectorizedGetMeanAndVarNCHW<T, AccT, vec_size>
<<<grids, blocks, 0, stream>>>(input, temp_mean, temp_variance, size); <<<grids, blocks, 0, stream>>>(input, mean, temp_variance, size);
} }
} else { } else {
phi::GroupNormForwardGetMeanAndVar<T> #ifdef PADDLE_WITH_HIP
hipMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups);
hipMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups);
#else
cudaMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups);
cudaMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups);
#endif
phi::GroupNormForwardGetMeanAndVar<T, AccT>
<<<grid, threads, 0, stream>>>(input, <<<grid, threads, 0, stream>>>(input,
input_ddim[0], input_ddim[0],
C, C,
...@@ -298,28 +309,37 @@ void GroupNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream, ...@@ -298,28 +309,37 @@ void GroupNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
image_size, image_size,
groups, groups,
group_size, group_size,
temp_mean, mean,
temp_variance); temp_variance);
} }
GroupNormForward<T, 3><<<grid, threads, 0, stream>>>( GroupNormForward<T, AccT, 3>
input, <<<grid, threads, 0, stream>>>(input,
temp_mean, mean,
temp_variance, temp_variance,
scale, scale,
bias, bias,
input_ddim[0], input_ddim[0],
C, C,
W, W,
image_size, image_size,
groups, groups,
group_size, group_size,
eps, static_cast<AccT>(eps),
output, output,
variance, variance,
data_layout); // for now, we only support nchw for group norm data_layout);
} }
template class GroupNormDirectCUDAFunctor<float>; template class GroupNormDirectCUDAFunctor<float, float>;
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template class GroupNormDirectCUDAFunctor<half, float>;
#endif
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(group_norm,
group_norm, GPU, ALL_LAYOUT, phi::GroupNormKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::GroupNormKernel,
float,
double,
phi::dtype::float16) {}
...@@ -31,9 +31,10 @@ namespace phi { ...@@ -31,9 +31,10 @@ namespace phi {
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define ALIGN_BYTES 16 #define ALIGN_BYTES 16
#define CHECK_CASE(i, flags, kernel_name, ...) \ #define CHECK_CASE(i, flags, kernel_name, ...) \
if (i == flags) { \ if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \ kernel_name<T, AccT, i> \
<<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
} }
// 0 for no scale, no bias // 0 for no scale, no bias
...@@ -75,11 +76,14 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs, ...@@ -75,11 +76,14 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
size += offset; size += offset;
if (tid >= offset) { if (tid >= offset) {
if (Num == 1) { if (Num == 1) {
*out_mean += x[tid]; AccT x_acc = static_cast<AccT>(x[tid]);
*out_var += x[tid] * x[tid]; *out_mean += x_acc;
*out_var += x_acc * x_acc;
} else if (Num == 2) { } else if (Num == 2) {
*out_mean += y[tid]; AccT x_acc = static_cast<AccT>(x[tid]);
*out_var += y[tid] * x[tid]; AccT y_acc = static_cast<AccT>(y[tid]);
*out_mean += y_acc;
*out_var += y_acc * x_acc;
} }
} }
size -= blockDim.x; size -= blockDim.x;
...@@ -105,11 +109,14 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs, ...@@ -105,11 +109,14 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
if (Num == 1) { if (Num == 1) {
*out_mean += ins_x[i]; AccT ins_x_acc = static_cast<AccT>(ins_x[i]);
*out_var += ins_x[i] * ins_x[i]; *out_mean += ins_x_acc;
*out_var += ins_x_acc * ins_x_acc;
} else if (Num == 2) { } else if (Num == 2) {
*out_mean += ins_y[i]; AccT ins_x_acc = static_cast<AccT>(ins_x[i]);
*out_var += ins_y[i] * ins_x[i]; AccT ins_y_acc = static_cast<AccT>(ins_y[i]);
*out_mean += ins_y_acc;
*out_var += ins_y_acc * ins_x_acc;
} }
} }
} }
...@@ -118,11 +125,14 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs, ...@@ -118,11 +125,14 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
tid = size - remain + threadIdx.x; tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) { for (; tid < size; tid += blockDim.x) {
if (Num == 1) { if (Num == 1) {
*out_mean += x[tid]; AccT x_acc = static_cast<AccT>(x[tid]);
*out_var += x[tid] * x[tid]; *out_mean += x_acc;
*out_var += x_acc * x_acc;
} else if (Num == 2) { } else if (Num == 2) {
*out_mean += y[tid]; AccT x_acc = static_cast<AccT>(x[tid]);
*out_var += y[tid] * x[tid]; AccT y_acc = static_cast<AccT>(y[tid]);
*out_mean += y_acc;
*out_var += y_acc * x_acc;
} }
} }
} }
...@@ -137,28 +147,32 @@ __device__ __forceinline__ void ReduceMeanAndVar( ...@@ -137,28 +147,32 @@ __device__ __forceinline__ void ReduceMeanAndVar(
x_var, kps::AddFunctor<T>()); x_var, kps::AddFunctor<T>());
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
mean[nc] = static_cast<T>(x_mean / size); mean[nc] = x_mean / size;
var[nc] = static_cast<T>(x_var / size); var[nc] = x_var / size;
} }
} }
template <typename T> template <typename T, typename AccT>
__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) { __global__ void ScalarGetMeanAndVarNCHW(const T* x,
AccT* mean,
AccT* var,
int size) {
int i = blockIdx.x; int i = blockIdx.x;
T x_mean = 0, x_var = 0; AccT x_mean = static_cast<AccT>(0);
AccT x_var = static_cast<AccT>(0);
for (int j = threadIdx.x; j < size; j += blockDim.x) { for (int j = threadIdx.x; j < size; j += blockDim.x) {
T val; AccT val;
val = x[i * size + j]; val = static_cast<AccT>(x[i * size + j]);
x_mean += val; x_mean += val;
x_var += val * val; x_var += val * val;
} }
ReduceMeanAndVar<T>(mean, var, x_mean, x_var, size); ReduceMeanAndVar<AccT>(mean, var, x_mean, x_var, size);
} }
template <typename T, typename AccT, int VecSize> template <typename T, typename AccT, int VecSize>
__global__ void VectorizedGetMeanAndVarNCHW(const T* x, __global__ void VectorizedGetMeanAndVarNCHW(const T* x,
T* mean, AccT* mean,
T* var, AccT* var,
int size) { int size) {
int i = blockIdx.x; int i = blockIdx.x;
AccT x_mean = static_cast<AccT>(0); AccT x_mean = static_cast<AccT>(0);
......
...@@ -34,7 +34,7 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -34,7 +34,7 @@ void GroupNormKernel(const Context& dev_ctx,
DenseTensor* variance); DenseTensor* variance);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T> template <typename T, typename AccT = T>
class GroupNormDirectCUDAFunctor { class GroupNormDirectCUDAFunctor {
public: public:
void operator()(gpuStream_t stream, void operator()(gpuStream_t stream,
...@@ -42,13 +42,12 @@ class GroupNormDirectCUDAFunctor { ...@@ -42,13 +42,12 @@ class GroupNormDirectCUDAFunctor {
std::vector<int> input_shape, std::vector<int> input_shape,
const T* bias, const T* bias,
const T* scale, const T* scale,
T* temp_mean, AccT* temp_variance,
T* temp_variance,
int groups, int groups,
float eps, float eps,
T* output, T* output,
T* mean, AccT* mean,
T* variance, AccT* variance,
const DataLayout data_layout); const DataLayout data_layout);
}; };
#endif #endif
......
...@@ -23,8 +23,6 @@ import unittest ...@@ -23,8 +23,6 @@ import unittest
class TrtConvertGroupNormTest(TrtLayerAutoScanTest): class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
] ]
...@@ -49,14 +47,15 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): ...@@ -49,14 +47,15 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
for batch in [1, 2, 4]: for batch in [1, 2, 4]:
for group in [1, 4, 32, -1]: for group in [1, 4, 32, -1]:
for epsilon in [0.0001, 0.0007, -1, 1]: for epsilon in [0.00001, 0.00005]:
for data_layout in ['NCHW']: for data_layout in ['NCHW']:
dics = [ dics = [
{ {
"epsilon": epsilon, "epsilon": epsilon,
"groups": group, "groups": group,
"data_layout": data_layout, "data_layout": data_layout,
} },
{},
] ]
ops_config = [ ops_config = [
{ {
...@@ -122,31 +121,31 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): ...@@ -122,31 +121,31 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
# for static_shape # for static_shape
clear_dynamic_shape() clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False attrs, False
), (1e-3, 1e-3) ), 1e-2
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True attrs, False
), 1e-5 ), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.workspace_size = 2013265920
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True attrs, True
), (1e-3, 1e-3) ), 1e-2
def add_skip_trt_case(self): self.trt_param.precision = paddle_infer.PrecisionType.Float32
pass yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -317,5 +317,65 @@ class TestGroupNormEager(unittest.TestCase): ...@@ -317,5 +317,65 @@ class TestGroupNormEager(unittest.TestCase):
) )
class TestGroupNormEager_fp32(unittest.TestCase):
def test_dygraph_api(self):
self.dtype = np.float32
self.shape = (8, 32, 32)
input = np.random.random(self.shape).astype(self.dtype)
with fluid.dygraph.guard():
tensor_1 = fluid.dygraph.to_variable(input)
tensor_1.stop_gradient = False
groupNorm = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4, dtype='float32'
)
ret1 = groupNorm(tensor_1)
ret1.backward()
with _test_eager_guard():
tensor_eager_1 = fluid.dygraph.to_variable(input)
tensor_eager_1.stop_gradient = False
groupNorm_eager = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4
)
ret2 = groupNorm_eager(tensor_eager_1)
ret2.backward()
self.assertEqual(
(
tensor_1.grad.numpy() == tensor_eager_1.grad.numpy()
).all(),
True,
)
class TestGroupNormEager_fp16(unittest.TestCase):
def test_dygraph_api(self):
self.dtype = np.float32
self.shape = (8, 32, 32)
input = np.random.random(self.shape).astype(self.dtype)
with fluid.dygraph.guard():
tensor_1 = fluid.dygraph.to_variable(input)
tensor_1.stop_gradient = False
groupNorm = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4, dtype='float16'
)
ret1 = groupNorm(tensor_1)
ret1.backward()
with _test_eager_guard():
tensor_eager_1 = fluid.dygraph.to_variable(input)
tensor_eager_1.stop_gradient = False
groupNorm_eager = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4
)
ret2 = groupNorm_eager(tensor_eager_1)
ret2.backward()
self.assertEqual(
(
tensor_1.grad.numpy() == tensor_eager_1.grad.numpy()
).all(),
True,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -178,6 +178,58 @@ class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase): ...@@ -178,6 +178,58 @@ class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
self.test_numerical_accuracy() self.test_numerical_accuracy()
class TestGroupNormAPIV2_With_General_Dimensions_fp16(unittest.TestCase):
def test_numerical_accuracy(self):
# fp16 only supported in cuda
if not core.is_compiled_with_cuda():
return
paddle.disable_static()
shapes = [
(2, 6, 4),
(2, 6, 4, 4),
(2, 6, 6, 6, 2),
(2, 6, 6, 6, 2, 3),
(2, 6, 6, 6, 256, 3),
]
np.random.seed(10)
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
for place in places:
for shape in shapes:
scale = np.array([1]).astype("float32")
bias = np.array([0]).astype("float32")
data = np.random.random(shape).astype("float32")
expect_res1 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=6
)
expect_res2 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=2
)
gn1 = paddle.nn.GroupNorm(num_channels=6, num_groups=6)
gn2 = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
paddle.assign(paddle.cast(gn1.weight, 'float16'), gn1.weight)
paddle.assign(paddle.cast(gn1.bias, 'float16'), gn1.bias)
paddle.assign(paddle.cast(gn2.weight, 'float16'), gn2.weight)
paddle.assign(paddle.cast(gn2.bias, 'float16'), gn2.bias)
data_pd = paddle.to_tensor(data.astype('float16'))
result1 = gn1(data_pd).numpy()
result2 = gn2(data_pd).numpy()
np.testing.assert_allclose(
result1, expect_res1, rtol=1e-2, atol=1e-3
)
np.testing.assert_allclose(
result2, expect_res2, rtol=1e-2, atol=1e-3
)
def test_eager_api(self):
with _test_eager_guard():
self.test_numerical_accuracy()
class TestGroupNormDimException(unittest.TestCase): class TestGroupNormDimException(unittest.TestCase):
def test_exception(self): def test_exception(self):
def test_empty_input_static_API(): def test_empty_input_static_API():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册