未验证 提交 eddf1ad6 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]add conv_fuse pass && kernel (#52247)

上级 d8081f22
......@@ -224,6 +224,7 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass xpu_quant_utils)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
......
此差异已折叠。
......@@ -532,6 +532,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass",
"link_xpu_op_max_pass",
"inplace_op_var_pass",
"delete_isolated_node_pass",
......
......@@ -4,6 +4,16 @@
# if one operator have "support_dygraph_mode : true", it supports dygraph mode,
# otherwise the operator only could be used in static mode.
- op : conv2d_xpu
args : (Tensor input, Tensor input_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
output : Tensor(output), Tensor(output_max)
infer_meta :
func : Conv2dXPUInferMeta
kernel :
func : conv2d_xpu
data_type : input
optional : bias, branch, input_max
- op : embedding_with_eltwise_add_xpu
args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx)
output: Tensor
......
......@@ -58,6 +58,7 @@ XPUOpMap& get_kl1_ops() {
{"concat_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d_xpu", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -151,6 +151,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"conv2d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"conv2d_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"conv3d_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"conv3d",
......
......@@ -18,9 +18,149 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace phi {
inline int ConvOutSize(int input_size,
int filter_size,
int dilation,
int pad_left,
int pad_right,
int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size =
(input_size + (pad_left + pad_right) - dkernel) / stride + 1;
return output_size;
}
void Conv2dXPUInferMeta(const MetaTensor& input,
const MetaTensor& input_max,
const MetaTensor& filter,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
bool has_bias,
bool has_branch,
int act_type,
float act_param,
MetaTensor* output,
MetaTensor* output_max) {
auto in_dims = input.dims();
auto filter_dims = filter.dims();
// do some checks
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument(
"The input of Op(Conv_xpu) should be a 4-D Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(),
filter_dims.size(),
phi::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(Conv_xpu) should be equal. But received: the input's shape is "
"[%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d.",
in_dims,
in_dims.size(),
filter_dims,
filter_dims.size()));
const auto input_channels = in_dims[1];
int stride_size = strides.size();
int in_sub_stride_size = in_dims.size() - stride_size;
int dilation_size = dilations.size();
PADDLE_ENFORCE_EQ(
in_dims.size(),
strides.size() + 2U,
phi::errors::InvalidArgument(
"The difference of input's dimension and Attr(strides)'s "
"length must be euqal to 2 for Op(Conv_xpu). "
"But received: input's dimension is %d, input's shape is [%s]; "
"Attr(stride)'s length is %d, Attr(stride) is [%s]; "
"difference of input's dimention and Attr(strides)'s length = %u.",
in_dims.size(),
in_dims,
strides.size(),
phi::make_ddim(strides),
in_sub_stride_size));
for (int i = 0; i < dilation_size; ++i) {
PADDLE_ENFORCE_GT(
dilations[i],
0,
phi::errors::InvalidArgument(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d.",
dilations[i]));
}
PADDLE_ENFORCE_EQ(
input_channels,
filter_dims[1] * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(Conv_xpu). But received: the input's channels is "
"%d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d. ",
input_channels,
in_dims,
filter_dims[1],
filter_dims,
groups));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups,
0,
phi::errors::InvalidArgument(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d.",
filter_dims[0],
filter_dims,
groups));
// update paddings and dilations accoring to padding_algorithm
std::vector<int> paddings_vec = paddings;
std::vector<int> dilations_vec = dilations;
DDim in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
DDim filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
phi::UpdatePaddingAndDilation(&paddings_vec,
&dilations_vec,
padding_algorithm,
in_data_dims,
strides,
ksize);
std::vector<int64_t> out_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
out_shape.push_back(ConvOutSize(in_dims[i + 2],
filter_dims[i + 2],
dilations[i],
paddings_vec[i * 2],
paddings_vec[i * 2 + 1],
strides[i]));
}
// set output and output max dims
output->set_dims(DDim(out_shape.data(), out_shape.size()));
output_max->set_dims(phi::make_ddim({4}));
}
void EmbeddingWithEltwiseAddXPUInferMeta(
const std::vector<const MetaTensor*>& ids,
const std::vector<const MetaTensor*>& tables,
......
......@@ -22,6 +22,24 @@ namespace phi {
// Common InferMeta Functions for fusion operators.
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
void Conv2dXPUInferMeta(const MetaTensor& input,
const MetaTensor& input_max,
const MetaTensor& filter,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
bool has_bias,
bool has_branch,
int act_type,
float act_param,
MetaTensor* output,
MetaTensor* output_max);
void EmbeddingWithEltwiseAddXPUInferMeta(
const std::vector<const MetaTensor*>& ids,
const std::vector<const MetaTensor*>& tables,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void Conv2dXPUKernel(const Context& ctx,
const DenseTensor& input,
const paddle::optional<DenseTensor>& input_max,
const DenseTensor& filter,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
bool has_bias,
bool has_branch,
int act_type,
float act_param,
DenseTensor* output,
DenseTensor* output_max) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto input_dims = input.dims();
auto filter_dims = filter.dims();
// update paddings and dilations accoring to padding_algorithm
std::vector<int> paddings_vec = paddings;
std::vector<int> dilations_vec = dilations;
DDim in_data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
DDim filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
phi::UpdatePaddingAndDilation(&paddings_vec,
&dilations_vec,
padding_algorithm,
in_data_dims,
strides,
ksize);
int batch = static_cast<int>(input_dims[0]);
int in_c = static_cast<int>(input_dims[1]);
int in_h = static_cast<int>(input_dims[2]);
int in_w = static_cast<int>(input_dims[3]);
int out_c = static_cast<int>(filter_dims[0]);
int win_h = static_cast<int>(filter_dims[2]);
int win_w = static_cast<int>(filter_dims[3]);
auto* input_data = reinterpret_cast<const XPUType*>(input.data<T>());
const float* input_max_data = input_max.get_ptr() == nullptr
? nullptr
: input_max.get_ptr()->data<float>();
auto* branch_data =
branch.get_ptr() == nullptr
? nullptr
: reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>());
const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(output));
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == xpu::Activation_t::LEAKY_RELU) {
act.leaky_alpha = act_param;
} else if (act_type == xpu::Activation_t::HARD_SIGMOID) {
act.hard_sigmoid_slope = act_param;
}
int r =
xpu::conv2d_fusion<XPUType, int16_t, XPUType, int16_t>( // TX/TW/TY/TGEMM
/* baidu::xpu::api::Context* ctx */ ctx.x_context(),
/* const TX* input */ input_data,
/* const TW* filter */ filter.data<int16_t>(),
/* TY* output */ out_data,
/* int64_t n */ batch,
/* int64_t ic */ in_c,
/* int64_t h */ in_h,
/* int64_t w */ in_w,
/* int64_t oc */ out_c,
/* const std::vector<int>& ksize */ std::vector<int>{win_h, win_w},
/* const std::vector<int>& strides */ strides,
/* const std::vector<int>& paddings */ paddings_vec,
/* const std::vector<int>& dilations */ dilations_vec,
/* int64_t groups */ groups,
/* const float* in_maxptr */ input_max_data,
/* const float* filter_maxptr */ filter_max.data<float>(),
/* float* out_maxptr */ ctx.template Alloc<float>(output_max),
/* bool is_nchw */ true,
/* const float* bias */ bias_data,
/* const TY* branch */ branch_data,
/* const baidu::xpu::api::Activation_t& act */ act,
/* const float* branch_maxptr */ nullptr);
// /* const float* scale */ nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu");
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(conv2d_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::Conv2dXPUKernel,
float,
phi::dtype::float16) {}
......@@ -45,9 +45,9 @@ void FcXPUKernel(const Context& ctx,
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == 5) {
if (act_type == xpu::Activation_t::LEAKY_RELU) {
act.leaky_alpha = act_alpha;
} else if (act_type == 15) {
} else if (act_type == xpu::Activation_t::HARD_SIGMOID) {
act.hard_sigmoid_slope = act_alpha;
}
int r =
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestConv2dXPUFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["conv2d_xpu"], (1e-3, 1e-3)
def is_program_valid(self, prog_config):
paddings = prog_config.ops[0].attrs["paddings"]
strides = prog_config.ops[0].attrs["strides"]
groups = prog_config.ops[0].attrs["groups"]
padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
dilations = prog_config.ops[0].attrs["dilations"]
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["conv2d_weight"].shape
input_shape = prog_config.inputs["conv2d_input"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID":
if (
(input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1))
/ strides[0]
+ 1
) <= 1 or (
(input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1))
/ strides[1]
+ 1
) <= 1:
return False
if padding_algorithm == "EXPLICIT":
if (
(
input_shape[2]
+ paddings[0]
+ paddings[1]
- (dilations[0] * (filter_shape[2] - 1) + 1)
)
/ strides[0]
+ 1
) <= 1 or (
(
input_shape[3]
+ paddings[2]
+ paddings[3]
- (dilations[1] * (filter_shape[3] - 1) + 1)
)
/ strides[1]
+ 1
) <= 1:
return False
if data_format == "NCHW":
if input_shape[1] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
return True
def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW"]))
x_shape = draw(
st.lists(
st.integers(min_value=12, max_value=12), min_size=4, max_size=4
)
)
x_shape[1] = draw(st.integers(min_value=1, max_value=10))
# 3. Generate legal shape of input:Y of conv2d
w_shape = draw(
st.lists(
st.integers(min_value=3, max_value=3), min_size=4, max_size=4
)
)
if data_format == "NCHW":
w_shape[1] = x_shape[1]
padding_algorithm = draw(st.sampled_from(["SAME", "VALID"]))
groups = draw(st.integers(min_value=1, max_value=1))
dilations = draw(
st.lists(
st.integers(min_value=1, max_value=1), min_size=2, max_size=2
)
)
paddings = draw(
st.lists(
st.integers(min_value=1, max_value=1), min_size=2, max_size=2
)
)
strides = draw(
st.lists(
st.integers(min_value=1, max_value=1), min_size=2, max_size=2
)
)
axis = 1
ew_bias_shape = [w_shape[0]]
# Random choose if add a relu operator
has_relu = True
def generate_data(shape):
return np.random.random(shape).astype(np.float32)
# Here we will compose a program
# Still has some risks that the program is invalid or cause bug while running
# Use function `is_program_valid` to filter the invalid programs before running
# Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["conv2d_input"],
"Filter": ["conv2d_weight"],
},
outputs={"Output": ["conv2d_out"]},
data_format=data_format,
dilations=dilations,
padding_algorithm=padding_algorithm,
groups=groups,
paddings=paddings,
strides=strides,
has_bias=False,
)
ew_bias_op = OpConfig(
"elementwise_add",
inputs={"X": ["conv2d_out"], "Y": ["ew_bias"]},
outputs={"Out": ["add_out"]},
axis=axis,
)
ops = [conv2d_op, ew_bias_op]
# 3. activation
if has_relu:
relu_op = OpConfig(
"relu", inputs={"X": ["add_out"]}, outputs={"Out": ["relu_out"]}
)
ops.append(relu_op)
program_config = ProgramConfig(
ops=ops,
inputs={
"conv2d_input": TensorConfig(
data_gen=partial(generate_data, x_shape)
),
},
weights={
"conv2d_weight": TensorConfig(
data_gen=partial(generate_data, w_shape)
),
"ew_bias": TensorConfig(
data_gen=partial(generate_data, ew_bias_shape)
),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["conv2d_xpu_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册