提交 5337daea 编写于 作者: H hjchen2

Optimize 5x5 depthwise conv for speedup 6x

上级 8bae119c
......@@ -64,9 +64,10 @@ void OperatorBase<Dtype>::Run() {
for (const auto key : input_keys) {
auto var_vec_in = inputs_.at(key);
for (int i = 0; i < var_vec_in.size(); ++i) {
auto vari = scope_->FindVar(var_vec_in[i]);
DLOG << var_vec_in[i];
auto vari = this->scope_->FindVar("input");
if (vari->IsInitialized()) {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
}
}
......@@ -76,7 +77,7 @@ void OperatorBase<Dtype>::Run() {
for (int i = 0; i < var_vec_out.size(); ++i) {
auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor;
}
}
......@@ -97,10 +98,10 @@ void OperatorBase<GPU_CL>::Run() {
auto vari = scope_->FindVar(var_vec_in[i]);
if (vari->IsInitialized()) {
if (type_ == "feed") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
const CLImage *cl_image = vari->template Get<framework::CLImage>();
if (cl_image) {
DLOG << type_ << " input- " << key << "=" << *cl_image;
}
......@@ -114,12 +115,12 @@ void OperatorBase<GPU_CL>::Run() {
auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) {
if (type_ == "fetch") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) {
DLOG << type_ << " output- " << key << "=" << *tensor;
}
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
const CLImage *cl_image = vari->template Get<framework::CLImage>();
if (cl_image) {
DLOG << type_ << " output- " << key << "=" << *cl_image;
}
......
......@@ -14,6 +14,7 @@
#include "io/api_paddle_mobile.h"
#include <vector>
#include "common/enforce.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......
......@@ -12,19 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains the implementation of inference API with Anakin engine
* embeded, this API can only support Anakin models.
*/
#pragma once
#include <vector>
#include "io/paddle_inference_api.h"
// from paddle_mobile
#include "common/enforce.h"
#include "common/types.h"
#include "io/paddle_inference_api.h"
#include "io/paddle_mobile.h"
namespace paddle_mobile {
......
......@@ -104,6 +104,8 @@ class PaddlePredictor {
// The common configs for all the predictors.
struct Config {
std::string model_dir; // path to the model directory.
std::string prog_file;
std::string param_file;
};
protected:
......@@ -128,9 +130,8 @@ struct PaddleMobileConfig : public PaddlePredictor::Config {
int batch_size = 1;
bool optimize = true;
bool quantification = false;
bool lod_mode = false;
int thread_num = 1;
std::string prog_file;
std::string param_file;
std::string cl_path;
struct PaddleModelMemoryPack memory_pack;
};
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "io/paddle_mobile.h"
#include <utility>
#include "common/common.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_MOBILE_CL
#include <CL/cl.h>
#include "framework/cl/cl_tensor.h"
......@@ -33,7 +36,7 @@ void PaddleMobile<Device, T>::SetThreadNum(int num) {
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -43,7 +46,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(dirname, optimize, quantification), batch_size, optimize,
loddable);
lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -55,7 +58,7 @@ template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -65,7 +68,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(model_path, para_path, optimize, quantification),
batch_size, optimize, loddable);
batch_size, optimize, lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -73,6 +76,21 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
return PMSuccess;
}
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const PaddleMobileConfig &config) {
if (!config.model_dir.empty()) {
return this->Load(config.model_dir, config.optimize, config.quantification,
config.batch_size, config.lod_mode);
} else if (!config.prog_file.empty() && !config.param_file.empty()) {
return this->Load(config.prog_file, config.param_file, config.optimize,
config.quantification, config.batch_size,
config.lod_mode);
} else {
LOG(kLOG_ERROR) << "Failed to load inference model";
return PMNotInitialized;
}
}
template <typename Device, typename T>
bool PaddleMobile<Device, T>::LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
......
......@@ -18,15 +18,12 @@ limitations under the License. */
#include <string>
#include <utility>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#include "common/types.h"
#include "framework/executor.h"
#include "framework/load_ops.h"
#include "framework/loader.h"
#include "framework/tensor.h"
#include "io/paddle_inference_api.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_engine.h"
#endif
......@@ -46,10 +43,12 @@ class PaddleMobile {
PMStatus Load(const std::string &dirname, const bool optimize = false,
const bool quantification = false, const int batch_size = 1,
const bool lod = false);
const bool lod_mode = false);
PMStatus Load(const std::string &model_path, const std::string &para_path,
const bool optimize = false, const bool quantification = false,
const int batch_size = 1, const bool lod = false);
const int batch_size = 1, const bool lod_mode = false);
PMStatus Load(const PaddleMobileConfig &config);
PMStatus Predict(const framework::Tensor &input);
PMStatus Predict(const framework::LoDTensor &input);
......
......@@ -24,8 +24,12 @@ template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 5;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) {
if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
......@@ -46,6 +50,9 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5S1_FLOAT;
#ifndef __aarch64__
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
......@@ -87,6 +94,10 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5S1_FLOAT:
math::DepthwiseConv5x5S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv5x5.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(__ARM_NEON__) && !defined(__aarch64__)
#include "operators/math/depthwise_conv5x5.h"
#include <arm_neon.h>
#include <iostream>
namespace paddle_mobile {
namespace operators {
namespace math {
#ifndef __aarch64__
inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) {
float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0));
float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1));
return vcombine_f32(sum0, sum1);
}
#endif
template <int Stride = 1>
inline void Depth5x5NormalRowLoadInput(const float *input, float32x4_t *y) {
y[0] = vld1q_f32(input);
y[4] = vld1q_f32(input + 4);
y[1] = vextq_f32(y[0], y[4], 1);
y[2] = vextq_f32(y[0], y[4], 2);
y[3] = vextq_f32(y[0], y[4], 3);
}
template <>
inline void Depth5x5NormalRowLoadInput<2>(const float *input, float32x4_t *y) {
float32x4x2_t x = vld2q_f32(input);
y[0] = x.val[0];
y[1] = x.val[1];
y[2] = vextq_f32(y[0], y[0], 1);
y[3] = vextq_f32(y[1], y[1], 1);
y[4] = vextq_f32(y[0], y[0], 2);
}
#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 5; \
const int w_start = w_in_start > 0 ? w_in_start : 0; \
const int w_end = w_in_end < input_w ? w_in_end : input_w; \
float value = 0; \
for (int h_in = h_start; h_in < h_end; ++h_in) { \
for (int w_in = w_start; w_in < w_end; ++w_in) { \
value += filter[(h_in - h_in_start) * 5 + (w_in - w_in_start)] * \
input[h_in * input_w + w_in]; \
} \
} \
output_ptr[w] = value; \
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
const int h_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
float *output, float32x4_t *ker,
float32_t *ker1) {
const int h_in_start = -padding_h + h_output * Stride_h;
const int h_in_end = h_in_start + 5;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start;
float *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
// middle
int output_tiles = (valid_w_end - valid_w_start) >> 2;
float32x4_t _sum, _x[5];
// valid w
for (int w = 0; w < output_tiles * 4; w += 4) {
_sum = vdupq_n_f32(0.f);
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride_w - padding_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth5x5NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_n_f32(_sum, _x[0], ker1[index]);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[3], vget_high_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1);
}
vst1q_f32(output_ptr + output_offset, _sum);
}
// remain valid w
int remain = (valid_w_end - valid_w_start) & 0x3;
if (remain > 0) {
_sum = vdupq_n_f32(0.f);
int remain_start = valid_w_start + (output_tiles << 2);
int input_w_offset = remain_start * Stride_w - padding_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth5x5NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_n_f32(_sum, _x[0], ker1[index]);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[3], vget_high_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1);
}
switch (remain) {
case 1:
vst1_lane_f32(output_ptr + remain_start, vget_low_f32(_sum), 0);
break;
case 2:
vst1_f32(output_ptr + remain_start, vget_low_f32(_sum));
break;
case 3:
vst1_f32(output_ptr + remain_start, vget_low_f32(_sum));
vst1_lane_f32(output_ptr + remain_start + 2, vget_high_f32(_sum), 0);
break;
}
}
// border right
DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w)
}
template <>
void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const float *input_data = input.data<float>();
const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = padding_h;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
DLOG << "valid_h_start: " << valid_h_start;
DLOG << "valid_h_end: " << valid_h_end;
DLOG << "valid_w_start: " << valid_w_start;
DLOG << "valid_w_end: " << valid_w_end;
for (int g = 0; g < input.dims()[1]; ++g) {
const float *input_ptr = input_data + g * image_size;
const float *filter_ptr = filter_data + g * 25;
float *output_ptr = out_data + g * out_image_size;
const float *filter_ptr0 = filter_ptr;
const float *filter_ptr1 = filter_ptr0 + 5;
const float *filter_ptr2 = filter_ptr1 + 5;
const float *filter_ptr3 = filter_ptr2 + 5;
const float *filter_ptr4 = filter_ptr3 + 5;
float32x4_t _ker[7];
float32_t _ker1[5] = {*filter_ptr0, *filter_ptr1, *filter_ptr2,
*filter_ptr3, *filter_ptr4};
_ker[0] = vld1q_f32(filter_ptr0 + 1);
_ker[1] = vld1q_f32(filter_ptr1 + 1);
_ker[2] = vld1q_f32(filter_ptr2 + 1);
_ker[3] = vld1q_f32(filter_ptr3 + 1);
_ker[4] = vld1q_f32(filter_ptr4 + 1);
_ker[5] = vld1q_f32(_ker1);
_ker[6] = vld1q_f32(_ker1 + 4);
// pad top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker, _ker1);
}
// output 4x4
int output_w_tiles = valid_w / 4;
int output_w_remain = valid_w - output_w_tiles * 4;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t row5 = vld1q_f32(input_ptr5);
float32x4_t zero = vdupq_n_f32(0.f);
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 5) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
} else {
row0 = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]);
row1 = vmulq_f32(row1, _ker[0]);
row1 = vmlaq_f32(row1, row2, _ker[1]);
row1 = vmlaq_f32(row1, row3, _ker[2]);
row1 = vmlaq_f32(row1, row4, _ker[3]);
row1 = vmlaq_f32(row1, row5, _ker[4]);
row0 = vpaddq_f32(row0, row1);
float32x2_t sum =
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0));
vst1_lane_f32(output_ptr0 + w, sum, 0);
vst1_lane_f32(output_ptr1 + w, sum, 1);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
row4 = vextq_f32(zero, row4, 3);
row5 = vextq_f32(zero, row5, 3);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #16 \n"
"loop_2h4w_%=: \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vmul.f32 q15, q9, %e[ker0][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vmla.f32 q15, q13, %e[kr0][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vmla.f32 q15, q13, %e[kr0][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q15, q13, %f[kr0][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q15, q10, %f[kr0][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vmla.f32 q15, q11, %e[ker0][1] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vmla.f32 q15, q13, %e[kr1][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vmla.f32 q15, q13, %e[kr1][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q15, q13, %f[kr1][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vmla.f32 q15, q12, %f[kr1][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr5]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vmla.f32 q15, q7, %f[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vmla.f32 q15, q13, %e[kr2][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vmla.f32 q15, q13, %e[kr2][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q15, q13, %f[kr2][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q15, q8, %f[kr2][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vmla.f32 q15, q9, %f[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vmla.f32 q15, q13, %e[kr3][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vmla.f32 q15, q13, %e[kr3][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q15, q13, %f[kr3][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
"vmla.f32 q15, q10, %f[kr3][1] \n"
"vmla.f32 q15, q11, %e[ker1][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q15, q13, %e[kr4][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q15, q13, %e[kr4][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q15, q13, %f[kr4][0] \n"
"vmla.f32 q15, q12, %f[kr4][1] \n"
// restore output
"vst1.32 {q14}, [%[output_ptr0]]! \n"
"vst1.32 {q15}, [%[output_ptr1]]! \n"
"subs %[loop], #1 \n"
"bne loop_2h4w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain], lsl #2 \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vmul.f32 q15, q9, %e[ker0][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vmla.f32 q15, q13, %e[kr0][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vmla.f32 q15, q13, %e[kr0][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q15, q13, %f[kr0][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q15, q10, %f[kr0][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vmla.f32 q15, q11, %e[ker0][1] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vmla.f32 q15, q13, %e[kr1][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vmla.f32 q15, q13, %e[kr1][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q15, q13, %f[kr1][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vmla.f32 q15, q12, %f[kr1][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr5]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vmla.f32 q15, q7, %f[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vmla.f32 q15, q13, %e[kr2][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vmla.f32 q15, q13, %e[kr2][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q15, q13, %f[kr2][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q15, q8, %f[kr2][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vmla.f32 q15, q9, %f[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vmla.f32 q15, q13, %e[kr3][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vmla.f32 q15, q13, %e[kr3][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q15, q13, %f[kr3][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
"vmla.f32 q15, q10, %f[kr3][1] \n"
"vmla.f32 q15, q11, %e[ker1][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q15, q13, %e[kr4][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q15, q13, %e[kr4][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q15, q13, %f[kr4][0] \n"
"vmla.f32 q15, q12, %f[kr4][1] \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d28}, [%[output_ptr0]]! \n"
"vst1.32 {d30}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d29[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d31[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"vst1.32 {d28[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
: [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain), [kr0] "w"(_ker[0]),
[kr1] "w"(_ker[1]), [kr2] "w"(_ker[2]), [kr3] "w"(_ker[3]),
[kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6])
: "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15", "r0");
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t row5 = vld1q_f32(input_ptr5);
float32x4_t zero = vdupq_n_f32(0.f);
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
} else {
int iw = w - valid_w_end;
float sum0 = input_ptr0[iw] * filter_ptr0[0] +
input_ptr1[iw] * filter_ptr1[0] +
input_ptr2[iw] * filter_ptr2[0] +
input_ptr3[iw] * filter_ptr3[0] +
input_ptr4[iw] * filter_ptr4[0];
float sum1 = input_ptr1[iw] * filter_ptr0[0] +
input_ptr2[iw] * filter_ptr1[0] +
input_ptr3[iw] * filter_ptr2[0] +
input_ptr4[iw] * filter_ptr3[0] +
input_ptr5[iw] * filter_ptr4[0];
row0 = vextq_f32(row0, zero, 1);
row1 = vextq_f32(row1, zero, 1);
row2 = vextq_f32(row2, zero, 1);
row3 = vextq_f32(row3, zero, 1);
row4 = vextq_f32(row4, zero, 1);
row5 = vextq_f32(row5, zero, 1);
row0 = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]);
row1 = vmulq_f32(row1, _ker[0]);
row1 = vmlaq_f32(row1, row2, _ker[1]);
row1 = vmlaq_f32(row1, row3, _ker[2]);
row1 = vmlaq_f32(row1, row4, _ker[3]);
row1 = vmlaq_f32(row1, row5, _ker[4]);
row0 = vpaddq_f32(row0, row1);
float32x2_t sum =
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0));
sum0 += vget_lane_f32(sum, 0);
sum1 += vget_lane_f32(sum, 1);
*output_ptr0 = sum0;
*output_ptr1 = sum1;
}
output_ptr0++;
output_ptr1++;
}
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
float *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t zero = vdupq_n_f32(0.f);
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 5) {
output_ptr0[w] = 0.f;
} else {
row0 = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]);
float32x2_t sum =
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0));
sum = vpadd_f32(sum, sum);
vst1_lane_f32(output_ptr0 + w, sum, 0);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
row4 = vextq_f32(zero, row4, 3);
}
}
output_ptr0 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #16 \n"
"loop_1h4w_%=: \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
// restore output
"vst1.32 {q14}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h4w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain], lsl #2 \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d28}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d29[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"vst1.32 {d28[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [output_ptr0] "+r"(output_ptr0),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain), [kr0] "w"(_ker[0]),
[kr1] "w"(_ker[1]), [kr2] "w"(_ker[2]), [kr3] "w"(_ker[3]),
[kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6])
: "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15", "r0");
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t zero = vdupq_n_f32(0.f);
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) {
*output_ptr0 = 0.f;
} else {
int iw = w - valid_w_end;
float sum0 = input_ptr0[iw] * filter_ptr0[0] +
input_ptr1[iw] * filter_ptr1[0] +
input_ptr2[iw] * filter_ptr2[0] +
input_ptr3[iw] * filter_ptr3[0] +
input_ptr4[iw] * filter_ptr4[0];
row0 = vextq_f32(row0, zero, 1);
row1 = vextq_f32(row1, zero, 1);
row2 = vextq_f32(row2, zero, 1);
row3 = vextq_f32(row3, zero, 1);
row4 = vextq_f32(row4, zero, 1);
row0 = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]);
float32x2_t sum =
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0));
sum = vpadd_f32(sum, sum);
sum0 += vget_lane_f32(sum, 0);
*output_ptr0 = sum0;
}
output_ptr0++;
}
}
}
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker, _ker1);
}
}
}
template <>
void DepthwiseConv5x5S2<float, float>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __ARM_NEON__
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv5x5(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// const std::vector<int> &paddings,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv5x5S1(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv5x5S2(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -424,6 +424,8 @@ class ConvParam : public OpParam {
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
EXEC_DEPTHWISE5x5S1_FLOAT,
EXEC_DEPTHWISE5x5S2_FLOAT,
EXEC_GEMM_INT8,
EXEC_DEPTHWISE3x3_INT8,
};
......@@ -2598,8 +2600,8 @@ class QuantizeParam : public OpParam {
// if offine scale or not
bool offline_ = false;
// round method type
// RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
// RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册