未验证 提交 03dbdbd1 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference]conv_fusion support bias's rank equal to input's rank (#54477)

* support bias's rank equal to input's rank
上级 9f924b03
...@@ -413,15 +413,15 @@ void ConvFusionKernel(const Context& ctx, ...@@ -413,15 +413,15 @@ void ConvFusionKernel(const Context& ctx,
compute_format); compute_format);
DenseTensor transformed_input; DenseTensor transformed_input;
const int input_rank = input.dims().size();
auto unsys_pad_process = [&](const std::vector<int>& new_input_shape_vec, auto unsys_pad_process = [&](const std::vector<int>& new_input_shape_vec,
const std::vector<int>& input_pad) { const std::vector<int>& input_pad) {
DDim new_input_shape(make_ddim(new_input_shape_vec)); DDim new_input_shape(make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape); transformed_input.Resize(new_input_shape);
ctx.template Alloc<T>(&transformed_input); ctx.template Alloc<T>(&transformed_input);
const int rank = input.dims().size();
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (input_rank) {
case 4: { case 4: {
funcs::PadFunction<Context, T, 4>( funcs::PadFunction<Context, T, 4>(
ctx, input_pad, input, pad_value, &transformed_input); ctx, input_pad, input, pad_value, &transformed_input);
...@@ -442,11 +442,16 @@ void ConvFusionKernel(const Context& ctx, ...@@ -442,11 +442,16 @@ void ConvFusionKernel(const Context& ctx,
conv_attr_cache->input_pad); conv_attr_cache->input_pad);
} }
std::vector<int> b_dims(input.dims().size(), 1); std::vector<int> b_dims(input_rank, 1);
if (compute_format == CUDNN_TENSOR_NCHW) { if (compute_format == CUDNN_TENSOR_NCHW) {
auto bias_rank = bias.dims().size();
if (input_rank == bias_rank) {
b_dims[1] = static_cast<int>(bias.dims()[1]);
} else {
b_dims[1] = static_cast<int>(bias.dims()[0]); b_dims[1] = static_cast<int>(bias.dims()[0]);
}
} else { } else {
b_dims[input.dims().size() - 1] = static_cast<int>(bias.dims()[0]); b_dims[input_rank - 1] = static_cast<int>(bias.dims()[0]);
} }
auto search_func = [&](cudnnConvolutionFwdAlgo_t* cudnn_algo, auto search_func = [&](cudnnConvolutionFwdAlgo_t* cudnn_algo,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册