未验证 提交 8d2351c3 编写于 作者: W weihaoji 提交者: GitHub

[XPU] add resnet50-D fusion (#4276)

上级 7fb2261d
......@@ -62,6 +62,7 @@ USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass)
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_d_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
......
......@@ -307,7 +307,7 @@ class XPUResNetBlock0Fuser : public FuseBase {
matched.at("right_bn1_variance")->arg()->name,
});
op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name});
// XXX: keep these to fool SubgraphOp::AttachImpl()
// keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
op_desc.SetAttr<std::vector<std::string>>("input_data_names", {});
op_desc.SetAttr<std::vector<std::string>>("output_data_names", {});
......@@ -570,7 +570,7 @@ class XPUResNetBlock1Fuser : public FuseBase {
matched.at("right_bn3_variance")->arg()->name,
});
op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name});
// XXX: keep these to fool SubgraphOp::AttachImpl()
// keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
op_desc.SetAttr<std::vector<std::string>>("input_data_names", {});
op_desc.SetAttr<std::vector<std::string>>("output_data_names", {});
......@@ -599,9 +599,658 @@ class XPUResNetBlock1Fuser : public FuseBase {
}
};
class XPUResNetDtypeBlock0Fuser : public FuseBase {
public:
XPUResNetDtypeBlock0Fuser() {}
void BuildPattern() override {
auto* input = VarNode("input")
->assert_is_op_input("conv2d", "Input")
->assert_is_op_input("pool2d", "X")
->AsInput();
auto* left_conv1_weight = VarNode("left_conv1_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* left_conv1 = OpNode("left_conv1", "conv2d");
auto* left_conv1_out = VarNode("left_conv1_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* left_bn1_scale = VarNode("left_bn1_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* left_bn1_bias = VarNode("left_bn1_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* left_bn1_mean = VarNode("left_bn1_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* left_bn1_var = VarNode("left_bn1_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* left_bn1 = OpNode("left_bn1", "batch_norm")->AsIntermediate();
auto* left_bn1_out = VarNode("left_bn1_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* left_bn1_mean_out = VarNode("left_bn1_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* left_bn1_var_out =
VarNode("left_bn1_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* left_bn1_saved_mean =
VarNode("left_bn1_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* left_bn1_saved_var =
VarNode("left_bn1_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* left_relu1 = OpNode("left_relu1", "relu")->AsIntermediate();
auto* left_relu1_out = VarNode("left_relu1_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* left_conv2_weight = VarNode("left_conv2_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* left_conv2 = OpNode("left_conv2", "conv2d")->AsIntermediate();
auto* left_conv2_out = VarNode("left_conv2_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* left_bn2_scale = VarNode("left_bn2_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* left_bn2_bias = VarNode("left_bn2_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* left_bn2_mean = VarNode("left_bn2_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* left_bn2_var = VarNode("left_bn2_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* left_bn2 = OpNode("left_bn2", "batch_norm")->AsIntermediate();
auto* left_bn2_out = VarNode("left_bn2_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* left_bn2_mean_out = VarNode("left_bn2_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* left_bn2_var_out =
VarNode("left_bn2_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* left_bn2_saved_mean =
VarNode("left_bn2_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* left_bn2_saved_var =
VarNode("left_bn2_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* left_relu2 = OpNode("left_relu2", "relu")->AsIntermediate();
auto* left_relu2_out = VarNode("left_relu2_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* left_conv3_weight = VarNode("left_conv3_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* left_conv3 = OpNode("left_conv3", "conv2d")->AsIntermediate();
auto* left_conv3_out = VarNode("left_conv3_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* left_bn3_scale = VarNode("left_bn3_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* left_bn3_bias = VarNode("left_bn3_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* left_bn3_mean = VarNode("left_bn3_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* left_bn3_var = VarNode("left_bn3_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* left_bn3 = OpNode("left_bn3", "batch_norm")->AsIntermediate();
auto* left_bn3_out = VarNode("left_bn3_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("elementwise_add", "Y")
->AsIntermediate();
auto* left_bn3_mean_out = VarNode("left_bn3_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* left_bn3_var_out =
VarNode("left_bn3_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* left_bn3_saved_mean =
VarNode("left_bn3_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* left_bn3_saved_var =
VarNode("left_bn3_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* right_pool = OpNode("right_pool", "pool2d")->AsIntermediate();
auto* right_pool_out = VarNode("right_pool_out")
->assert_is_op_output("pool2d", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* right_conv1_weight = VarNode("right_conv1_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* right_conv1 = OpNode("right_conv1", "conv2d")->AsIntermediate();
auto* right_conv1_out = VarNode("right_conv1_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* right_bn1_scale = VarNode("right_bn1_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* right_bn1_bias = VarNode("right_bn1_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* right_bn1_mean = VarNode("right_bn1_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* right_bn1_var = VarNode("right_bn1_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* right_bn1 = OpNode("right_bn1", "batch_norm")->AsIntermediate();
auto* right_bn1_out = VarNode("right_bn1_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* right_bn1_mean_out =
VarNode("right_bn1_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* right_bn1_var_out =
VarNode("right_bn1_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* right_bn1_saved_mean =
VarNode("right_bn1_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* right_bn1_saved_var =
VarNode("right_bn1_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* add = OpNode("add", "elementwise_add")->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->AsIntermediate();
auto* relu_out =
VarNode("relu_out")->assert_is_op_output("relu", "Out")->AsOutput();
*input >> *left_conv1 >> *left_conv1_out >> *left_bn1 >> *left_bn1_out >>
*left_relu1 >> *left_relu1_out >> *left_conv2 >> *left_conv2_out >>
*left_bn2 >> *left_bn2_out >> *left_relu2 >> *left_relu2_out >>
*left_conv3 >> *left_conv3_out >> *left_bn3 >> *left_bn3_out >> *add;
*left_conv1_weight >> *left_conv1;
*left_bn1_scale >> *left_bn1;
*left_bn1_bias >> *left_bn1;
*left_bn1_mean >> *left_bn1;
*left_bn1_var >> *left_bn1;
*left_bn1 >> *left_bn1_mean_out;
*left_bn1 >> *left_bn1_var_out;
*left_bn1 >> *left_bn1_saved_mean;
*left_bn1 >> *left_bn1_saved_var;
*left_conv2_weight >> *left_conv2;
*left_bn2_scale >> *left_bn2;
*left_bn2_bias >> *left_bn2;
*left_bn2_mean >> *left_bn2;
*left_bn2_var >> *left_bn2;
*left_bn2 >> *left_bn2_mean_out;
*left_bn2 >> *left_bn2_var_out;
*left_bn2 >> *left_bn2_saved_mean;
*left_bn2 >> *left_bn2_saved_var;
*left_conv3_weight >> *left_conv3;
*left_bn3_scale >> *left_bn3;
*left_bn3_bias >> *left_bn3;
*left_bn3_mean >> *left_bn3;
*left_bn3_var >> *left_bn3;
*left_bn3 >> *left_bn3_mean_out;
*left_bn3 >> *left_bn3_var_out;
*left_bn3 >> *left_bn3_saved_mean;
*left_bn3 >> *left_bn3_saved_var;
*input >> *right_pool >> *right_pool_out >> *right_conv1 >>
*right_conv1_out >> *right_bn1 >> *right_bn1_out >> *add;
*right_conv1_weight >> *right_conv1;
*right_bn1_scale >> *right_bn1;
*right_bn1_bias >> *right_bn1;
*right_bn1_mean >> *right_bn1;
*right_bn1_var >> *right_bn1;
*right_bn1 >> *right_bn1_mean_out;
*right_bn1 >> *right_bn1_var_out;
*right_bn1 >> *right_bn1_saved_mean;
*right_bn1 >> *right_bn1_saved_var;
*add >> *add_out >> *relu >> *relu_out;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("resnet_block0_d");
op_desc.SetInput("Inputs", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter",
{
matched.at("left_conv1_weight")->arg()->name,
matched.at("left_conv2_weight")->arg()->name,
matched.at("left_conv3_weight")->arg()->name,
matched.at("right_conv1_weight")->arg()->name,
});
op_desc.SetInput("Scale",
{
matched.at("left_bn1_scale")->arg()->name,
matched.at("left_bn2_scale")->arg()->name,
matched.at("left_bn3_scale")->arg()->name,
matched.at("right_bn1_scale")->arg()->name,
});
op_desc.SetInput("Bias",
{
matched.at("left_bn1_bias")->arg()->name,
matched.at("left_bn2_bias")->arg()->name,
matched.at("left_bn3_bias")->arg()->name,
matched.at("right_bn1_bias")->arg()->name,
});
op_desc.SetInput("Mean",
{
matched.at("left_bn1_mean")->arg()->name,
matched.at("left_bn2_mean")->arg()->name,
matched.at("left_bn3_mean")->arg()->name,
matched.at("right_bn1_mean")->arg()->name,
});
op_desc.SetInput("Var",
{
matched.at("left_bn1_variance")->arg()->name,
matched.at("left_bn2_variance")->arg()->name,
matched.at("left_bn3_variance")->arg()->name,
matched.at("right_bn1_variance")->arg()->name,
});
op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name});
// keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
op_desc.SetAttr<std::vector<std::string>>("input_data_names", {});
op_desc.SetAttr<std::vector<std::string>>("output_data_names", {});
auto block0_stmt = matched.at("left_conv1")->stmt();
// block0_stmt->ResetOp(op_desc, graph->valid_places());
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places());
block0_stmt->SetOp(fake_subgraph_op);
std::vector<std::string> froms = {
"left_conv2_weight",
"left_conv3_weight",
"right_conv1_weight",
"left_bn1_bias",
"left_bn2_bias",
"left_bn3_bias",
"right_bn1_bias",
};
for (auto& from : froms) {
IR_NODE_LINK_TO(matched.at(from), matched.at("left_conv1"));
}
IR_OP_VAR_LINK(matched.at("left_conv1"), matched.at("relu_out"));
}
};
class XPUResNet50Fuser : public xpu::XPUFuseBase {
public:
XPUResNet50Fuser() {}
XPUResNet50Fuser() {}
void BuildPattern() override {
auto* input =
VarNode("input")->assert_is_op_input("conv2d", "Input")->AsInput();
auto* top_conv_weight = VarNode("top_conv_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* top_conv = OpNode("top_conv", "conv2d");
auto* top_conv_out = VarNode("top_conv_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* top_bn_scale = VarNode("top_bn_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* top_bn_bias = VarNode("top_bn_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* top_bn_mean = VarNode("top_bn_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* top_bn_var = VarNode("top_bn_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* top_bn = OpNode("top_bn", "batch_norm")->AsIntermediate();
auto* top_bn_out = VarNode("top_bn_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* top_bn_mean_out = VarNode("top_bn_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* top_bn_var_out =
VarNode("top_bn_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* top_bn_saved_mean =
VarNode("top_bn_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* top_bn_saved_var =
VarNode("top_bn_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* top_relu = OpNode("top_relu", "relu")->AsIntermediate();
auto* top_relu_out = VarNode("top_relu_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("pool2d", "X")
->AsIntermediate();
auto* top_pool = OpNode("top_pool", "pool2d")->AsIntermediate();
auto* top_pool_out = VarNode("top_pool_out")
->assert_is_op_output("pool2d", "Out")
->assert_is_op_input("resnet_block0", "Inputs")
->AsIntermediate();
// args are left out
auto* resnet_block0_1 =
OpNode("resnet_block0_1", "resnet_block0")->AsIntermediate();
auto* resnet_block0_1_out =
VarNode("resnet_block0_1_out")
->assert_is_op_output("resnet_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_1_1 =
OpNode("resnet_block1_1_1", "resnet_block1")->AsIntermediate();
auto* resnet_block1_1_1_out =
VarNode("resnet_block1_1_1_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_1_2 =
OpNode("resnet_block1_1_2", "resnet_block1")->AsIntermediate();
auto* resnet_block1_1_2_out =
VarNode("resnet_block1_1_2_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block0_2 =
OpNode("resnet_block0_2", "resnet_block0")->AsIntermediate();
auto* resnet_block0_2_out =
VarNode("resnet_block0_2_out")
->assert_is_op_output("resnet_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_2_1 =
OpNode("resnet_block1_2_1", "resnet_block1")->AsIntermediate();
auto* resnet_block1_2_1_out =
VarNode("resnet_block1_2_1_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_2_2 =
OpNode("resnet_block1_2_2", "resnet_block1")->AsIntermediate();
auto* resnet_block1_2_2_out =
VarNode("resnet_block1_2_2_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_2_3 =
OpNode("resnet_block1_2_3", "resnet_block1")->AsIntermediate();
auto* resnet_block1_2_3_out =
VarNode("resnet_block1_2_3_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block0_3 =
OpNode("resnet_block0_3", "resnet_block0")->AsIntermediate();
auto* resnet_block0_3_out =
VarNode("resnet_block0_3_out")
->assert_is_op_output("resnet_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_1 =
OpNode("resnet_block1_3_1", "resnet_block1")->AsIntermediate();
auto* resnet_block1_3_1_out =
VarNode("resnet_block1_3_1_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_2 =
OpNode("resnet_block1_3_2", "resnet_block1")->AsIntermediate();
auto* resnet_block1_3_2_out =
VarNode("resnet_block1_3_2_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_3 =
OpNode("resnet_block1_3_3", "resnet_block1")->AsIntermediate();
auto* resnet_block1_3_3_out =
VarNode("resnet_block1_3_3_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_4 =
OpNode("resnet_block1_3_4", "resnet_block1")->AsIntermediate();
auto* resnet_block1_3_4_out =
VarNode("resnet_block1_3_4_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_5 =
OpNode("resnet_block1_3_5", "resnet_block1")->AsIntermediate();
auto* resnet_block1_3_5_out =
VarNode("resnet_block1_3_5_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block0_4 =
OpNode("resnet_block0_4", "resnet_block0")->AsIntermediate();
auto* resnet_block0_4_out =
VarNode("resnet_block0_4_out")
->assert_is_op_output("resnet_block0", "Outputs")
->AsIntermediate();
auto* resnet_block1_4_1 =
OpNode("resnet_block1_4_1", "resnet_block1")->AsIntermediate();
auto* resnet_block1_4_1_out =
VarNode("resnet_block1_4_1_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* resnet_block1_4_2 =
OpNode("resnet_block1_4_2", "resnet_block1")->AsIntermediate();
auto* resnet_block1_4_2_out =
VarNode("resnet_block1_4_2_out")
->assert_is_op_output("resnet_block1", "Outputs")
->AsIntermediate();
auto* bottom_pool = OpNode("bottom_pool", "pool2d")->AsIntermediate();
auto* bottom_pool_out = VarNode("bottom_pool_out")
->assert_is_op_output("pool2d", "Out")
->AsOutput();
*input >> *top_conv >> *top_conv_out >> *top_bn >> *top_bn_out >>
*top_relu >> *top_relu_out >> *top_pool >> *top_pool_out >>
*resnet_block0_1 >> *resnet_block0_1_out >> *resnet_block1_1_1 >>
*resnet_block1_1_1_out >> *resnet_block1_1_2 >>
*resnet_block1_1_2_out >> *resnet_block0_2 >> *resnet_block0_2_out >>
*resnet_block1_2_1 >> *resnet_block1_2_1_out >> *resnet_block1_2_2 >>
*resnet_block1_2_2_out >> *resnet_block1_2_3 >>
*resnet_block1_2_3_out >> *resnet_block0_3 >> *resnet_block0_3_out >>
*resnet_block1_3_1 >> *resnet_block1_3_1_out >> *resnet_block1_3_2 >>
*resnet_block1_3_2_out >> *resnet_block1_3_3 >>
*resnet_block1_3_3_out >> *resnet_block1_3_4 >>
*resnet_block1_3_4_out >> *resnet_block1_3_5 >>
*resnet_block1_3_5_out >> *resnet_block0_4 >> *resnet_block0_4_out >>
*resnet_block1_4_1 >> *resnet_block1_4_1_out >> *resnet_block1_4_2 >>
*resnet_block1_4_2_out >> *bottom_pool >> *bottom_pool_out;
*top_conv_weight >> *top_conv;
*top_bn_scale >> *top_bn;
*top_bn_bias >> *top_bn;
*top_bn_mean >> *top_bn;
*top_bn_var >> *top_bn;
*top_bn >> *top_bn_mean_out;
*top_bn >> *top_bn_var_out;
*top_bn >> *top_bn_saved_mean;
*top_bn >> *top_bn_saved_var;
}
void InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched,
const std::vector<Node*>& extra_input_vars) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__resnet50");
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
std::vector<std::string> filter_name = {
matched.at("top_conv_weight")->arg()->name};
std::vector<std::string> scale_name = {
matched.at("top_bn_scale")->arg()->name};
std::vector<std::string> bias_name = {
matched.at("top_bn_bias")->arg()->name};
std::vector<std::string> mean_name = {
matched.at("top_bn_mean")->arg()->name};
std::vector<std::string> var_name = {
matched.at("top_bn_variance")->arg()->name};
std::vector<std::string> max_filter_name;
std::vector<std::string> resnet_block_vec = {
"resnet_block0_1",
"resnet_block1_1_1",
"resnet_block1_1_2",
"resnet_block0_2",
"resnet_block1_2_1",
"resnet_block1_2_2",
"resnet_block1_2_3",
"resnet_block0_3",
"resnet_block1_3_1",
"resnet_block1_3_2",
"resnet_block1_3_3",
"resnet_block1_3_4",
"resnet_block1_3_5",
"resnet_block0_4",
"resnet_block1_4_1",
"resnet_block1_4_2",
};
for (auto& block : resnet_block_vec) {
auto* block_op_info = matched.at(block)->stmt()->op_info();
auto block_filter_name = block_op_info->Input("Filter");
std::copy(block_filter_name.begin(),
block_filter_name.end(),
std::back_inserter(filter_name));
auto block_scale_name = block_op_info->Input("Scale");
std::copy(block_scale_name.begin(),
block_scale_name.end(),
std::back_inserter(scale_name));
auto block_bias_name = block_op_info->Input("Bias");
std::copy(block_bias_name.begin(),
block_bias_name.end(),
std::back_inserter(bias_name));
auto block_mean_name = block_op_info->Input("Mean");
std::copy(block_mean_name.begin(),
block_mean_name.end(),
std::back_inserter(mean_name));
auto block_var_name = block_op_info->Input("Var");
std::copy(block_var_name.begin(),
block_var_name.end(),
std::back_inserter(var_name));
}
op_desc.SetInput("Filter", filter_name);
op_desc.SetInput("Bias", bias_name);
op_desc.SetOutput("Output", {matched.at("bottom_pool_out")->arg()->name});
op_desc.SetAttr<int>("xpu", 1);
auto* resnet50_stmt = matched.at("top_conv")->stmt();
auto* scope = resnet50_stmt->op()->scope();
for (size_t i = 0; i < filter_name.size(); ++i) {
auto* filter_t = scope->FindMutableTensor(filter_name[i]);
auto* scale_t = scope->FindMutableTensor(scale_name[i]);
auto* bias_t = scope->FindMutableTensor(bias_name[i]);
auto* mean_t = scope->FindMutableTensor(mean_name[i]);
auto* var_t = scope->FindMutableTensor(var_name[i]);
int mean_len = mean_t->numel();
int filter_len = filter_t->numel();
int filter_stride = filter_len / mean_len;
float* filter_on_host = filter_t->mutable_data<float>();
float* scale_on_host = scale_t->mutable_data<float>();
float* bias_on_host = bias_t->mutable_data<float>();
float* mean_on_host = mean_t->mutable_data<float>();
float* var_on_host = var_t->mutable_data<float>();
// Perform preprocess
for (int i = 0; i < mean_len; ++i) {
scale_on_host[i] = scale_on_host[i] / sqrtf(var_on_host[i] + 0.00001f);
}
for (int i = 0; i < mean_len; ++i) {
for (int j = 0; j < filter_stride; ++j) {
filter_on_host[i * filter_stride + j] *= scale_on_host[i];
}
}
for (int i = 0; i < mean_len; ++i) {
bias_on_host[i] += -mean_on_host[i] * scale_on_host[i];
}
float max_f =
paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len);
std::unique_ptr<int16_t[]> filter_int16(new int16_t[filter_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
filter_on_host, filter_int16.get(), max_f, filter_len);
memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t));
// create new arg in graph and scope
std::string max_name = filter_name[i] + "_max";
max_filter_name.push_back(max_name);
auto* max_filter_node = graph->NewArgumentNode(max_name);
max_filter_node->arg()->is_weight = true;
max_filter_node->arg()->type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
DirectedLink(max_filter_node, matched.at("top_conv"));
auto* max_filter_t = scope->NewTensor(max_name);
max_filter_t->Resize({4});
float* max_ptr = max_filter_t->mutable_data<float>();
max_ptr[0] = max_f;
max_ptr[1] = max_f;
max_ptr[2] = max_f;
max_ptr[3] = max_f;
}
op_desc.SetInput("MaxFilter", max_filter_name);
auto resnet50_op = LiteOpRegistry::Global().Create(op_desc.Type());
resnet50_op->Attach(op_desc, scope);
resnet50_op->SetValidPlaces(resnet50_stmt->op()->valid_places());
auto kernels = resnet50_op->CreateKernels(resnet50_op->valid_places());
resnet50_stmt->SetOp(resnet50_op);
resnet50_stmt->SetKernels(std::move(kernels));
IR_NODE_LINK_TO(matched.at("top_bn_bias"), matched.at("top_conv"));
for (auto* node : extra_input_vars) {
IR_NODE_LINK_TO(node, matched.at("top_conv"));
}
IR_OP_VAR_LINK(matched.at("top_conv"), matched.at("bottom_pool_out"));
}
};
class XPUResNet50DtypeFuser : public xpu::XPUFuseBase {
public:
XPUResNet50DtypeFuser() {}
void BuildPattern() override {
auto* input =
......@@ -650,8 +1299,102 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
auto* top_relu = OpNode("top_relu", "relu")->AsIntermediate();
auto* top_relu_out = VarNode("top_relu_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("pool2d", "X")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* second_conv_weight = VarNode("second_conv_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* second_conv = OpNode("second_conv", "conv2d")->AsIntermediate();
auto* second_conv_out = VarNode("second_conv_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* second_bn_scale = VarNode("second_bn_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* second_bn_bias = VarNode("second_bn_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* second_bn_mean = VarNode("second_bn_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* second_bn_var = VarNode("second_bn_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* second_bn = OpNode("second_bn", "batch_norm")->AsIntermediate();
auto* second_bn_out = VarNode("second_bn_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* second_bn_mean_out =
VarNode("second_bn_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* second_bn_var_out =
VarNode("second_bn_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* second_bn_saved_mean =
VarNode("second_bn_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* second_bn_saved_var =
VarNode("second_bn_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* second_relu = OpNode("second_relu", "relu")->AsIntermediate();
auto* second_relu_out = VarNode("second_relu_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("conv2d", "Input")
->AsIntermediate();
auto* third_conv_weight = VarNode("third_conv_weight")
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto* third_conv = OpNode("third_conv", "conv2d")->AsIntermediate();
auto* third_conv_out = VarNode("third_conv_out")
->assert_is_op_output("conv2d", "Output")
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* third_bn_scale = VarNode("third_bn_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* third_bn_bias = VarNode("third_bn_bias")
->assert_is_op_input("batch_norm", "Bias")
->AsInput();
auto* third_bn_mean = VarNode("third_bn_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* third_bn_var = VarNode("third_bn_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* third_bn = OpNode("third_bn", "batch_norm")->AsIntermediate();
auto* third_bn_out = VarNode("third_bn_out")
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input("relu", "X")
->AsIntermediate();
auto* third_bn_mean_out = VarNode("third_bn_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* third_bn_var_out =
VarNode("third_bn_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* third_bn_saved_mean =
VarNode("third_bn_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* third_bn_saved_var =
VarNode("third_bn_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
auto* third_relu = OpNode("third_relu", "relu")->AsIntermediate();
auto* third_relu_out = VarNode("third_relu_out")
->assert_is_op_output("relu", "Out")
->assert_is_op_input("pool2d", "X")
->AsIntermediate();
auto* top_pool = OpNode("top_pool", "pool2d")->AsIntermediate();
auto* top_pool_out = VarNode("top_pool_out")
->assert_is_op_output("pool2d", "Out")
......@@ -679,10 +1422,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->AsIntermediate();
auto* resnet_block0_2 =
OpNode("resnet_block0_2", "resnet_block0")->AsIntermediate();
OpNode("resnet_block0_2", "resnet_block0_d")->AsIntermediate();
auto* resnet_block0_2_out =
VarNode("resnet_block0_2_out")
->assert_is_op_output("resnet_block0", "Outputs")
->assert_is_op_output("resnet_block0_d", "Outputs")
->AsIntermediate();
auto* resnet_block1_2_1 =
OpNode("resnet_block1_2_1", "resnet_block1")->AsIntermediate();
......@@ -704,10 +1447,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->AsIntermediate();
auto* resnet_block0_3 =
OpNode("resnet_block0_3", "resnet_block0")->AsIntermediate();
OpNode("resnet_block0_3", "resnet_block0_d")->AsIntermediate();
auto* resnet_block0_3_out =
VarNode("resnet_block0_3_out")
->assert_is_op_output("resnet_block0", "Outputs")
->assert_is_op_output("resnet_block0_d", "Outputs")
->AsIntermediate();
auto* resnet_block1_3_1 =
OpNode("resnet_block1_3_1", "resnet_block1")->AsIntermediate();
......@@ -741,10 +1484,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->AsIntermediate();
auto* resnet_block0_4 =
OpNode("resnet_block0_4", "resnet_block0")->AsIntermediate();
OpNode("resnet_block0_4", "resnet_block0_d")->AsIntermediate();
auto* resnet_block0_4_out =
VarNode("resnet_block0_4_out")
->assert_is_op_output("resnet_block0", "Outputs")
->assert_is_op_output("resnet_block0_d", "Outputs")
->AsIntermediate();
auto* resnet_block1_4_1 =
OpNode("resnet_block1_4_1", "resnet_block1")->AsIntermediate();
......@@ -765,7 +1508,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->AsOutput();
*input >> *top_conv >> *top_conv_out >> *top_bn >> *top_bn_out >>
*top_relu >> *top_relu_out >> *top_pool >> *top_pool_out >>
*top_relu >> *top_relu_out >> *second_conv >> *second_conv_out >>
*second_bn >> *second_bn_out >> *second_relu >> *second_relu_out >>
*third_conv >> *third_conv_out >> *third_bn >> *third_bn_out >>
*third_relu >> *third_relu_out >> *top_pool >> *top_pool_out >>
*resnet_block0_1 >> *resnet_block0_1_out >> *resnet_block1_1_1 >>
*resnet_block1_1_1_out >> *resnet_block1_1_2 >>
*resnet_block1_1_2_out >> *resnet_block0_2 >> *resnet_block0_2_out >>
......@@ -789,24 +1535,59 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
*top_bn >> *top_bn_var_out;
*top_bn >> *top_bn_saved_mean;
*top_bn >> *top_bn_saved_var;
*second_conv_weight >> *second_conv;
*second_bn_scale >> *second_bn;
*second_bn_bias >> *second_bn;
*second_bn_mean >> *second_bn;
*second_bn_var >> *second_bn;
*second_bn >> *second_bn_mean_out;
*second_bn >> *second_bn_var_out;
*second_bn >> *second_bn_saved_mean;
*second_bn >> *second_bn_saved_var;
*third_conv_weight >> *third_conv;
*third_bn_scale >> *third_bn;
*third_bn_bias >> *third_bn;
*third_bn_mean >> *third_bn;
*third_bn_var >> *third_bn;
*third_bn >> *third_bn_mean_out;
*third_bn >> *third_bn_var_out;
*third_bn >> *third_bn_saved_mean;
*third_bn >> *third_bn_saved_var;
}
void InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched,
const std::vector<Node*>& extra_input_vars) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__resnet50");
op_desc.SetType("__xpu__resnet50_d");
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
std::vector<std::string> filter_name = {
matched.at("top_conv_weight")->arg()->name};
matched.at("top_conv_weight")->arg()->name,
matched.at("second_conv_weight")->arg()->name,
matched.at("third_conv_weight")->arg()->name};
std::vector<std::string> scale_name = {
matched.at("top_bn_scale")->arg()->name};
matched.at("top_bn_scale")->arg()->name,
matched.at("second_bn_scale")->arg()->name,
matched.at("third_bn_scale")->arg()->name};
std::vector<std::string> bias_name = {
matched.at("top_bn_bias")->arg()->name};
matched.at("top_bn_bias")->arg()->name,
matched.at("second_bn_bias")->arg()->name,
matched.at("third_bn_bias")->arg()->name};
std::vector<std::string> mean_name = {
matched.at("top_bn_mean")->arg()->name};
matched.at("top_bn_mean")->arg()->name,
matched.at("second_bn_mean")->arg()->name,
matched.at("third_bn_mean")->arg()->name};
std::vector<std::string> var_name = {
matched.at("top_bn_variance")->arg()->name};
matched.at("top_bn_variance")->arg()->name,
matched.at("second_bn_variance")->arg()->name,
matched.at("third_bn_variance")->arg()->name};
std::vector<std::string> max_filter_name;
std::vector<std::string> resnet_block_vec = {
"resnet_block0_1",
......@@ -900,7 +1681,9 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
max_filter_node->arg()->is_weight = true;
max_filter_node->arg()->type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
DirectedLink(max_filter_node, matched.at("top_conv"));
auto* max_filter_t = scope->NewTensor(max_name);
max_filter_t->Resize({4});
float* max_ptr = max_filter_t->mutable_data<float>();
......@@ -919,6 +1702,11 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
resnet50_stmt->SetKernels(std::move(kernels));
IR_NODE_LINK_TO(matched.at("top_bn_bias"), matched.at("top_conv"));
IR_NODE_LINK_TO(matched.at("second_conv_weight"), matched.at("top_conv"));
IR_NODE_LINK_TO(matched.at("second_bn_bias"), matched.at("top_conv"));
IR_NODE_LINK_TO(matched.at("third_conv_weight"), matched.at("top_conv"));
IR_NODE_LINK_TO(matched.at("third_bn_bias"), matched.at("top_conv"));
for (auto* node : extra_input_vars) {
IR_NODE_LINK_TO(node, matched.at("top_conv"));
}
......@@ -951,6 +1739,31 @@ class XPUResNet50FusePass : public ProgramPass {
}
};
class XPUResNet50DtypeFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
bool changed = false;
SSAGraph backup;
backup.CloneFrom(*graph);
fusion::XPUResNetBlock0Fuser block0_fuser;
changed |= block0_fuser(graph.get());
fusion::XPUResNetDtypeBlock0Fuser d_type_block0_fuser;
changed |= d_type_block0_fuser(graph.get());
fusion::XPUResNetBlock1Fuser block1_fuser;
changed |= block1_fuser(graph.get());
fusion::XPUResNet50DtypeFuser resnet50_d_fuser;
size_t n_matches = resnet50_d_fuser(graph.get());
if (changed && !n_matches) {
// Restore graph from backuped one if no whole ResNet50 graph was found
graph->CloneFrom(backup);
}
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -959,3 +1772,8 @@ REGISTER_MIR_PASS(__xpu__resnet_fuse_pass,
paddle::lite::mir::XPUResNet50FusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("__xpu__resnet50");
REGISTER_MIR_PASS(__xpu__resnet_d_fuse_pass,
paddle::lite::mir::XPUResNet50DtypeFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("__xpu__resnet50_d");
......@@ -108,6 +108,7 @@ class Optimizer {
#endif
"identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass",
"__xpu__resnet_d_fuse_pass",
"__xpu__resnet_cbam_fuse_pass",
"__xpu__conv2d_fuse_pass",
"__xpu__conv2d_link_previous_out_max_pass",
......
......@@ -34,6 +34,21 @@ void XPUResNet50Compute::PrepareForRun() {
}
}
void XPUResNet50DtypeCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
for (auto* filter : param.filter) {
arg_filter_.push_back(
reinterpret_cast<const int16_t*>(filter->data<float>()));
}
for (auto* bias : param.bias) {
arg_bias_.push_back(bias->data<float>());
}
for (auto* max_filter : param.max_filter) {
arg_max_filter_.push_back(max_filter->data<float>());
}
}
void XPUResNet50Compute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
......@@ -50,6 +65,22 @@ void XPUResNet50Compute::Run() {
CHECK_EQ(r, 0);
}
void XPUResNet50DtypeCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
int batch_size = param.input->dims()[0];
int r = xdnn::conv2d_int16_resnet_d<float, int16_t>(
ctx.GetRawContext(), /* context */
batch_size, /* num */
param.input->data<float>(), /* bottom */
&arg_filter_[0], /* weight_list */
param.output->mutable_data<float>(TARGET(kXPU)), /* top */
&arg_bias_[0], /* bias_list */
&arg_max_filter_[0] /* max_filter_list */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
......@@ -67,3 +98,16 @@ REGISTER_LITE_KERNEL(__xpu__resnet50,
.BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(__xpu__resnet50_d,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUResNet50DtypeCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
......@@ -38,6 +38,21 @@ class XPUResNet50Compute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
std::vector<const float *> arg_bias_;
};
class XPUResNet50DtypeCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUResNet50Param;
virtual void PrepareForRun();
virtual void Run();
private:
std::vector<const int16_t *> arg_filter_;
std::vector<const float *> arg_max_filter_;
std::vector<const float *> arg_bias_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
......
......@@ -62,3 +62,4 @@ bool XPUResNet50Op::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
} // namespace paddle
REGISTER_LITE_OP(__xpu__resnet50, paddle::lite::operators::XPUResNet50Op);
REGISTER_LITE_OP(__xpu__resnet50_d, paddle::lite::operators::XPUResNet50Op);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册