提交 1ce127cb 编写于 作者: S sangoly 提交者: GitHub

fix conv_bias_relu fuse bug & x86 conv kernel bug (#18098)

fix conv_bias_relu fuse bug,x86 conv kernel bug
上级 d1f539c8
......@@ -24,8 +24,11 @@ namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser;
fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
}
} // namespace mir
......
......@@ -85,7 +85,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("Input", {"add_1_out"});
relu_1->SetInput("X", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d");
......@@ -105,7 +105,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("Input", {"add_2_out"});
relu_2->SetInput("X", {"add_2_out"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
......
......@@ -23,21 +23,33 @@ namespace fusion {
void ConvElementwiseAddReLUFuser::BuildPattern() {
// create input nodes.
auto* input = VarNode("input");
auto* filter = VarNode("filter");
auto* bias = VarNode("bias");
auto* input =
VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* filter =
VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput();
auto* bias =
VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput();
// create op nodes
auto* conv2d = OpNode("conv2d", "conv2d");
auto* add = OpNode("add", "elementwise_add");
auto* relu = OpNode("relu", "relu");
auto* conv2d =
OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate();
// create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out");
auto* add_out = VarNode("add_out");
auto* conv2d_out = VarNode("conv2d_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->AsIntermediate();
// create output node
auto* out = VarNode("output");
auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
......@@ -45,19 +57,12 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out;
*add_out >> *relu >> *out;
// Some op specialities.
conv2d_out->AsIntermediate();
add_out->AsIntermediate();
conv2d->AsIntermediate();
add->AsIntermediate();
relu->AsIntermediate();
}
void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create("conv2d");
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op;
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
......@@ -75,7 +80,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
......
......@@ -25,11 +25,14 @@ namespace fusion {
class ConvElementwiseAddReLUFuser : public FuseBase {
public:
explicit ConvElementwiseAddReLUFuser(const std::string& conv_type)
: conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
};
} // namespace fusion
......
......@@ -32,3 +32,4 @@ USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
......@@ -105,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data()));
lite::Tensor out_batch;
out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize(
input_shape.data()));
output_matrix_shape.data()));
for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice;
......
......@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ReluCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -73,19 +73,22 @@ class ConvOpLite : public OpLite {
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front());
if (bias_var != nullptr) {
param_.bias =
const_cast<lite::Tensor*>(&(bias_var->Get<lite::Tensor>()));
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() != 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
param_.bias = &bias_var->Get<lite::Tensor>();
}
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
auto residual_data_var =
scope->FindVar(op_desc.Input("ResidualData").front());
if (residual_data_var != nullptr) {
param_.residualData = const_cast<lite::Tensor*>(
&(residual_data_var->Get<lite::Tensor>()));
auto res_argument = op_desc.Input("ResidualData");
if (res_argument.size() != 0) {
auto residual_data_var = scope->FindVar(res_argument.front());
if (residual_data_var != nullptr) {
param_.residualData = &residual_data_var->Get<lite::Tensor>();
}
}
}
......
......@@ -124,8 +124,8 @@ struct ConcatParam {
struct ConvParam {
lite::Tensor* x{};
lite::Tensor* filter{};
lite::Tensor* bias{};
lite::Tensor* residualData{};
const lite::Tensor* bias{};
const lite::Tensor* residualData{};
lite::Tensor* output{};
std::vector<int> strides{1, 1};
std::vector<int> paddings{0, 0};
......
......@@ -32,7 +32,7 @@ bool ReluOp::InferShape() const {
bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<lite::Tensor>());
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.input);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册