diff --git a/paddle/fluid/lite/core/lite_tensor.h b/paddle/fluid/lite/core/lite_tensor.h index 9860265bbb342e91cfd8031eef6eb1062c98920f..ecc8b0629c070588a8c1bf0c3f30dd34c3a957a7 100644 --- a/paddle/fluid/lite/core/lite_tensor.h +++ b/paddle/fluid/lite/core/lite_tensor.h @@ -47,6 +47,22 @@ class DDimLite : public DDimBase { std::multiplies()); } const std::vector &data() const { return data_; } + value_type count(int start, int end) { + if (start < 0) { + start = 0; + } + if (end > size()) { + end = size(); + } + if (end < start) { + end = start; + } + value_type sum = 1; + for (auto i = start; i < end; ++i) { + sum *= data_[i]; + } + return sum; + } private: std::vector data_; diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.cc b/paddle/fluid/lite/kernels/arm/conv_compute.cc index 44223ee37184146d56732484f0cedfdb1bf44619..811a4341f16bf72a3ec76bb88e2259ee9ca546bc 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute.cc @@ -95,8 +95,44 @@ void ConvCompute::Run() { template void ConvComputeInt8::PrepareForRun() { auto& param = this->Param(); + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + auto& ctx = this->ctx_->template As(); - impl_ = new lite::arm::math::GemmLikeConvInt8; + + int win = x_dims[3]; // nchw + int hin = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int kh = w_dims[2]; // oihw + int kw = w_dims[3]; + int ph = param.paddings[1]; + int pw = param.paddings[0]; + int sh = param.strides[1]; + int sw = param.strides[0]; + + bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); + bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); + bool flag_dw_3x3 = (kw == 3) && (ph == 1) && (sw == 1 || sw == 2); + bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); + bool flag_dw = flag_dw_3x3 || flag_dw_5x5; + + // weigth is int8 and bias is int32 so do not need trans + if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { + impl_ = new lite::arm::math::DepthwiseConvInt8; + VLOG(3) << "DepthwiseConv Int8"; + } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && + kps_equal && no_dilation) { + // impl_ = new lite::arm::math::DirectConv; + } else { + VLOG(3) << "GemmLikeConvInt8"; + impl_ = new lite::arm::math::GemmLikeConvInt8; + } + CHECK(this->impl_->create(param, &ctx)); }