From 71e28b12815018a8420962fc6a20a3526085938d Mon Sep 17 00:00:00 2001 From: Tian Zheng Date: Thu, 31 Aug 2023 11:08:54 +0800 Subject: [PATCH] Add fused_scale_bias_relu_conv_bnstats OP (#55026) * Add fused_scale_bias_relu_conv_bnstats op * Review changes * Fix no CUDNN Frontend build * Fix PADDLE_ENFORCE format * Fix PADDLE_ENFORCE CI error * Rename kernel filename * Refactor unittest to use paddle eager_op_test * Fix padding bugs * Review changes * test=cuda117 * test=cuda117 --- paddle/phi/api/yaml/fused_ops.yaml | 10 + paddle/phi/infermeta/fusion.cc | 134 ++++ paddle/phi/infermeta/fusion.h | 28 + paddle/phi/kernels/CMakeLists.txt | 5 + paddle/phi/kernels/autotune/cache.cc | 5 + paddle/phi/kernels/autotune/cache.h | 9 +- .../kernels/autotune/cache_cudnn_frontend.h | 73 ++- ...sed_scale_bias_relu_conv_bnstats_kernel.cu | 618 ++++++++++++++++++ .../phi/kernels/gpudnn/conv_cudnn_frontend.h | 46 +- paddle/phi/kernels/gpudnn/conv_kernel.cu | 3 +- test/legacy_test/CMakeLists.txt | 4 + ...t_fused_scale_bias_relu_conv_bnstats_op.py | 239 +++++++ test/white_list/op_accuracy_white_list.py | 1 + 13 files changed, 1154 insertions(+), 21 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu create mode 100644 test/legacy_test/test_fused_scale_bias_relu_conv_bnstats_op.py diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 648384422ca..09ccd2fe7d8 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -160,6 +160,16 @@ backward: fused_rotary_position_embedding_grad support_dygraph_mode : true +- op : fused_scale_bias_relu_conv_bnstats + args : (Tensor x, Tensor w, Tensor scale, Tensor bias, Tensor bn_scale, Tensor bn_bias, Tensor input_running_mean, Tensor input_running_var, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, float momentum, float epsilon, bool fuse_prologue, bool exhaustive_search, int64_t accumulation_count = 0) + optional : scale, bias + output : Tensor(out), Tensor(out_running_mean), Tensor(out_running_var), Tensor(saved_mean), Tensor(saved_var), Tensor(eq_scale), Tensor(eq_bias) + infer_meta : + func : FusedScaleBiasReluConvBnstatsInferMeta + kernel : + func : fused_scale_bias_relu_conv_bnstats + data_type : x + - op : generate_sequence_xpu args : (Tensor x, DataType dtype) output : Tensor diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 3143c5cde2e..993fb5d5887 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -821,4 +821,138 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void FusedScaleBiasReluConvBnstatsInferMeta( + const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& scale, + const MetaTensor& bias, + const MetaTensor& bn_scale, + const MetaTensor& bn_bias, + const MetaTensor& input_running_mean, + const MetaTensor& input_running_var, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + float momentum, + float epsilon, + bool fuse_prologue, + bool exhaustive_search, + int64_t accumulation_count, + MetaTensor* out, + MetaTensor* out_running_mean, + MetaTensor* out_running_var, + MetaTensor* saved_mean, + MetaTensor* saved_var, + MetaTensor* eq_scale, + MetaTensor* eq_bias) { + auto in_dims = x.dims(); + auto filter_dims = w.dims(); + // do some checks + PADDLE_ENFORCE_EQ( + in_dims.size(), + 4, + phi::errors::InvalidArgument( + "The input of Op(FusedScaleBiasReluConvBnstats) should be a 4-D " + "Tensor. But " + "received: input's dimension is %u, input's shape is [%s].", + in_dims.size(), + in_dims)); + + PADDLE_ENFORCE_EQ( + in_dims.size(), + filter_dims.size(), + phi::errors::InvalidArgument( + "The input's dimension and filter's dimension of " + "Op(FusedScaleBiasReluConvBnstats) should be equal. But received: " + "the input's" + " shape is [%s], " + "the input's dimension is %d; the filter's shape is [%s], " + "the filter's dimension is %d.", + in_dims, + in_dims.size(), + filter_dims, + filter_dims.size())); + + // Check if data format is NHWC + PADDLE_ENFORCE_EQ( + data_format, + "NHWC", + phi::errors::InvalidArgument( + "Operator(FusedScaleBiasReluConvBnstats) only supports data format " + "of " + "channel last (NHWC) now. But recieved: data_format = '%s'.", + data_format)); + + PADDLE_ENFORCE_EQ( + groups, + 1, + phi::errors::InvalidArgument("Expect group to be 1, got %d.", groups)); + + const auto input_channels = in_dims[in_dims.size() - 1]; + int dilation_size = dilations.size(); + for (int i = 0; i < dilation_size; ++i) { + PADDLE_ENFORCE_GT( + dilations[i], + 0, + phi::errors::InvalidArgument( + "The dilation of Op(Conv) should be larget than 0, but received " + "dilation is %d.", + dilations[i])); + } + + PADDLE_ENFORCE_EQ( + input_channels, + filter_dims[1] * groups, + phi::errors::InvalidArgument( + "The number of input's channels should be equal to filter's channels " + "* groups for Op(FusedScaleBiasReluConvBnstats). But received: the " + "input's" + " channels is %d, " + "the input's shape is [%s]; the filter's channels is %d, the " + "filter's shape is [%s]; the groups is %d. ", + input_channels, + in_dims, + filter_dims[1], + filter_dims, + groups)); + + // update paddings and dilations accoring to padding_algorithm + std::vector paddings_vec = paddings; + std::vector dilations_vec = dilations; + // get "HW" from "NHWC" + DDim in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1); + DDim filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = phi::vectorize(filter_data_dims); + phi::UpdatePaddingAndDilation(&paddings_vec, + &dilations_vec, + padding_algorithm, + in_data_dims, + strides, + ksize); + + std::vector out_shape({in_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + out_shape.push_back(ConvOutSize(in_dims[i + 1], + filter_dims[i + 2], + dilations[i], + paddings_vec[i * 2], + paddings_vec[i * 2 + 1], + strides[i])); + } + out_shape.push_back(filter_dims[0]); + // make shape for other outputs + auto c_dims = phi::make_ddim({filter_dims[0]}); + // set output and output max dims + out->set_dims(DDim(out_shape.data(), out_shape.size())); + out_running_mean->set_dims(c_dims); + out_running_var->set_dims(c_dims); + saved_mean->set_dims(c_dims); + saved_var->set_dims(c_dims); + eq_scale->set_dims(c_dims); + eq_bias->set_dims(c_dims); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 25c27bdd406..3d7ba19c4ec 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -201,4 +201,32 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, float epsilon, MetaTensor* out); +void FusedScaleBiasReluConvBnstatsInferMeta( + const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& scale, + const MetaTensor& bias, + const MetaTensor& bn_scale, + const MetaTensor& bn_bias, + const MetaTensor& input_running_mean, + const MetaTensor& input_running_var, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + float momentum, + float epsilon, + bool fuse_prologue, + bool exhaustive_search, + int64_t accumulation_count, + MetaTensor* out, + MetaTensor* out_running_mean, + MetaTensor* out_running_var, + MetaTensor* saved_mean, + MetaTensor* saved_var, + MetaTensor* eq_scale, + MetaTensor* eq_bias); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 9a917004c83..cc8df692a12 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -94,6 +94,11 @@ if(WITH_CUTLASS) list(APPEND kernel_cu ${cutlass_cu}) endif() +if(NOT WITH_CUDNN_FRONTEND) + list(REMOVE_ITEM kernel_cu + "fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu") +endif() + set(cc_search_pattern "*.cc" "cpu/*.cc" diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index 6ff1296b513..ba48e2e00ce 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -47,6 +47,11 @@ std::string AlgorithmTypeString(int64_t algo_type) { } else if (algo_type == static_cast(AlgorithmType::kConvBackwardFilterV8)) { return "conv_backward_filter_v8"; + } else if (algo_type == + static_cast(AlgorithmType::kScaleBiasReluConvBNstats)) { + return "scale_bias_relu_conv_bnstats"; + } else if (algo_type == static_cast(AlgorithmType::kBNFinalize)) { + return "bn_finalize"; } #endif return std::to_string(algo_type); diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 188faaed71b..34b98e28f50 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -55,7 +55,9 @@ enum class AlgorithmType { kConvForwardV8 = 10, kConvBackwardDataV8 = 11, kConvBackwardFilterV8 = 12, - kAlgorithmCount = 13 + kScaleBiasReluConvBNstats = 13, + kBNFinalize = 14, + kAlgorithmCount = 15 #endif }; @@ -178,9 +180,8 @@ class AutoTuneCache { conv_auto_tune_map_[key] = cache; } #ifdef PADDLE_WITH_CUDNN_FRONTEND - } else if (algo_type == AlgorithmType::kConvForwardV8 || - algo_type == AlgorithmType::kConvBackwardDataV8 || - algo_type == AlgorithmType::kConvBackwardFilterV8) { + } else if (algo_type >= AlgorithmType::kConvForwardV8 && + algo_type <= AlgorithmType::kBNFinalize) { int64_t key = static_cast(algo_type); if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) { CudnnFrontendPlanCache cache; diff --git a/paddle/phi/kernels/autotune/cache_cudnn_frontend.h b/paddle/phi/kernels/autotune/cache_cudnn_frontend.h index 4715efa1f77..cfd16e51433 100644 --- a/paddle/phi/kernels/autotune/cache_cudnn_frontend.h +++ b/paddle/phi/kernels/autotune/cache_cudnn_frontend.h @@ -79,10 +79,10 @@ class CudnnFrontendPlanCache { return ret; } - void GetPlan(const cudnn_frontend::feature_vector_t &feature, - const cudnn_frontend::ExecutionPlan **plan, - int64_t *workspace_size, - cudnnHandle_t handle) { + void GetPlanAndWorkspaceSize(const cudnn_frontend::feature_vector_t &feature, + const cudnn_frontend::ExecutionPlan **plan, + int64_t *workspace_size, + cudnnHandle_t handle) { // Note(tizheng): CUDNNv8 execution plan is not thread-safe. // A shared plan being executed by different threads is // generally not safe (for now). @@ -90,11 +90,11 @@ class CudnnFrontendPlanCache { auto &local_map = map_[hasher(std::this_thread::get_id())]; auto it = local_map.find(GetExtendedFeature(feature, handle)); - if (it == local_map.end()) { - PADDLE_THROW(phi::errors::InvalidArgument( - "[cudnn_frontend] Cached Plan Not Found.")); - return; - } + PADDLE_ENFORCE_NE(it, + local_map.end(), + phi::errors::InvalidArgument( + "[cudnn_frontend] Cached Plan Not Found.")); + *plan = &(it->second); *workspace_size = (*plan)->getWorkspaceSize(); VLOG(4) << "Cached execution plan found." << (*plan)->getTag() @@ -133,11 +133,12 @@ class CudnnFrontendPlanCache { return FindPlan(op_graph.getFeatureVector(), handle); } - void GetPlan(const cudnn_frontend::OperationGraph &op_graph, - const cudnn_frontend::ExecutionPlan **plan, - int64_t *workspace_size, - cudnnHandle_t handle) { - GetPlan(op_graph.getFeatureVector(), plan, workspace_size, handle); + void GetPlanAndWorkspaceSize(const cudnn_frontend::OperationGraph &op_graph, + const cudnn_frontend::ExecutionPlan **plan, + int64_t *workspace_size, + cudnnHandle_t handle) { + GetPlanAndWorkspaceSize( + op_graph.getFeatureVector(), plan, workspace_size, handle); } void InsertPlan(const cudnn_frontend::OperationGraph &op_graph, @@ -176,5 +177,49 @@ class CudnnFrontendPlanCache { int64_t cache_misses_{0}; }; // class CudnnFrontendPlanCache +template +inline void BuildFeatureVectorSingle(cudnn_frontend::feature_vector_t *v, + const T &value) { + v->push_back(static_cast(value)); +} + +template <> +inline void BuildFeatureVectorSingle(cudnn_frontend::feature_vector_t *v, + const float &value) { + int64_t val = 0; + memcpy(&val, &value, sizeof(float)); + v->push_back(val); +} + +template <> +inline void BuildFeatureVectorSingle>( + cudnn_frontend::feature_vector_t *v, const std::vector &value) { + v->insert(v->end(), value.begin(), value.end()); +} + +template <> +inline void BuildFeatureVectorSingle>( + cudnn_frontend::feature_vector_t *v, const std::vector &value) { + for (auto &val : value) { + v->push_back(static_cast(val)); + } +} + +template <> +inline void BuildFeatureVectorSingle( + cudnn_frontend::feature_vector_t *v, const std::string &value) { + v->push_back(std::hash()(value)); +} + +inline void BuildFeatureVector(cudnn_frontend::feature_vector_t *v) { return; } + +template +inline void BuildFeatureVector(cudnn_frontend::feature_vector_t *v, + const T &value, + Args... args) { + BuildFeatureVectorSingle(v, value); + BuildFeatureVector(v, args...); +} + } // namespace autotune } // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu new file mode 100644 index 00000000000..e19996d63c7 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu @@ -0,0 +1,618 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/backends/gpu/cuda/cudnn_helper.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" + +DECLARE_bool(cudnn_deterministic); +DECLARE_bool(cudnn_exhaustive_search); + +namespace phi { +namespace fusion { + +using helper = phi::CudnnFrontendConvHelper; +template +using CudnnDataType = phi::backends::gpu::CudnnDataType; + +/* + * Implements Scale + Bias + ReLU + Conv + BNStats fusion pattern. + * Same as the following (x and output are in NHWC format): + * ``` + * output = conv2d(relu(x * scale + bias), w) + * sum_output, sqsum_output = bnstats(output) + * ``` + * Here, bnstats generates per-channel statistics, same as: + * ``` + * sum_output = output.sum(axis=[0,1,2]) + * sqsum_output = (output ** 2).sum(axis=[0,1,2]) + * ``` + * More details: + * https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#genstats-runtime-fusion-engine + */ +template +void FusedScaleBiasReluConvBnstatsImpl( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& w, + const paddle::optional& scale, + const paddle::optional& bias, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + bool fuse_prologue, + bool exhaustive_search, + bool deterministic, + DenseTensor* output, + DenseTensor* sum_output, + DenseTensor* sqsum_output) { + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kScaleBiasReluConvBNstats); + + // transformed tensor + DenseTensor w_transformed(w.dtype()); + // Assume input and output already in NHWC. + // No transformation is needed for them. + VLOG(3) << "Transform filter tensor from NCHW to NHWC."; + ResizeToChannelLast(dev_ctx, &w, &w_transformed); + TransToChannelLast(dev_ctx, &w, &w_transformed); + + // update padding and dilation + std::vector paddings_vec = paddings; + std::vector dilations_vec = dilations; + auto in_dims = x.dims(); + auto filter_dims = w_transformed.dims(); + DDim in_data_dims = slice_ddim(in_dims, 1, in_dims.size() - 1); + DDim filter_data_dims = slice_ddim(filter_dims, 1, filter_dims.size() - 1); + std::vector ksize = phi::vectorize(filter_data_dims); + phi::UpdatePaddingAndDilation(&paddings_vec, + &dilations_vec, + padding_algorithm, + in_data_dims, + strides, + ksize); + + int data_dim = strides.size(); // 2d only + + std::vector pre_padding(data_dim, 0); + std::vector post_padding(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + pre_padding[i] = static_cast(paddings_vec[2 * i]); + post_padding[i] = static_cast(paddings_vec[2 * i + 1]); + } + + // input pointers + T* input_data = const_cast(x.data()); + T* filter_data = w_transformed.data(); + + // output pointers + T* output_data = output->data(); + float* sum_output_data = sum_output->data(); + float* sqsum_output_data = sqsum_output->data(); + + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + // build tensors + cudnnTensorFormat_t layout_format = CUDNN_TENSOR_NHWC; + auto tensor_format = phi::backends::gpu::ToCudnnDataType(x.dtype()); + + auto tensor_format_math = CUDNN_DATA_FLOAT; + auto compute_dtype = CUDNN_DATA_FLOAT; + + // get dims in CUDNN manner: [N, C, H, W] + auto dim_x = + phi::backends::gpu::TransformDimOrder(phi::vectorize(in_dims)); + auto dim_filt = phi::backends::gpu::TransformDimOrder( + phi::vectorize(filter_dims)); + auto dim_y = phi::backends::gpu::TransformDimOrder( + phi::vectorize(output->dims())); + std::vector dim_scale(dim_x.size(), 1); + dim_scale[1] = dim_x[1]; // [1, C, 1, 1] + std::vector dim_sum(dim_x.size(), 1); // [1, K, 1, 1] + dim_sum[1] = dim_filt[0]; + + std::vector data_ptrs; + std::vector uids; + int64_t uid = 100; + + // inputs + auto input_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(input_data); + uids.push_back(uid); + + auto filter_desc = helper::GetGeneralTensorDescriptor( + dim_filt, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(filter_data); + uids.push_back(uid); + + // dispensable inputs + auto scale_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format); + if (fuse_prologue) { + data_ptrs.push_back(const_cast(scale->data())); + uids.push_back(uid); + } + + auto bias_desc = helper::GetGeneralTensorDescriptor( + dim_scale, layout_format, ++uid, 16, tensor_format); + if (fuse_prologue) { + data_ptrs.push_back(const_cast(bias->data())); + uids.push_back(uid); + } + + // outputs + auto output_desc = helper::GetGeneralTensorDescriptor( + dim_y, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(output_data); + uids.push_back(uid); + + auto sum_output_desc = helper::GetGeneralTensorDescriptor( + dim_sum, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(sum_output_data); + uids.push_back(uid); + + auto sqsum_output_desc = helper::GetGeneralTensorDescriptor( + dim_sum, layout_format, ++uid, 16, tensor_format_math); + data_ptrs.push_back(sqsum_output_data); + uids.push_back(uid); + + // virtual outputs + auto after_scale = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_bias = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + auto after_relu = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + // create ops + auto scale_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_MUL, compute_dtype, input_desc, scale_desc, after_scale); + + auto bias_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_ADD, compute_dtype, after_scale, bias_desc, after_bias); + + auto relu_desc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setComputeType(compute_dtype) + .build(); + + auto relu_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(after_bias) + .setyDesc(after_relu) + .setpwDesc(relu_desc) + .build(); + VLOG(6) << relu_op.describe(); + + std::vector stride_int64 = helper::GetInt64Array(strides); + std::vector dilation_int64 = helper::GetInt64Array(dilations_vec); + auto conv_desc = cudnn_frontend::ConvDescBuilder() + .setComputeType(compute_dtype) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setSpatialDimCount(data_dim) + .setSpatialStride(data_dim, stride_int64.data()) + .setPrePadding(data_dim, pre_padding.data()) + .setPostPadding(data_dim, post_padding.data()) + .setDilation(data_dim, dilation_int64.data()) + .build(); + + float alpha = 1.0f; + float beta = 0.0f; + cudnn_frontend::Tensor* input_to_conv = + fuse_prologue ? &after_relu : &input_desc; + auto conv_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(*input_to_conv) + .setwDesc(filter_desc) + .setyDesc(output_desc) + .setcDesc(conv_desc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + VLOG(6) << conv_op.describe(); + + auto genstat_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR) + .setxDesc(output_desc) + .setComputeType(compute_dtype) + .setGenStatsMode(CUDNN_GENSTATS_SUM_SQSUM) + .setSumDesc(sum_output_desc) + .setSqSumDesc(sqsum_output_desc) + .build(); + VLOG(6) << genstat_op.describe(); + + // build op graph + std::vector ops; + if (fuse_prologue) { + ops = std::vector( + {&scale_op, &bias_op, &relu_op, &conv_op, &genstat_op}); + } else { + ops = + std::vector({&conv_op, &genstat_op}); + } + + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + VLOG(6) << op_graph.describe(); + + cudnn_frontend::feature_vector_t feature_vector; + phi::autotune::BuildFeatureVector(&feature_vector, + dim_x, + dim_filt, + strides, + paddings, + dilations, + pre_padding, + post_padding, + fuse_prologue); + + helper::QueryCacheAndExecute(handle, + &workspace_handle, + &op_graph, + &data_ptrs, + &uids, + exhaustive_search, + deterministic, + feature_vector, + &plan_cache); +} + +/* + * Implements BNFinalize pattern. It works with aforementioned bnstats node: + * ``` + * y = bn_finalize(genstats(conv_out)) + * ``` + * is the same as: + * ``` + * y = batchnorm2d(conv_out) + * ``` + */ +template +void BNFinalizeImpl(const Context& dev_ctx, + const DenseTensor& sum_tensor, + const DenseTensor& sqsum_tensor, + const DenseTensor& bn_scale, + const DenseTensor& bn_bias, + const DenseTensor& input_running_mean, + const DenseTensor& input_running_var, + int64_t accumulation_count, + float exp_decay, + float epsilon, + bool exhaustive_search, + bool deterministic, + DenseTensor* out_running_mean, + DenseTensor* out_running_var, + DenseTensor* saved_mean, + DenseTensor* saved_var, + DenseTensor* eq_scale, + DenseTensor* eq_bias) { + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kBNFinalize); + + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + // set dtypes + cudnnTensorFormat_t layout_format = CUDNN_TENSOR_NHWC; + auto tensor_format_bn = + phi::backends::gpu::ToCudnnDataType(sum_tensor.dtype()); + auto tensor_format = phi::backends::gpu::ToCudnnDataType(eq_scale->dtype()); + auto compute_dtype = CUDNN_DATA_FLOAT; + // create tensor descriptors + auto dim_input = phi::vectorize(sum_tensor.dims()); + std::vector dim_c = {1, dim_input[0], 1, 1}; // [1, C, 1, 1] + std::vector dim_scalar = {1, 1, 1, 1}; + std::vector stride_scalar = {1, 1, 1, 1}; + + std::vector data_ptrs; + std::vector uids; + int64_t uid = 100; + + // inputs + auto sum_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(const_cast(sum_tensor.data())); + uids.push_back(uid); + + auto sqsum_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(const_cast(sqsum_tensor.data())); + uids.push_back(uid); + + auto scale_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(const_cast(bn_scale.data())); + uids.push_back(uid); + + auto bias_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(const_cast(bn_bias.data())); + uids.push_back(uid); + + auto input_running_mean_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(const_cast(input_running_mean.data())); + uids.push_back(uid); + + auto input_running_var_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(const_cast(input_running_var.data())); + uids.push_back(uid); + + // outputs + auto updated_running_mean_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(out_running_mean->data()); + uids.push_back(uid); + + auto updated_running_var_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(out_running_var->data()); + uids.push_back(uid); + + auto saved_mean_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(saved_mean->data()); + uids.push_back(uid); + + auto saved_inv_var_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format_bn); + data_ptrs.push_back(saved_var->data()); + uids.push_back(uid); + + auto eq_scale_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(eq_scale->data()); + uids.push_back(uid); + + auto eq_bias_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(eq_bias->data()); + uids.push_back(uid); + + // scalar descriptors + auto epsilon_desc = cudnn_frontend::TensorBuilder() + .setDim(dim_scalar.size(), dim_scalar.data()) + .setStride(stride_scalar.size(), stride_scalar.data()) + .setId(++uid) + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setByValue(true) + .build(); + data_ptrs.push_back(&epsilon); + uids.push_back(uid); + + auto exp_decay_desc = + cudnn_frontend::TensorBuilder() + .setDim(dim_scalar.size(), dim_scalar.data()) + .setStride(stride_scalar.size(), stride_scalar.data()) + .setId(++uid) + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .setByValue(true) + .build(); + data_ptrs.push_back(&exp_decay); + uids.push_back(uid); + + auto accum_count_desc = + cudnn_frontend::TensorBuilder() + .setDim(dim_scalar.size(), dim_scalar.data()) + .setStride(stride_scalar.size(), stride_scalar.data()) + .setId(++uid) + .setAlignment(16) + .setDataType(CUDNN_DATA_INT64) + .setByValue(true) + .build(); + data_ptrs.push_back(&accumulation_count); + uids.push_back(uid); + + // build ops + auto finalize_stat_op = + cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR) + .setComputeType(compute_dtype) + .setBNFinalizeMode(CUDNN_BN_FINALIZE_STATISTICS_TRAINING) + .setSumDesc(sum_desc) + .setSqSumDesc(sqsum_desc) + .setScaleAndBias(scale_desc, bias_desc) + .setEqScaleAndBias(eq_scale_desc, eq_bias_desc) + .setPrevRunningMeanAndVar(input_running_mean_desc, + input_running_var_desc) + .setNextRunningMeanAndVar(updated_running_mean_desc, + updated_running_var_desc) + .setSavedMeanAndInvVar(saved_mean_desc, saved_inv_var_desc) + .setEpsilonTensor(epsilon_desc) + .setAccumCountTensor(accum_count_desc) + .setExpDecayFactorTensor(exp_decay_desc) + .build(); + + std::array ops = {&finalize_stat_op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + VLOG(6) << op_graph.describe(); + + cudnn_frontend::feature_vector_t feature_vector; + phi::autotune::BuildFeatureVector( + &feature_vector, dim_input, accumulation_count, exp_decay, epsilon); + + helper::QueryCacheAndExecute(handle, + &workspace_handle, + &op_graph, + &data_ptrs, + &uids, + exhaustive_search, + deterministic, + feature_vector, + &plan_cache); +} + +template +void FusedScaleBiasReluConvBnstatsKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& w, + const paddle::optional& scale, + const paddle::optional& bias, + const DenseTensor& bn_scale, + const DenseTensor& bn_bias, + const DenseTensor& input_running_mean, + const DenseTensor& input_running_var, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + float momentum, + float epsilon, + bool fuse_prologue, + bool exhaustive_search, + int64_t accumulation_count, + DenseTensor* out, + DenseTensor* out_running_mean, + DenseTensor* out_running_var, + DenseTensor* saved_mean, + DenseTensor* saved_var, + DenseTensor* eq_scale, + DenseTensor* eq_bias) { + auto cudnn_version = phi::backends::gpu::DnnVersion(); + PADDLE_ENFORCE_GE(cudnn_version, + 8800, + phi::errors::PreconditionNotMet( + "This op only supports CUDNN version >= 8800, " + "but got %d.", + cudnn_version)); + PADDLE_ENFORCE_GE(dev_ctx.GetComputeCapability(), + 80, + phi::errors::PreconditionNotMet( + "This op only supports Ampere and later devices, " + "but got compute capability: %d.", + dev_ctx.GetComputeCapability())); + // attr + float exp_decay = 1. - momentum; + if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon = + std::max(epsilon, static_cast(CUDNN_BN_MIN_EPSILON + FLT_EPSILON)); + // exhaustive search + exhaustive_search = exhaustive_search || FLAGS_cudnn_exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + // check optional inputs + if (fuse_prologue) { + PADDLE_ENFORCE_EQ( + scale && bias, + true, + phi::errors::InvalidArgument( + "\"scale\" and \"bias\" must be provided " + "when fuse_prologue = true. Got scale = %d; bias = %d.", + scale, + bias)); + } + + // alloc output variables + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(out_running_mean); + dev_ctx.template Alloc(out_running_var); + dev_ctx.template Alloc(saved_mean); + dev_ctx.template Alloc(saved_var); + dev_ctx.template Alloc(eq_scale); + dev_ctx.template Alloc(eq_bias); + + // deal with strides, dilations and paddings + if (accumulation_count == 0) { + // dim_out = [N, H, W, C] + // accumulation_count = N * H * W + auto dim_out = phi::vectorize(out->dims()); + accumulation_count = dim_out[0] * dim_out[1] * dim_out[2]; + } + + // Step 1: Scale Bias ReLU Conv BNStats + auto bn_dims = bn_scale.dims(); + DenseTensor sum_tensor(bn_scale.dtype()); + DenseTensor sqsum_tensor(bn_scale.dtype()); + sum_tensor.Resize(bn_dims); + sqsum_tensor.Resize(bn_dims); + dev_ctx.template Alloc(&sum_tensor); + dev_ctx.template Alloc(&sqsum_tensor); + FusedScaleBiasReluConvBnstatsImpl(dev_ctx, + x, + w, + scale, + bias, + paddings, + dilations, + strides, + padding_algorithm, + fuse_prologue, + exhaustive_search, + deterministic, + out, + &sum_tensor, + &sqsum_tensor); + // Step 2: BN Finalize + BNFinalizeImpl(dev_ctx, + sum_tensor, + sqsum_tensor, + bn_scale, + bn_bias, + input_running_mean, + input_running_var, + accumulation_count, + exp_decay, + epsilon, + exhaustive_search, + deterministic, + out_running_mean, + out_running_var, + saved_mean, + saved_var, + eq_scale, + eq_bias); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_scale_bias_relu_conv_bnstats, + GPU, + ALL_LAYOUT, + phi::fusion::FusedScaleBiasReluConvBnstatsKernel, + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); +} diff --git a/paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h b/paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h index ef8e606e547..d0bdcc10bea 100644 --- a/paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h +++ b/paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h @@ -367,6 +367,48 @@ class CudnnFrontendConvHelper { plan_cache); } + static void QueryCacheAndExecute( + cudnnHandle_t handle, + phi::DnnWorkspaceHandle* workspace_handle, + cudnn_frontend::OperationGraph* op_graph_pointer, + std::vector* data_ptrs, + std::vector* uids, + bool exhaustive_search, + bool deterministic, + const cudnn_frontend::feature_vector_t& feature_vector, + phi::autotune::CudnnFrontendPlanCache* plan_cache) { + if (plan_cache->FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache->GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + ExecutePlan(handle, + workspace_handle, + data_ptrs, + uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + auto plans = FindExecutionPlans(op_graph_pointer, + exhaustive_search, + deterministic, + data_ptrs, + uids, + handle, + workspace_handle); + + ExecutePlansAndCache(handle, + workspace_handle, + data_ptrs, + uids, + &plans, + exhaustive_search, + feature_vector, + plan_cache); + } + static cudnn_frontend::Operation MakePointwiseOp( cudnnPointwiseMode_t mode, cudnnDataType_t dtype, @@ -435,7 +477,7 @@ void CudnnConvBwdDataV8(const DenseTensor* dy_tensor, if (plan_cache_bwd_data.FindPlan(op_graph, handle)) { const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; int64_t workspace_size = 0; - plan_cache_bwd_data.GetPlan( + plan_cache_bwd_data.GetPlanAndWorkspaceSize( op_graph, &cached_plan, &workspace_size, handle); helper::ExecutePlan(handle, workspace_handle, @@ -509,7 +551,7 @@ void CudnnConvBwdFilterV8(const DenseTensor* x_tensor, if (plan_cache_bwd_filter.FindPlan(op_graph, handle)) { const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; int64_t workspace_size = 0; - plan_cache_bwd_filter.GetPlan( + plan_cache_bwd_filter.GetPlanAndWorkspaceSize( op_graph, &cached_plan, &workspace_size, handle); helper::ExecutePlan(handle, workspace_handle, diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index 6dc7fc9e613..65418673827 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -264,7 +264,8 @@ void ConvCudnnKernelImplV8(const DenseTensor* input_tensor, if (plan_cache.FindPlan(op_graph, handle)) { const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; int64_t workspace_size = 0; - plan_cache.GetPlan(op_graph, &cached_plan, &workspace_size, handle); + plan_cache.GetPlanAndWorkspaceSize( + op_graph, &cached_plan, &workspace_size, handle); helper::ExecutePlan(handle, &workspace_handle, input_data, diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 079233a9c16..86443907bd9 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -503,6 +503,10 @@ if(NOT WITH_GPU list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) endif() +if(NOT WITH_CUDNN_FRONTEND) + list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_relu_conv_bnstats_op) +endif() + # Some ops need to check results when gc is enabled # Currently, only ops that register NoNeedBufferVarsInference need to do this test set(TEST_OPS_WITH_GC diff --git a/test/legacy_test/test_fused_scale_bias_relu_conv_bnstats_op.py b/test/legacy_test/test_fused_scale_bias_relu_conv_bnstats_op.py new file mode 100644 index 00000000000..cbed4e5b33f --- /dev/null +++ b/test/legacy_test/test_fused_scale_bias_relu_conv_bnstats_op.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from eager_op_test import OpTest, skip_check_grad_ci + +import paddle +from paddle import nn +from paddle.fluid import core + + +def skip_unit_test(): + return ( + not paddle.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 8 + or paddle.get_cudnn_version() < 8800 + ) + + +skip_msg = ( + "only support with cuda and CUDNN 8.8 or later," + " and only Ampere or later devices are supported" +) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedScaleBiasReluConvBnstatsOp(OpTest): + def setUp(self): + self.__class__.op_type = "fused_scale_bias_relu_conv_bnstats" + self.dtype = np.float16 + self.outputs = None + self.padding_algorithm = "EXIPLICIT" + self.data_format = "NHWC" + self.groups = 1 + self.init_attr() + self.init_test_case() + self.rtol = 1e-5 + self.atol = 2e-2 + + self.attrs = { + 'fuse_prologue': self.fuse_prologue, + 'strides': self.stride, + 'paddings': self.pad, + 'dilations': self.dilations, + 'data_format': self.data_format, + 'padding_algorithm': self.padding_algorithm, + 'accumulation_count': self.accumulation_count, + 'momentum': self.momentum, + 'epsilon': self.epsilon, + 'exhaustive_search': self.exhaustive_search, + 'groups': self.groups, + } + + # prepare inputs + np.random.seed(0) + self.x_input = np.random.random(self.x_size).astype(self.dtype) + self.bias_input = np.random.random(self.in_channel_num).astype( + self.dtype + ) + self.scale_input = np.random.random(self.in_channel_num).astype( + self.dtype + ) + + self.x_input_prologue = self.x_input.astype(np.float32) + if self.fuse_prologue: + self.x_input_prologue *= self.scale_input.reshape( + (1, 1, 1, self.in_channel_num) + ).astype( + np.float32 + ) # scale + self.x_input_prologue += self.bias_input.reshape( + (1, 1, 1, self.in_channel_num) + ).astype( + np.float32 + ) # bias + self.x_input_prologue = np.maximum(self.x_input_prologue, 0) # relu + self.x_input_prologue = self.x_input_prologue.astype(self.dtype) + + paddle.disable_static() + paddle.seed(0) + paddle.set_default_dtype(self.dtype) + + self.conv = nn.Conv2D( + in_channels=self.x_size[-1], + out_channels=self.filter_size[0], + kernel_size=self.filter_size[-1], + stride=self.stride, + padding=self.pad, + groups=self.groups, + bias_attr=False, + data_format=self.data_format, + ) + + self.bn = nn.BatchNorm( + self.filter_size[0], + momentum=self.momentum, + epsilon=self.epsilon, + data_layout=self.data_format, + ) + + self.w_input = self.conv.weight.numpy().astype(self.dtype) + self.bn_scale_input = self.bn.weight.numpy() + self.bn_bias_input = self.bn.bias.numpy() + self.bn_running_mean_input = self.bn._mean.numpy() + self.bn_running_var_input = self.bn._variance.numpy() + + ( + y_ref, + running_mean_out_ref, + running_var_out_ref, + saved_mean_out_ref, + saved_invvar_out_ref, + eqscale_ref, + eqbias_ref, + ) = self.calc_ref() + + self.inputs = { + 'x': self.x_input, + 'w': self.w_input, + 'bn_scale': self.bn_scale_input, + 'bn_bias': self.bn_bias_input, + 'input_running_mean': self.bn_running_mean_input, + 'input_running_var': self.bn_running_var_input, + } + if self.fuse_prologue: + extra_inputs = { + 'bias': self.bias_input, + 'scale': self.scale_input, + } + self.inputs.update(extra_inputs) + + self.outputs = { + 'out': y_ref, + 'out_running_mean': running_mean_out_ref, + 'out_running_var': running_var_out_ref, + 'saved_mean': saved_mean_out_ref, + 'saved_var': saved_invvar_out_ref, + 'eq_scale': eqscale_ref, + 'eq_bias': eqbias_ref, + } + + def calc_ref(self): + # Calculate normal (scale + bias + relu +) Conv + BN + x_input_np = self.x_input + if self.fuse_prologue: + x_input_np = self.x_input_prologue + x_tensor = paddle.to_tensor(x_input_np, stop_gradient=False) + after_conv = self.conv(x_tensor) + after_bn = self.bn(after_conv) + # Calculate reference for saved_mean and saved_invvar + after_conv_np = ( + after_conv.numpy() + .astype(np.float32) + .reshape((-1, after_conv.shape[-1])) + ) + mean_np = after_conv_np.mean(axis=0) + var_np = after_conv_np.var(axis=0) + invstd_np = 1 / np.sqrt(var_np + self.epsilon) + # Calculate reference for eqscale and eqbias + eqscale_np = self.bn_scale_input * invstd_np + eqbias_np = ( + self.bn_bias_input - self.bn_scale_input * mean_np * invstd_np + ) + return ( + after_conv.numpy().astype(self.dtype), + self.bn._mean.numpy(), + self.bn._variance.numpy(), + mean_np, + invstd_np, + eqscale_np, + eqbias_np, + ) + + def has_cuda(self): + return core.is_compiled_with_cuda() + + def test_check_output(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, atol=self.atol, rtol=self.rtol, check_dygraph=False + ) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + + self.x_size = [8, 16, 16, 32] # NHWC + self.filter_size = [64, 32, 1, 1] + self.y_size = [8, 16, 16, 64] + self.in_channel_num = self.x_size[-1] + self.out_channel_num = self.y_size[-1] + self.scale_size = [self.in_channel_num] + self.bn_size = [self.out_channel_num] + self.momentum = 0.9 + self.epsilon = 1e-5 + self.accumulation_count = ( + self.y_size[0] * self.y_size[1] * self.y_size[2] + ) + + def init_attr(self): + self.fuse_prologue = True + self.exhaustive_search = False + + +class TestFusedScaleBiasReluConvBnstatsOpNoPrologue( + TestFusedScaleBiasReluConvBnstatsOp +): + def init_attr(self): + self.fuse_prologue = False + self.exhaustive_search = False + + +class TestFusedScaleBiasReluConvBnstatsOpExhaustive( + TestFusedScaleBiasReluConvBnstatsOp +): + def init_attr(self): + self.fuse_prologue = True + self.exhaustive_search = True + + +if __name__ == '__main__': + unittest.main() diff --git a/test/white_list/op_accuracy_white_list.py b/test/white_list/op_accuracy_white_list.py index 12a97a160aa..49b501e765b 100644 --- a/test/white_list/op_accuracy_white_list.py +++ b/test/white_list/op_accuracy_white_list.py @@ -92,6 +92,7 @@ NO_FP16_CHECK_GRAD_OP_LIST = [ NO_FP16_COMPARED_WITH_FP32_OP_LIST = [ 'fake_quantize_moving_average_abs_max', + 'fused_scale_bias_relu_conv_bnstats', 'p_norm', ] -- GitLab