提交 48f9ae36 编写于 作者: T tensor-tang

Merge branch 'tangjian/lite/int8' into 'incubate/lite'

Tangjian/lite/int8

See merge request inference/paddlelite!61
...@@ -47,6 +47,22 @@ class DDimLite : public DDimBase<DDimLite> { ...@@ -47,6 +47,22 @@ class DDimLite : public DDimBase<DDimLite> {
std::multiplies<value_type>()); std::multiplies<value_type>());
} }
const std::vector<value_type> &data() const { return data_; } const std::vector<value_type> &data() const { return data_; }
value_type count(int start, int end) {
if (start < 0) {
start = 0;
}
if (end > size()) {
end = size();
}
if (end < start) {
end = start;
}
value_type sum = 1;
for (auto i = start; i < end; ++i) {
sum *= data_[i];
}
return sum;
}
private: private:
std::vector<value_type> data_; std::vector<value_type> data_;
......
...@@ -95,8 +95,44 @@ void ConvCompute::Run() { ...@@ -95,8 +95,44 @@ void ConvCompute::Run() {
template <PrecisionType Ptype_out> template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::PrepareForRun() { void ConvComputeInt8<Ptype_out>::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
int win = x_dims[3]; // nchw
int hin = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kh = w_dims[2]; // oihw
int kw = w_dims[3];
int ph = param.paddings[1];
int pw = param.paddings[0];
int sh = param.strides[1];
int sw = param.strides[0];
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3) && (ph == 1) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
// weigth is int8 and bias is int32 so do not need trans
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
impl_ = new lite::arm::math::DepthwiseConvInt8<Ptype_out>;
VLOG(3) << "DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) {
// impl_ = new lite::arm::math::DirectConv<Ptype_out>;
} else {
VLOG(3) << "GemmLikeConvInt8";
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>; impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
}
CHECK(this->impl_->create(param, &ctx)); CHECK(this->impl_->create(param, &ctx));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册