未验证 提交 8d2351c3 编写于 作者: W weihaoji 提交者: GitHub

[XPU] add resnet50-D fusion (#4276)

上级 7fb2261d
......@@ -62,6 +62,7 @@ USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass)
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_d_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
......
......@@ -108,6 +108,7 @@ class Optimizer {
#endif
"identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass",
"__xpu__resnet_d_fuse_pass",
"__xpu__resnet_cbam_fuse_pass",
"__xpu__conv2d_fuse_pass",
"__xpu__conv2d_link_previous_out_max_pass",
......
......@@ -34,6 +34,21 @@ void XPUResNet50Compute::PrepareForRun() {
}
}
void XPUResNet50DtypeCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
for (auto* filter : param.filter) {
arg_filter_.push_back(
reinterpret_cast<const int16_t*>(filter->data<float>()));
}
for (auto* bias : param.bias) {
arg_bias_.push_back(bias->data<float>());
}
for (auto* max_filter : param.max_filter) {
arg_max_filter_.push_back(max_filter->data<float>());
}
}
void XPUResNet50Compute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
......@@ -50,6 +65,22 @@ void XPUResNet50Compute::Run() {
CHECK_EQ(r, 0);
}
void XPUResNet50DtypeCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
int batch_size = param.input->dims()[0];
int r = xdnn::conv2d_int16_resnet_d<float, int16_t>(
ctx.GetRawContext(), /* context */
batch_size, /* num */
param.input->data<float>(), /* bottom */
&arg_filter_[0], /* weight_list */
param.output->mutable_data<float>(TARGET(kXPU)), /* top */
&arg_bias_[0], /* bias_list */
&arg_max_filter_[0] /* max_filter_list */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
......@@ -67,3 +98,16 @@ REGISTER_LITE_KERNEL(__xpu__resnet50,
.BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(__xpu__resnet50_d,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUResNet50DtypeCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
......@@ -38,6 +38,21 @@ class XPUResNet50Compute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
std::vector<const float *> arg_bias_;
};
class XPUResNet50DtypeCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUResNet50Param;
virtual void PrepareForRun();
virtual void Run();
private:
std::vector<const int16_t *> arg_filter_;
std::vector<const float *> arg_max_filter_;
std::vector<const float *> arg_bias_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
......
......@@ -62,3 +62,4 @@ bool XPUResNet50Op::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
} // namespace paddle
REGISTER_LITE_OP(__xpu__resnet50, paddle::lite::operators::XPUResNet50Op);
REGISTER_LITE_OP(__xpu__resnet50_d, paddle::lite::operators::XPUResNet50Op);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册