未验证 提交 71e28b12 编写于 作者: T Tian Zheng 提交者: GitHub

Add fused_scale_bias_relu_conv_bnstats OP (#55026)

* Add fused_scale_bias_relu_conv_bnstats op

* Review changes

* Fix no CUDNN Frontend build

* Fix PADDLE_ENFORCE format

* Fix PADDLE_ENFORCE CI error

* Rename kernel filename

* Refactor unittest to use paddle eager_op_test

* Fix padding bugs

* Review changes

* test=cuda117

* test=cuda117
上级 bb2310a6
......@@ -160,6 +160,16 @@
backward: fused_rotary_position_embedding_grad
support_dygraph_mode : true
- op : fused_scale_bias_relu_conv_bnstats
args : (Tensor x, Tensor w, Tensor scale, Tensor bias, Tensor bn_scale, Tensor bn_bias, Tensor input_running_mean, Tensor input_running_var, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, float momentum, float epsilon, bool fuse_prologue, bool exhaustive_search, int64_t accumulation_count = 0)
optional : scale, bias
output : Tensor(out), Tensor(out_running_mean), Tensor(out_running_var), Tensor(saved_mean), Tensor(saved_var), Tensor(eq_scale), Tensor(eq_bias)
infer_meta :
func : FusedScaleBiasReluConvBnstatsInferMeta
kernel :
func : fused_scale_bias_relu_conv_bnstats
data_type : x
- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
......
......@@ -821,4 +821,138 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias) {
auto in_dims = x.dims();
auto filter_dims = w.dims();
// do some checks
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument(
"The input of Op(FusedScaleBiasReluConvBnstats) should be a 4-D "
"Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(),
filter_dims.size(),
phi::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(FusedScaleBiasReluConvBnstats) should be equal. But received: "
"the input's"
" shape is [%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d.",
in_dims,
in_dims.size(),
filter_dims,
filter_dims.size()));
// Check if data format is NHWC
PADDLE_ENFORCE_EQ(
data_format,
"NHWC",
phi::errors::InvalidArgument(
"Operator(FusedScaleBiasReluConvBnstats) only supports data format "
"of "
"channel last (NHWC) now. But recieved: data_format = '%s'.",
data_format));
PADDLE_ENFORCE_EQ(
groups,
1,
phi::errors::InvalidArgument("Expect group to be 1, got %d.", groups));
const auto input_channels = in_dims[in_dims.size() - 1];
int dilation_size = dilations.size();
for (int i = 0; i < dilation_size; ++i) {
PADDLE_ENFORCE_GT(
dilations[i],
0,
phi::errors::InvalidArgument(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d.",
dilations[i]));
}
PADDLE_ENFORCE_EQ(
input_channels,
filter_dims[1] * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(FusedScaleBiasReluConvBnstats). But received: the "
"input's"
" channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d. ",
input_channels,
in_dims,
filter_dims[1],
filter_dims,
groups));
// update paddings and dilations accoring to padding_algorithm
std::vector<int> paddings_vec = paddings;
std::vector<int> dilations_vec = dilations;
// get "HW" from "NHWC"
DDim in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
DDim filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
phi::UpdatePaddingAndDilation(&paddings_vec,
&dilations_vec,
padding_algorithm,
in_data_dims,
strides,
ksize);
std::vector<int64_t> out_shape({in_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
out_shape.push_back(ConvOutSize(in_dims[i + 1],
filter_dims[i + 2],
dilations[i],
paddings_vec[i * 2],
paddings_vec[i * 2 + 1],
strides[i]));
}
out_shape.push_back(filter_dims[0]);
// make shape for other outputs
auto c_dims = phi::make_ddim({filter_dims[0]});
// set output and output max dims
out->set_dims(DDim(out_shape.data(), out_shape.size()));
out_running_mean->set_dims(c_dims);
out_running_var->set_dims(c_dims);
saved_mean->set_dims(c_dims);
saved_var->set_dims(c_dims);
eq_scale->set_dims(c_dims);
eq_bias->set_dims(c_dims);
}
} // namespace phi
......@@ -201,4 +201,32 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
float epsilon,
MetaTensor* out);
void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias);
} // namespace phi
......@@ -94,6 +94,11 @@ if(WITH_CUTLASS)
list(APPEND kernel_cu ${cutlass_cu})
endif()
if(NOT WITH_CUDNN_FRONTEND)
list(REMOVE_ITEM kernel_cu
"fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu")
endif()
set(cc_search_pattern
"*.cc"
"cpu/*.cc"
......
......@@ -47,6 +47,11 @@ std::string AlgorithmTypeString(int64_t algo_type) {
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardFilterV8)) {
return "conv_backward_filter_v8";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kScaleBiasReluConvBNstats)) {
return "scale_bias_relu_conv_bnstats";
} else if (algo_type == static_cast<int64_t>(AlgorithmType::kBNFinalize)) {
return "bn_finalize";
}
#endif
return std::to_string(algo_type);
......
......@@ -55,7 +55,9 @@ enum class AlgorithmType {
kConvForwardV8 = 10,
kConvBackwardDataV8 = 11,
kConvBackwardFilterV8 = 12,
kAlgorithmCount = 13
kScaleBiasReluConvBNstats = 13,
kBNFinalize = 14,
kAlgorithmCount = 15
#endif
};
......@@ -178,9 +180,8 @@ class AutoTuneCache {
conv_auto_tune_map_[key] = cache;
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
} else if (algo_type == AlgorithmType::kConvForwardV8 ||
algo_type == AlgorithmType::kConvBackwardDataV8 ||
algo_type == AlgorithmType::kConvBackwardFilterV8) {
} else if (algo_type >= AlgorithmType::kConvForwardV8 &&
algo_type <= AlgorithmType::kBNFinalize) {
int64_t key = static_cast<int64_t>(algo_type);
if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) {
CudnnFrontendPlanCache cache;
......
......@@ -79,10 +79,10 @@ class CudnnFrontendPlanCache {
return ret;
}
void GetPlan(const cudnn_frontend::feature_vector_t &feature,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
void GetPlanAndWorkspaceSize(const cudnn_frontend::feature_vector_t &feature,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
// Note(tizheng): CUDNNv8 execution plan is not thread-safe.
// A shared plan being executed by different threads is
// generally not safe (for now).
......@@ -90,11 +90,11 @@ class CudnnFrontendPlanCache {
auto &local_map = map_[hasher(std::this_thread::get_id())];
auto it = local_map.find(GetExtendedFeature(feature, handle));
if (it == local_map.end()) {
PADDLE_THROW(phi::errors::InvalidArgument(
"[cudnn_frontend] Cached Plan Not Found."));
return;
}
PADDLE_ENFORCE_NE(it,
local_map.end(),
phi::errors::InvalidArgument(
"[cudnn_frontend] Cached Plan Not Found."));
*plan = &(it->second);
*workspace_size = (*plan)->getWorkspaceSize();
VLOG(4) << "Cached execution plan found." << (*plan)->getTag()
......@@ -133,11 +133,12 @@ class CudnnFrontendPlanCache {
return FindPlan(op_graph.getFeatureVector(), handle);
}
void GetPlan(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
GetPlan(op_graph.getFeatureVector(), plan, workspace_size, handle);
void GetPlanAndWorkspaceSize(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
GetPlanAndWorkspaceSize(
op_graph.getFeatureVector(), plan, workspace_size, handle);
}
void InsertPlan(const cudnn_frontend::OperationGraph &op_graph,
......@@ -176,5 +177,49 @@ class CudnnFrontendPlanCache {
int64_t cache_misses_{0};
}; // class CudnnFrontendPlanCache
template <typename T>
inline void BuildFeatureVectorSingle(cudnn_frontend::feature_vector_t *v,
const T &value) {
v->push_back(static_cast<int64_t>(value));
}
template <>
inline void BuildFeatureVectorSingle(cudnn_frontend::feature_vector_t *v,
const float &value) {
int64_t val = 0;
memcpy(&val, &value, sizeof(float));
v->push_back(val);
}
template <>
inline void BuildFeatureVectorSingle<std::vector<int64_t>>(
cudnn_frontend::feature_vector_t *v, const std::vector<int64_t> &value) {
v->insert(v->end(), value.begin(), value.end());
}
template <>
inline void BuildFeatureVectorSingle<std::vector<int>>(
cudnn_frontend::feature_vector_t *v, const std::vector<int> &value) {
for (auto &val : value) {
v->push_back(static_cast<int64_t>(val));
}
}
template <>
inline void BuildFeatureVectorSingle<std::string>(
cudnn_frontend::feature_vector_t *v, const std::string &value) {
v->push_back(std::hash<std::string>()(value));
}
inline void BuildFeatureVector(cudnn_frontend::feature_vector_t *v) { return; }
template <typename T, typename... Args>
inline void BuildFeatureVector(cudnn_frontend::feature_vector_t *v,
const T &value,
Args... args) {
BuildFeatureVectorSingle(v, value);
BuildFeatureVector(v, args...);
}
} // namespace autotune
} // namespace phi
......@@ -367,6 +367,48 @@ class CudnnFrontendConvHelper {
plan_cache);
}
static void QueryCacheAndExecute(
cudnnHandle_t handle,
phi::DnnWorkspaceHandle* workspace_handle,
cudnn_frontend::OperationGraph* op_graph_pointer,
std::vector<void*>* data_ptrs,
std::vector<int64_t>* uids,
bool exhaustive_search,
bool deterministic,
const cudnn_frontend::feature_vector_t& feature_vector,
phi::autotune::CudnnFrontendPlanCache* plan_cache) {
if (plan_cache->FindPlan(feature_vector, handle)) {
const cudnn_frontend::ExecutionPlan* cached_plan = nullptr;
int64_t workspace_size = 0;
plan_cache->GetPlanAndWorkspaceSize(
feature_vector, &cached_plan, &workspace_size, handle);
ExecutePlan(handle,
workspace_handle,
data_ptrs,
uids,
cached_plan->get_raw_desc(),
workspace_size);
return;
}
auto plans = FindExecutionPlans(op_graph_pointer,
exhaustive_search,
deterministic,
data_ptrs,
uids,
handle,
workspace_handle);
ExecutePlansAndCache(handle,
workspace_handle,
data_ptrs,
uids,
&plans,
exhaustive_search,
feature_vector,
plan_cache);
}
static cudnn_frontend::Operation MakePointwiseOp(
cudnnPointwiseMode_t mode,
cudnnDataType_t dtype,
......@@ -435,7 +477,7 @@ void CudnnConvBwdDataV8(const DenseTensor* dy_tensor,
if (plan_cache_bwd_data.FindPlan(op_graph, handle)) {
const cudnn_frontend::ExecutionPlan* cached_plan = nullptr;
int64_t workspace_size = 0;
plan_cache_bwd_data.GetPlan(
plan_cache_bwd_data.GetPlanAndWorkspaceSize(
op_graph, &cached_plan, &workspace_size, handle);
helper::ExecutePlan(handle,
workspace_handle,
......@@ -509,7 +551,7 @@ void CudnnConvBwdFilterV8(const DenseTensor* x_tensor,
if (plan_cache_bwd_filter.FindPlan(op_graph, handle)) {
const cudnn_frontend::ExecutionPlan* cached_plan = nullptr;
int64_t workspace_size = 0;
plan_cache_bwd_filter.GetPlan(
plan_cache_bwd_filter.GetPlanAndWorkspaceSize(
op_graph, &cached_plan, &workspace_size, handle);
helper::ExecutePlan(handle,
workspace_handle,
......
......@@ -264,7 +264,8 @@ void ConvCudnnKernelImplV8(const DenseTensor* input_tensor,
if (plan_cache.FindPlan(op_graph, handle)) {
const cudnn_frontend::ExecutionPlan* cached_plan = nullptr;
int64_t workspace_size = 0;
plan_cache.GetPlan(op_graph, &cached_plan, &workspace_size, handle);
plan_cache.GetPlanAndWorkspaceSize(
op_graph, &cached_plan, &workspace_size, handle);
helper::ExecutePlan(handle,
&workspace_handle,
input_data,
......
......@@ -503,6 +503,10 @@ if(NOT WITH_GPU
list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass)
endif()
if(NOT WITH_CUDNN_FRONTEND)
list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_relu_conv_bnstats_op)
endif()
# Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest, skip_check_grad_ci
import paddle
from paddle import nn
from paddle.fluid import core
def skip_unit_test():
return (
not paddle.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 8
or paddle.get_cudnn_version() < 8800
)
skip_msg = (
"only support with cuda and CUDNN 8.8 or later,"
" and only Ampere or later devices are supported"
)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(skip_unit_test(), skip_msg)
class TestFusedScaleBiasReluConvBnstatsOp(OpTest):
def setUp(self):
self.__class__.op_type = "fused_scale_bias_relu_conv_bnstats"
self.dtype = np.float16
self.outputs = None
self.padding_algorithm = "EXIPLICIT"
self.data_format = "NHWC"
self.groups = 1
self.init_attr()
self.init_test_case()
self.rtol = 1e-5
self.atol = 2e-2
self.attrs = {
'fuse_prologue': self.fuse_prologue,
'strides': self.stride,
'paddings': self.pad,
'dilations': self.dilations,
'data_format': self.data_format,
'padding_algorithm': self.padding_algorithm,
'accumulation_count': self.accumulation_count,
'momentum': self.momentum,
'epsilon': self.epsilon,
'exhaustive_search': self.exhaustive_search,
'groups': self.groups,
}
# prepare inputs
np.random.seed(0)
self.x_input = np.random.random(self.x_size).astype(self.dtype)
self.bias_input = np.random.random(self.in_channel_num).astype(
self.dtype
)
self.scale_input = np.random.random(self.in_channel_num).astype(
self.dtype
)
self.x_input_prologue = self.x_input.astype(np.float32)
if self.fuse_prologue:
self.x_input_prologue *= self.scale_input.reshape(
(1, 1, 1, self.in_channel_num)
).astype(
np.float32
) # scale
self.x_input_prologue += self.bias_input.reshape(
(1, 1, 1, self.in_channel_num)
).astype(
np.float32
) # bias
self.x_input_prologue = np.maximum(self.x_input_prologue, 0) # relu
self.x_input_prologue = self.x_input_prologue.astype(self.dtype)
paddle.disable_static()
paddle.seed(0)
paddle.set_default_dtype(self.dtype)
self.conv = nn.Conv2D(
in_channels=self.x_size[-1],
out_channels=self.filter_size[0],
kernel_size=self.filter_size[-1],
stride=self.stride,
padding=self.pad,
groups=self.groups,
bias_attr=False,
data_format=self.data_format,
)
self.bn = nn.BatchNorm(
self.filter_size[0],
momentum=self.momentum,
epsilon=self.epsilon,
data_layout=self.data_format,
)
self.w_input = self.conv.weight.numpy().astype(self.dtype)
self.bn_scale_input = self.bn.weight.numpy()
self.bn_bias_input = self.bn.bias.numpy()
self.bn_running_mean_input = self.bn._mean.numpy()
self.bn_running_var_input = self.bn._variance.numpy()
(
y_ref,
running_mean_out_ref,
running_var_out_ref,
saved_mean_out_ref,
saved_invvar_out_ref,
eqscale_ref,
eqbias_ref,
) = self.calc_ref()
self.inputs = {
'x': self.x_input,
'w': self.w_input,
'bn_scale': self.bn_scale_input,
'bn_bias': self.bn_bias_input,
'input_running_mean': self.bn_running_mean_input,
'input_running_var': self.bn_running_var_input,
}
if self.fuse_prologue:
extra_inputs = {
'bias': self.bias_input,
'scale': self.scale_input,
}
self.inputs.update(extra_inputs)
self.outputs = {
'out': y_ref,
'out_running_mean': running_mean_out_ref,
'out_running_var': running_var_out_ref,
'saved_mean': saved_mean_out_ref,
'saved_var': saved_invvar_out_ref,
'eq_scale': eqscale_ref,
'eq_bias': eqbias_ref,
}
def calc_ref(self):
# Calculate normal (scale + bias + relu +) Conv + BN
x_input_np = self.x_input
if self.fuse_prologue:
x_input_np = self.x_input_prologue
x_tensor = paddle.to_tensor(x_input_np, stop_gradient=False)
after_conv = self.conv(x_tensor)
after_bn = self.bn(after_conv)
# Calculate reference for saved_mean and saved_invvar
after_conv_np = (
after_conv.numpy()
.astype(np.float32)
.reshape((-1, after_conv.shape[-1]))
)
mean_np = after_conv_np.mean(axis=0)
var_np = after_conv_np.var(axis=0)
invstd_np = 1 / np.sqrt(var_np + self.epsilon)
# Calculate reference for eqscale and eqbias
eqscale_np = self.bn_scale_input * invstd_np
eqbias_np = (
self.bn_bias_input - self.bn_scale_input * mean_np * invstd_np
)
return (
after_conv.numpy().astype(self.dtype),
self.bn._mean.numpy(),
self.bn._variance.numpy(),
mean_np,
invstd_np,
eqscale_np,
eqbias_np,
)
def has_cuda(self):
return core.is_compiled_with_cuda()
def test_check_output(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(
place, atol=self.atol, rtol=self.rtol, check_dygraph=False
)
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.x_size = [8, 16, 16, 32] # NHWC
self.filter_size = [64, 32, 1, 1]
self.y_size = [8, 16, 16, 64]
self.in_channel_num = self.x_size[-1]
self.out_channel_num = self.y_size[-1]
self.scale_size = [self.in_channel_num]
self.bn_size = [self.out_channel_num]
self.momentum = 0.9
self.epsilon = 1e-5
self.accumulation_count = (
self.y_size[0] * self.y_size[1] * self.y_size[2]
)
def init_attr(self):
self.fuse_prologue = True
self.exhaustive_search = False
class TestFusedScaleBiasReluConvBnstatsOpNoPrologue(
TestFusedScaleBiasReluConvBnstatsOp
):
def init_attr(self):
self.fuse_prologue = False
self.exhaustive_search = False
class TestFusedScaleBiasReluConvBnstatsOpExhaustive(
TestFusedScaleBiasReluConvBnstatsOp
):
def init_attr(self):
self.fuse_prologue = True
self.exhaustive_search = True
if __name__ == '__main__':
unittest.main()
......@@ -92,6 +92,7 @@ NO_FP16_CHECK_GRAD_OP_LIST = [
NO_FP16_COMPARED_WITH_FP32_OP_LIST = [
'fake_quantize_moving_average_abs_max',
'fused_scale_bias_relu_conv_bnstats',
'p_norm',
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册