diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index cea2a45c5db15891a4de679265a9c2cd2779d0fb..a4ea030cbf3ae7ead5836f02638ff440335f89fe 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -62,6 +62,7 @@ USE_MIR_PASS(quantized_op_attributes_inference_pass); USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass) USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass); +USE_MIR_PASS(__xpu__resnet_d_fuse_pass); USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass); USE_MIR_PASS(__xpu__multi_encoder_fuse_pass); USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass); diff --git a/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc b/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc index 39773e272a3345454c00c4da4b7e7c69617afd69..0692928dd212dd6bfc61f7a53e6321ac93439993 100644 --- a/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc @@ -307,7 +307,7 @@ class XPUResNetBlock0Fuser : public FuseBase { matched.at("right_bn1_variance")->arg()->name, }); op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name}); - // XXX: keep these to fool SubgraphOp::AttachImpl() + // keep these to fool SubgraphOp::AttachImpl() op_desc.SetAttr("sub_block", 0); op_desc.SetAttr>("input_data_names", {}); op_desc.SetAttr>("output_data_names", {}); @@ -570,7 +570,7 @@ class XPUResNetBlock1Fuser : public FuseBase { matched.at("right_bn3_variance")->arg()->name, }); op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name}); - // XXX: keep these to fool SubgraphOp::AttachImpl() + // keep these to fool SubgraphOp::AttachImpl() op_desc.SetAttr("sub_block", 0); op_desc.SetAttr>("input_data_names", {}); op_desc.SetAttr>("output_data_names", {}); @@ -599,9 +599,658 @@ class XPUResNetBlock1Fuser : public FuseBase { } }; +class XPUResNetDtypeBlock0Fuser : public FuseBase { + public: + XPUResNetDtypeBlock0Fuser() {} + + void BuildPattern() override { + auto* input = VarNode("input") + ->assert_is_op_input("conv2d", "Input") + ->assert_is_op_input("pool2d", "X") + ->AsInput(); + + auto* left_conv1_weight = VarNode("left_conv1_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* left_conv1 = OpNode("left_conv1", "conv2d"); + auto* left_conv1_out = VarNode("left_conv1_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* left_bn1_scale = VarNode("left_bn1_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* left_bn1_bias = VarNode("left_bn1_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* left_bn1_mean = VarNode("left_bn1_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* left_bn1_var = VarNode("left_bn1_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* left_bn1 = OpNode("left_bn1", "batch_norm")->AsIntermediate(); + auto* left_bn1_out = VarNode("left_bn1_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* left_bn1_mean_out = VarNode("left_bn1_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* left_bn1_var_out = + VarNode("left_bn1_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* left_bn1_saved_mean = + VarNode("left_bn1_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* left_bn1_saved_var = + VarNode("left_bn1_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* left_relu1 = OpNode("left_relu1", "relu")->AsIntermediate(); + auto* left_relu1_out = VarNode("left_relu1_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* left_conv2_weight = VarNode("left_conv2_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* left_conv2 = OpNode("left_conv2", "conv2d")->AsIntermediate(); + auto* left_conv2_out = VarNode("left_conv2_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* left_bn2_scale = VarNode("left_bn2_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* left_bn2_bias = VarNode("left_bn2_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* left_bn2_mean = VarNode("left_bn2_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* left_bn2_var = VarNode("left_bn2_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* left_bn2 = OpNode("left_bn2", "batch_norm")->AsIntermediate(); + auto* left_bn2_out = VarNode("left_bn2_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* left_bn2_mean_out = VarNode("left_bn2_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* left_bn2_var_out = + VarNode("left_bn2_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* left_bn2_saved_mean = + VarNode("left_bn2_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* left_bn2_saved_var = + VarNode("left_bn2_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* left_relu2 = OpNode("left_relu2", "relu")->AsIntermediate(); + auto* left_relu2_out = VarNode("left_relu2_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* left_conv3_weight = VarNode("left_conv3_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* left_conv3 = OpNode("left_conv3", "conv2d")->AsIntermediate(); + auto* left_conv3_out = VarNode("left_conv3_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* left_bn3_scale = VarNode("left_bn3_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* left_bn3_bias = VarNode("left_bn3_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* left_bn3_mean = VarNode("left_bn3_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* left_bn3_var = VarNode("left_bn3_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* left_bn3 = OpNode("left_bn3", "batch_norm")->AsIntermediate(); + auto* left_bn3_out = VarNode("left_bn3_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("elementwise_add", "Y") + ->AsIntermediate(); + auto* left_bn3_mean_out = VarNode("left_bn3_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* left_bn3_var_out = + VarNode("left_bn3_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* left_bn3_saved_mean = + VarNode("left_bn3_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* left_bn3_saved_var = + VarNode("left_bn3_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + auto* right_pool = OpNode("right_pool", "pool2d")->AsIntermediate(); + auto* right_pool_out = VarNode("right_pool_out") + ->assert_is_op_output("pool2d", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + auto* right_conv1_weight = VarNode("right_conv1_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* right_conv1 = OpNode("right_conv1", "conv2d")->AsIntermediate(); + auto* right_conv1_out = VarNode("right_conv1_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* right_bn1_scale = VarNode("right_bn1_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* right_bn1_bias = VarNode("right_bn1_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* right_bn1_mean = VarNode("right_bn1_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* right_bn1_var = VarNode("right_bn1_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* right_bn1 = OpNode("right_bn1", "batch_norm")->AsIntermediate(); + auto* right_bn1_out = VarNode("right_bn1_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + auto* right_bn1_mean_out = + VarNode("right_bn1_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* right_bn1_var_out = + VarNode("right_bn1_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* right_bn1_saved_mean = + VarNode("right_bn1_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* right_bn1_saved_var = + VarNode("right_bn1_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + auto* add = OpNode("add", "elementwise_add")->AsIntermediate(); + auto* add_out = VarNode("add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* relu = OpNode("relu", "relu")->AsIntermediate(); + auto* relu_out = + VarNode("relu_out")->assert_is_op_output("relu", "Out")->AsOutput(); + + *input >> *left_conv1 >> *left_conv1_out >> *left_bn1 >> *left_bn1_out >> + *left_relu1 >> *left_relu1_out >> *left_conv2 >> *left_conv2_out >> + *left_bn2 >> *left_bn2_out >> *left_relu2 >> *left_relu2_out >> + *left_conv3 >> *left_conv3_out >> *left_bn3 >> *left_bn3_out >> *add; + + *left_conv1_weight >> *left_conv1; + *left_bn1_scale >> *left_bn1; + *left_bn1_bias >> *left_bn1; + *left_bn1_mean >> *left_bn1; + *left_bn1_var >> *left_bn1; + *left_bn1 >> *left_bn1_mean_out; + *left_bn1 >> *left_bn1_var_out; + *left_bn1 >> *left_bn1_saved_mean; + *left_bn1 >> *left_bn1_saved_var; + + *left_conv2_weight >> *left_conv2; + *left_bn2_scale >> *left_bn2; + *left_bn2_bias >> *left_bn2; + *left_bn2_mean >> *left_bn2; + *left_bn2_var >> *left_bn2; + *left_bn2 >> *left_bn2_mean_out; + *left_bn2 >> *left_bn2_var_out; + *left_bn2 >> *left_bn2_saved_mean; + *left_bn2 >> *left_bn2_saved_var; + + *left_conv3_weight >> *left_conv3; + *left_bn3_scale >> *left_bn3; + *left_bn3_bias >> *left_bn3; + *left_bn3_mean >> *left_bn3; + *left_bn3_var >> *left_bn3; + *left_bn3 >> *left_bn3_mean_out; + *left_bn3 >> *left_bn3_var_out; + *left_bn3 >> *left_bn3_saved_mean; + *left_bn3 >> *left_bn3_saved_var; + + *input >> *right_pool >> *right_pool_out >> *right_conv1 >> + *right_conv1_out >> *right_bn1 >> *right_bn1_out >> *add; + + *right_conv1_weight >> *right_conv1; + *right_bn1_scale >> *right_bn1; + *right_bn1_bias >> *right_bn1; + *right_bn1_mean >> *right_bn1; + *right_bn1_var >> *right_bn1; + *right_bn1 >> *right_bn1_mean_out; + *right_bn1 >> *right_bn1_var_out; + *right_bn1 >> *right_bn1_saved_mean; + *right_bn1 >> *right_bn1_saved_var; + + *add >> *add_out >> *relu >> *relu_out; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("resnet_block0_d"); + op_desc.SetInput("Inputs", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", + { + matched.at("left_conv1_weight")->arg()->name, + matched.at("left_conv2_weight")->arg()->name, + matched.at("left_conv3_weight")->arg()->name, + matched.at("right_conv1_weight")->arg()->name, + }); + op_desc.SetInput("Scale", + { + matched.at("left_bn1_scale")->arg()->name, + matched.at("left_bn2_scale")->arg()->name, + matched.at("left_bn3_scale")->arg()->name, + matched.at("right_bn1_scale")->arg()->name, + }); + op_desc.SetInput("Bias", + { + matched.at("left_bn1_bias")->arg()->name, + matched.at("left_bn2_bias")->arg()->name, + matched.at("left_bn3_bias")->arg()->name, + matched.at("right_bn1_bias")->arg()->name, + }); + op_desc.SetInput("Mean", + { + matched.at("left_bn1_mean")->arg()->name, + matched.at("left_bn2_mean")->arg()->name, + matched.at("left_bn3_mean")->arg()->name, + matched.at("right_bn1_mean")->arg()->name, + }); + op_desc.SetInput("Var", + { + matched.at("left_bn1_variance")->arg()->name, + matched.at("left_bn2_variance")->arg()->name, + matched.at("left_bn3_variance")->arg()->name, + matched.at("right_bn1_variance")->arg()->name, + }); + op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name}); + // keep these to fool SubgraphOp::AttachImpl() + op_desc.SetAttr("sub_block", 0); + op_desc.SetAttr>("input_data_names", {}); + op_desc.SetAttr>("output_data_names", {}); + + auto block0_stmt = matched.at("left_conv1")->stmt(); + // block0_stmt->ResetOp(op_desc, graph->valid_places()); + auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); + auto sub_program_desc = std::make_shared(); + sub_program_desc->AddBlock(); + static_cast(fake_subgraph_op.get()) + ->SetProgramDesc(sub_program_desc); + fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); + fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); + block0_stmt->SetOp(fake_subgraph_op); + + std::vector froms = { + "left_conv2_weight", + "left_conv3_weight", + "right_conv1_weight", + "left_bn1_bias", + "left_bn2_bias", + "left_bn3_bias", + "right_bn1_bias", + }; + for (auto& from : froms) { + IR_NODE_LINK_TO(matched.at(from), matched.at("left_conv1")); + } + IR_OP_VAR_LINK(matched.at("left_conv1"), matched.at("relu_out")); + } +}; + class XPUResNet50Fuser : public xpu::XPUFuseBase { public: - XPUResNet50Fuser() {} + XPUResNet50Fuser() {} + + void BuildPattern() override { + auto* input = + VarNode("input")->assert_is_op_input("conv2d", "Input")->AsInput(); + + auto* top_conv_weight = VarNode("top_conv_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* top_conv = OpNode("top_conv", "conv2d"); + auto* top_conv_out = VarNode("top_conv_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* top_bn_scale = VarNode("top_bn_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* top_bn_bias = VarNode("top_bn_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* top_bn_mean = VarNode("top_bn_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* top_bn_var = VarNode("top_bn_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* top_bn = OpNode("top_bn", "batch_norm")->AsIntermediate(); + auto* top_bn_out = VarNode("top_bn_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* top_bn_mean_out = VarNode("top_bn_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* top_bn_var_out = + VarNode("top_bn_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* top_bn_saved_mean = + VarNode("top_bn_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* top_bn_saved_var = + VarNode("top_bn_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* top_relu = OpNode("top_relu", "relu")->AsIntermediate(); + auto* top_relu_out = VarNode("top_relu_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("pool2d", "X") + ->AsIntermediate(); + auto* top_pool = OpNode("top_pool", "pool2d")->AsIntermediate(); + auto* top_pool_out = VarNode("top_pool_out") + ->assert_is_op_output("pool2d", "Out") + ->assert_is_op_input("resnet_block0", "Inputs") + ->AsIntermediate(); + + // args are left out + auto* resnet_block0_1 = + OpNode("resnet_block0_1", "resnet_block0")->AsIntermediate(); + auto* resnet_block0_1_out = + VarNode("resnet_block0_1_out") + ->assert_is_op_output("resnet_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_1_1 = + OpNode("resnet_block1_1_1", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_1_1_out = + VarNode("resnet_block1_1_1_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_1_2 = + OpNode("resnet_block1_1_2", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_1_2_out = + VarNode("resnet_block1_1_2_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + + auto* resnet_block0_2 = + OpNode("resnet_block0_2", "resnet_block0")->AsIntermediate(); + auto* resnet_block0_2_out = + VarNode("resnet_block0_2_out") + ->assert_is_op_output("resnet_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_2_1 = + OpNode("resnet_block1_2_1", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_2_1_out = + VarNode("resnet_block1_2_1_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_2_2 = + OpNode("resnet_block1_2_2", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_2_2_out = + VarNode("resnet_block1_2_2_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_2_3 = + OpNode("resnet_block1_2_3", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_2_3_out = + VarNode("resnet_block1_2_3_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + + auto* resnet_block0_3 = + OpNode("resnet_block0_3", "resnet_block0")->AsIntermediate(); + auto* resnet_block0_3_out = + VarNode("resnet_block0_3_out") + ->assert_is_op_output("resnet_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_1 = + OpNode("resnet_block1_3_1", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_3_1_out = + VarNode("resnet_block1_3_1_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_2 = + OpNode("resnet_block1_3_2", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_3_2_out = + VarNode("resnet_block1_3_2_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_3 = + OpNode("resnet_block1_3_3", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_3_3_out = + VarNode("resnet_block1_3_3_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_4 = + OpNode("resnet_block1_3_4", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_3_4_out = + VarNode("resnet_block1_3_4_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_5 = + OpNode("resnet_block1_3_5", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_3_5_out = + VarNode("resnet_block1_3_5_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + + auto* resnet_block0_4 = + OpNode("resnet_block0_4", "resnet_block0")->AsIntermediate(); + auto* resnet_block0_4_out = + VarNode("resnet_block0_4_out") + ->assert_is_op_output("resnet_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_4_1 = + OpNode("resnet_block1_4_1", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_4_1_out = + VarNode("resnet_block1_4_1_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_4_2 = + OpNode("resnet_block1_4_2", "resnet_block1")->AsIntermediate(); + auto* resnet_block1_4_2_out = + VarNode("resnet_block1_4_2_out") + ->assert_is_op_output("resnet_block1", "Outputs") + ->AsIntermediate(); + + auto* bottom_pool = OpNode("bottom_pool", "pool2d")->AsIntermediate(); + auto* bottom_pool_out = VarNode("bottom_pool_out") + ->assert_is_op_output("pool2d", "Out") + ->AsOutput(); + + *input >> *top_conv >> *top_conv_out >> *top_bn >> *top_bn_out >> + *top_relu >> *top_relu_out >> *top_pool >> *top_pool_out >> + *resnet_block0_1 >> *resnet_block0_1_out >> *resnet_block1_1_1 >> + *resnet_block1_1_1_out >> *resnet_block1_1_2 >> + *resnet_block1_1_2_out >> *resnet_block0_2 >> *resnet_block0_2_out >> + *resnet_block1_2_1 >> *resnet_block1_2_1_out >> *resnet_block1_2_2 >> + *resnet_block1_2_2_out >> *resnet_block1_2_3 >> + *resnet_block1_2_3_out >> *resnet_block0_3 >> *resnet_block0_3_out >> + *resnet_block1_3_1 >> *resnet_block1_3_1_out >> *resnet_block1_3_2 >> + *resnet_block1_3_2_out >> *resnet_block1_3_3 >> + *resnet_block1_3_3_out >> *resnet_block1_3_4 >> + *resnet_block1_3_4_out >> *resnet_block1_3_5 >> + *resnet_block1_3_5_out >> *resnet_block0_4 >> *resnet_block0_4_out >> + *resnet_block1_4_1 >> *resnet_block1_4_1_out >> *resnet_block1_4_2 >> + *resnet_block1_4_2_out >> *bottom_pool >> *bottom_pool_out; + + *top_conv_weight >> *top_conv; + *top_bn_scale >> *top_bn; + *top_bn_bias >> *top_bn; + *top_bn_mean >> *top_bn; + *top_bn_var >> *top_bn; + *top_bn >> *top_bn_mean_out; + *top_bn >> *top_bn_var_out; + *top_bn >> *top_bn_saved_mean; + *top_bn >> *top_bn_saved_var; + } + + void InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched, + const std::vector& extra_input_vars) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__resnet50"); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + std::vector filter_name = { + matched.at("top_conv_weight")->arg()->name}; + std::vector scale_name = { + matched.at("top_bn_scale")->arg()->name}; + std::vector bias_name = { + matched.at("top_bn_bias")->arg()->name}; + std::vector mean_name = { + matched.at("top_bn_mean")->arg()->name}; + std::vector var_name = { + matched.at("top_bn_variance")->arg()->name}; + std::vector max_filter_name; + std::vector resnet_block_vec = { + "resnet_block0_1", + "resnet_block1_1_1", + "resnet_block1_1_2", + "resnet_block0_2", + "resnet_block1_2_1", + "resnet_block1_2_2", + "resnet_block1_2_3", + "resnet_block0_3", + "resnet_block1_3_1", + "resnet_block1_3_2", + "resnet_block1_3_3", + "resnet_block1_3_4", + "resnet_block1_3_5", + "resnet_block0_4", + "resnet_block1_4_1", + "resnet_block1_4_2", + }; + for (auto& block : resnet_block_vec) { + auto* block_op_info = matched.at(block)->stmt()->op_info(); + auto block_filter_name = block_op_info->Input("Filter"); + std::copy(block_filter_name.begin(), + block_filter_name.end(), + std::back_inserter(filter_name)); + auto block_scale_name = block_op_info->Input("Scale"); + std::copy(block_scale_name.begin(), + block_scale_name.end(), + std::back_inserter(scale_name)); + auto block_bias_name = block_op_info->Input("Bias"); + std::copy(block_bias_name.begin(), + block_bias_name.end(), + std::back_inserter(bias_name)); + auto block_mean_name = block_op_info->Input("Mean"); + std::copy(block_mean_name.begin(), + block_mean_name.end(), + std::back_inserter(mean_name)); + auto block_var_name = block_op_info->Input("Var"); + std::copy(block_var_name.begin(), + block_var_name.end(), + std::back_inserter(var_name)); + } + op_desc.SetInput("Filter", filter_name); + op_desc.SetInput("Bias", bias_name); + op_desc.SetOutput("Output", {matched.at("bottom_pool_out")->arg()->name}); + op_desc.SetAttr("xpu", 1); + + auto* resnet50_stmt = matched.at("top_conv")->stmt(); + auto* scope = resnet50_stmt->op()->scope(); + for (size_t i = 0; i < filter_name.size(); ++i) { + auto* filter_t = scope->FindMutableTensor(filter_name[i]); + auto* scale_t = scope->FindMutableTensor(scale_name[i]); + auto* bias_t = scope->FindMutableTensor(bias_name[i]); + auto* mean_t = scope->FindMutableTensor(mean_name[i]); + auto* var_t = scope->FindMutableTensor(var_name[i]); + + int mean_len = mean_t->numel(); + int filter_len = filter_t->numel(); + int filter_stride = filter_len / mean_len; + + float* filter_on_host = filter_t->mutable_data(); + float* scale_on_host = scale_t->mutable_data(); + float* bias_on_host = bias_t->mutable_data(); + float* mean_on_host = mean_t->mutable_data(); + float* var_on_host = var_t->mutable_data(); + + // Perform preprocess + for (int i = 0; i < mean_len; ++i) { + scale_on_host[i] = scale_on_host[i] / sqrtf(var_on_host[i] + 0.00001f); + } + for (int i = 0; i < mean_len; ++i) { + for (int j = 0; j < filter_stride; ++j) { + filter_on_host[i * filter_stride + j] *= scale_on_host[i]; + } + } + for (int i = 0; i < mean_len; ++i) { + bias_on_host[i] += -mean_on_host[i] * scale_on_host[i]; + } + + float max_f = + paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len); + std::unique_ptr filter_int16(new int16_t[filter_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + filter_on_host, filter_int16.get(), max_f, filter_len); + memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t)); + + // create new arg in graph and scope + std::string max_name = filter_name[i] + "_max"; + max_filter_name.push_back(max_name); + auto* max_filter_node = graph->NewArgumentNode(max_name); + max_filter_node->arg()->is_weight = true; + max_filter_node->arg()->type = LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + DirectedLink(max_filter_node, matched.at("top_conv")); + auto* max_filter_t = scope->NewTensor(max_name); + max_filter_t->Resize({4}); + float* max_ptr = max_filter_t->mutable_data(); + max_ptr[0] = max_f; + max_ptr[1] = max_f; + max_ptr[2] = max_f; + max_ptr[3] = max_f; + } + op_desc.SetInput("MaxFilter", max_filter_name); + + auto resnet50_op = LiteOpRegistry::Global().Create(op_desc.Type()); + resnet50_op->Attach(op_desc, scope); + resnet50_op->SetValidPlaces(resnet50_stmt->op()->valid_places()); + auto kernels = resnet50_op->CreateKernels(resnet50_op->valid_places()); + resnet50_stmt->SetOp(resnet50_op); + resnet50_stmt->SetKernels(std::move(kernels)); + + IR_NODE_LINK_TO(matched.at("top_bn_bias"), matched.at("top_conv")); + for (auto* node : extra_input_vars) { + IR_NODE_LINK_TO(node, matched.at("top_conv")); + } + IR_OP_VAR_LINK(matched.at("top_conv"), matched.at("bottom_pool_out")); + } +}; + +class XPUResNet50DtypeFuser : public xpu::XPUFuseBase { + public: + XPUResNet50DtypeFuser() {} void BuildPattern() override { auto* input = @@ -650,8 +1299,102 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { auto* top_relu = OpNode("top_relu", "relu")->AsIntermediate(); auto* top_relu_out = VarNode("top_relu_out") ->assert_is_op_output("relu", "Out") - ->assert_is_op_input("pool2d", "X") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* second_conv_weight = VarNode("second_conv_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* second_conv = OpNode("second_conv", "conv2d")->AsIntermediate(); + auto* second_conv_out = VarNode("second_conv_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* second_bn_scale = VarNode("second_bn_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* second_bn_bias = VarNode("second_bn_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* second_bn_mean = VarNode("second_bn_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* second_bn_var = VarNode("second_bn_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* second_bn = OpNode("second_bn", "batch_norm")->AsIntermediate(); + auto* second_bn_out = VarNode("second_bn_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* second_bn_mean_out = + VarNode("second_bn_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* second_bn_var_out = + VarNode("second_bn_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* second_bn_saved_mean = + VarNode("second_bn_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* second_bn_saved_var = + VarNode("second_bn_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* second_relu = OpNode("second_relu", "relu")->AsIntermediate(); + auto* second_relu_out = VarNode("second_relu_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* third_conv_weight = VarNode("third_conv_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* third_conv = OpNode("third_conv", "conv2d")->AsIntermediate(); + auto* third_conv_out = VarNode("third_conv_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* third_bn_scale = VarNode("third_bn_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* third_bn_bias = VarNode("third_bn_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* third_bn_mean = VarNode("third_bn_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* third_bn_var = VarNode("third_bn_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* third_bn = OpNode("third_bn", "batch_norm")->AsIntermediate(); + auto* third_bn_out = VarNode("third_bn_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") ->AsIntermediate(); + auto* third_bn_mean_out = VarNode("third_bn_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* third_bn_var_out = + VarNode("third_bn_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* third_bn_saved_mean = + VarNode("third_bn_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* third_bn_saved_var = + VarNode("third_bn_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* third_relu = OpNode("third_relu", "relu")->AsIntermediate(); + auto* third_relu_out = VarNode("third_relu_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("pool2d", "X") + ->AsIntermediate(); + auto* top_pool = OpNode("top_pool", "pool2d")->AsIntermediate(); auto* top_pool_out = VarNode("top_pool_out") ->assert_is_op_output("pool2d", "Out") @@ -679,10 +1422,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { ->AsIntermediate(); auto* resnet_block0_2 = - OpNode("resnet_block0_2", "resnet_block0")->AsIntermediate(); + OpNode("resnet_block0_2", "resnet_block0_d")->AsIntermediate(); auto* resnet_block0_2_out = VarNode("resnet_block0_2_out") - ->assert_is_op_output("resnet_block0", "Outputs") + ->assert_is_op_output("resnet_block0_d", "Outputs") ->AsIntermediate(); auto* resnet_block1_2_1 = OpNode("resnet_block1_2_1", "resnet_block1")->AsIntermediate(); @@ -704,10 +1447,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { ->AsIntermediate(); auto* resnet_block0_3 = - OpNode("resnet_block0_3", "resnet_block0")->AsIntermediate(); + OpNode("resnet_block0_3", "resnet_block0_d")->AsIntermediate(); auto* resnet_block0_3_out = VarNode("resnet_block0_3_out") - ->assert_is_op_output("resnet_block0", "Outputs") + ->assert_is_op_output("resnet_block0_d", "Outputs") ->AsIntermediate(); auto* resnet_block1_3_1 = OpNode("resnet_block1_3_1", "resnet_block1")->AsIntermediate(); @@ -741,10 +1484,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { ->AsIntermediate(); auto* resnet_block0_4 = - OpNode("resnet_block0_4", "resnet_block0")->AsIntermediate(); + OpNode("resnet_block0_4", "resnet_block0_d")->AsIntermediate(); auto* resnet_block0_4_out = VarNode("resnet_block0_4_out") - ->assert_is_op_output("resnet_block0", "Outputs") + ->assert_is_op_output("resnet_block0_d", "Outputs") ->AsIntermediate(); auto* resnet_block1_4_1 = OpNode("resnet_block1_4_1", "resnet_block1")->AsIntermediate(); @@ -765,7 +1508,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { ->AsOutput(); *input >> *top_conv >> *top_conv_out >> *top_bn >> *top_bn_out >> - *top_relu >> *top_relu_out >> *top_pool >> *top_pool_out >> + *top_relu >> *top_relu_out >> *second_conv >> *second_conv_out >> + *second_bn >> *second_bn_out >> *second_relu >> *second_relu_out >> + *third_conv >> *third_conv_out >> *third_bn >> *third_bn_out >> + *third_relu >> *third_relu_out >> *top_pool >> *top_pool_out >> *resnet_block0_1 >> *resnet_block0_1_out >> *resnet_block1_1_1 >> *resnet_block1_1_1_out >> *resnet_block1_1_2 >> *resnet_block1_1_2_out >> *resnet_block0_2 >> *resnet_block0_2_out >> @@ -789,24 +1535,59 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { *top_bn >> *top_bn_var_out; *top_bn >> *top_bn_saved_mean; *top_bn >> *top_bn_saved_var; + + *second_conv_weight >> *second_conv; + *second_bn_scale >> *second_bn; + *second_bn_bias >> *second_bn; + *second_bn_mean >> *second_bn; + *second_bn_var >> *second_bn; + *second_bn >> *second_bn_mean_out; + *second_bn >> *second_bn_var_out; + *second_bn >> *second_bn_saved_mean; + *second_bn >> *second_bn_saved_var; + + *third_conv_weight >> *third_conv; + *third_bn_scale >> *third_bn; + *third_bn_bias >> *third_bn; + *third_bn_mean >> *third_bn; + *third_bn_var >> *third_bn; + *third_bn >> *third_bn_mean_out; + *third_bn >> *third_bn_var_out; + *third_bn >> *third_bn_saved_mean; + *third_bn >> *third_bn_saved_var; } void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched, const std::vector& extra_input_vars) override { cpp::OpDesc op_desc; - op_desc.SetType("__xpu__resnet50"); + op_desc.SetType("__xpu__resnet50_d"); op_desc.SetInput("Input", {matched.at("input")->arg()->name}); std::vector filter_name = { - matched.at("top_conv_weight")->arg()->name}; + matched.at("top_conv_weight")->arg()->name, + matched.at("second_conv_weight")->arg()->name, + matched.at("third_conv_weight")->arg()->name}; + std::vector scale_name = { - matched.at("top_bn_scale")->arg()->name}; + matched.at("top_bn_scale")->arg()->name, + matched.at("second_bn_scale")->arg()->name, + matched.at("third_bn_scale")->arg()->name}; + std::vector bias_name = { - matched.at("top_bn_bias")->arg()->name}; + matched.at("top_bn_bias")->arg()->name, + matched.at("second_bn_bias")->arg()->name, + matched.at("third_bn_bias")->arg()->name}; + std::vector mean_name = { - matched.at("top_bn_mean")->arg()->name}; + matched.at("top_bn_mean")->arg()->name, + matched.at("second_bn_mean")->arg()->name, + matched.at("third_bn_mean")->arg()->name}; + std::vector var_name = { - matched.at("top_bn_variance")->arg()->name}; + matched.at("top_bn_variance")->arg()->name, + matched.at("second_bn_variance")->arg()->name, + matched.at("third_bn_variance")->arg()->name}; + std::vector max_filter_name; std::vector resnet_block_vec = { "resnet_block0_1", @@ -900,7 +1681,9 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { max_filter_node->arg()->is_weight = true; max_filter_node->arg()->type = LiteType::GetTensorTy( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + DirectedLink(max_filter_node, matched.at("top_conv")); + auto* max_filter_t = scope->NewTensor(max_name); max_filter_t->Resize({4}); float* max_ptr = max_filter_t->mutable_data(); @@ -919,6 +1702,11 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase { resnet50_stmt->SetKernels(std::move(kernels)); IR_NODE_LINK_TO(matched.at("top_bn_bias"), matched.at("top_conv")); + IR_NODE_LINK_TO(matched.at("second_conv_weight"), matched.at("top_conv")); + IR_NODE_LINK_TO(matched.at("second_bn_bias"), matched.at("top_conv")); + IR_NODE_LINK_TO(matched.at("third_conv_weight"), matched.at("top_conv")); + IR_NODE_LINK_TO(matched.at("third_bn_bias"), matched.at("top_conv")); + for (auto* node : extra_input_vars) { IR_NODE_LINK_TO(node, matched.at("top_conv")); } @@ -951,6 +1739,31 @@ class XPUResNet50FusePass : public ProgramPass { } }; +class XPUResNet50DtypeFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; + + bool changed = false; + SSAGraph backup; + backup.CloneFrom(*graph); + + fusion::XPUResNetBlock0Fuser block0_fuser; + changed |= block0_fuser(graph.get()); + fusion::XPUResNetDtypeBlock0Fuser d_type_block0_fuser; + changed |= d_type_block0_fuser(graph.get()); + fusion::XPUResNetBlock1Fuser block1_fuser; + changed |= block1_fuser(graph.get()); + fusion::XPUResNet50DtypeFuser resnet50_d_fuser; + size_t n_matches = resnet50_d_fuser(graph.get()); + + if (changed && !n_matches) { + // Restore graph from backuped one if no whole ResNet50 graph was found + graph->CloneFrom(backup); + } + } +}; + } // namespace mir } // namespace lite } // namespace paddle @@ -959,3 +1772,8 @@ REGISTER_MIR_PASS(__xpu__resnet_fuse_pass, paddle::lite::mir::XPUResNet50FusePass) .BindTargets({TARGET(kXPU)}) .BindKernel("__xpu__resnet50"); + +REGISTER_MIR_PASS(__xpu__resnet_d_fuse_pass, + paddle::lite::mir::XPUResNet50DtypeFusePass) + .BindTargets({TARGET(kXPU)}) + .BindKernel("__xpu__resnet50_d"); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 7b12b32b6917280164b44e91d4d882af4fad73a4..7709090c038cf81bee5a735b682ea0721ee30ec1 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -108,6 +108,7 @@ class Optimizer { #endif "identity_dropout_eliminate_pass", "__xpu__resnet_fuse_pass", + "__xpu__resnet_d_fuse_pass", "__xpu__resnet_cbam_fuse_pass", "__xpu__conv2d_fuse_pass", "__xpu__conv2d_link_previous_out_max_pass", diff --git a/lite/kernels/xpu/__xpu__resnet50_compute.cc b/lite/kernels/xpu/__xpu__resnet50_compute.cc index 2e63e03fc9c1d52be42a8ff9b1d6260b3396a2fe..baa97f86606918d0e44765680f5a01abe5113615 100644 --- a/lite/kernels/xpu/__xpu__resnet50_compute.cc +++ b/lite/kernels/xpu/__xpu__resnet50_compute.cc @@ -34,6 +34,21 @@ void XPUResNet50Compute::PrepareForRun() { } } +void XPUResNet50DtypeCompute::PrepareForRun() { + auto& param = this->Param(); + + for (auto* filter : param.filter) { + arg_filter_.push_back( + reinterpret_cast(filter->data())); + } + for (auto* bias : param.bias) { + arg_bias_.push_back(bias->data()); + } + for (auto* max_filter : param.max_filter) { + arg_max_filter_.push_back(max_filter->data()); + } +} + void XPUResNet50Compute::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->As(); @@ -50,6 +65,22 @@ void XPUResNet50Compute::Run() { CHECK_EQ(r, 0); } +void XPUResNet50DtypeCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + int batch_size = param.input->dims()[0]; + int r = xdnn::conv2d_int16_resnet_d( + ctx.GetRawContext(), /* context */ + batch_size, /* num */ + param.input->data(), /* bottom */ + &arg_filter_[0], /* weight_list */ + param.output->mutable_data(TARGET(kXPU)), /* top */ + &arg_bias_[0], /* bias_list */ + &arg_max_filter_[0] /* max_filter_list */); + CHECK_EQ(r, 0); +} + } // namespace xpu } // namespace kernels } // namespace lite @@ -67,3 +98,16 @@ REGISTER_LITE_KERNEL(__xpu__resnet50, .BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); + +REGISTER_LITE_KERNEL(__xpu__resnet50_d, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUResNet50DtypeCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__resnet50_compute.h b/lite/kernels/xpu/__xpu__resnet50_compute.h index 7ce8b1192ea9e85d83ddbeddc374378692866aa6..d12616ea89de9c15c671e8fc39e55ad334945d66 100644 --- a/lite/kernels/xpu/__xpu__resnet50_compute.h +++ b/lite/kernels/xpu/__xpu__resnet50_compute.h @@ -38,6 +38,21 @@ class XPUResNet50Compute : public KernelLite { std::vector arg_bias_; }; +class XPUResNet50DtypeCompute + : public KernelLite { + public: + using param_t = operators::XPUResNet50Param; + + virtual void PrepareForRun(); + + virtual void Run(); + + private: + std::vector arg_filter_; + std::vector arg_max_filter_; + std::vector arg_bias_; +}; + } // namespace xpu } // namespace kernels } // namespace lite diff --git a/lite/operators/__xpu__resnet50_op.cc b/lite/operators/__xpu__resnet50_op.cc index 02ea6dc1799baaab486b839a4d3137020a9f7a5c..cea8ba667d2d3a3a3794b6cce604c6c7a86de8e8 100644 --- a/lite/operators/__xpu__resnet50_op.cc +++ b/lite/operators/__xpu__resnet50_op.cc @@ -62,3 +62,4 @@ bool XPUResNet50Op::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { } // namespace paddle REGISTER_LITE_OP(__xpu__resnet50, paddle::lite::operators::XPUResNet50Op); +REGISTER_LITE_OP(__xpu__resnet50_d, paddle::lite::operators::XPUResNet50Op);