提交 de37013f 编写于 作者: H hjchen2

Support padding in 8bit depthwise conv, so remove padding from dequantize kernel

上级 7b5a6c39
......@@ -55,10 +55,10 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
framework::Tensor *transformed_weight = new framework::Tensor;
framework::Tensor transformed_weight;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight);
param->Filter() = transformed_weight;
&transformed_weight);
framework::TensorCopy(transformed_weight, param->Filter());
#endif
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
......
......@@ -170,31 +170,21 @@ template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output();
output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] || paddings[1]) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
} else {
input_pad = in_batch;
}
if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter, &out_batch);
math::DepthwiseConv3x3s1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(input_pad, *filter, &out_batch);
math::DepthwiseConv3x3s2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
......
......@@ -1278,7 +1278,10 @@ void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias->data<float>();
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
......
......@@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// const std::vector<int> &paddings,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
} // namespace math
......
......@@ -29,6 +29,7 @@ namespace math {
template <>
void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
......@@ -751,6 +752,7 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
template <>
void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
......
......@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
const RType *Input() const { return input_; }
RType *&Filter() const { return filter_; }
RType *Filter() const { return filter_; }
RType *&Output() const { return output_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
......@@ -441,8 +441,8 @@ class ConvParam : public OpParam {
private:
RType *input_;
mutable RType *output_;
mutable RType *filter_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
......
......@@ -44,25 +44,19 @@ struct Round<round::RoundTowardsZero> {
template <>
struct Round<round::RoundToEven> {
int8_t operator()(float x) {
int8_t ret = 0;
float v = std::round(x);
int32_t q = (int32_t)v;
if (abs(abs(q - x) - 0.5) > 0) {
ret = q;
} else {
if (abs(q) % 2 == 0) {
ret = q;
} else {
ret = q + ((q > 0) ? -1 : 1);
int32_t q = static_cast<int32_t>(v);
if (abs(abs(q - v) - 0.5) <= 0) {
if (abs(q) % 2 != 0) {
q = q + ((q > 0) ? -1 : 1);
}
}
return ret;
return static_cast<int8_t>(q);
}
};
template <round::RoundType T>
static void quantize(const Tensor *input, const float scale, const int pad,
const int8_t pad_val, Tensor *output) {
static void quantize(const Tensor *input, const float scale, Tensor *output) {
int batch_size = input->dims()[0];
int channels = input->dims()[1];
int input_h = input->dims()[2];
......@@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad,
for (int nc = 0; nc < batch_size * channels; ++nc) {
const float *xh = x + nc * input_spatial;
int8_t *yh = y + nc * output_spatial;
// pad top
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) {
// pad left
for (int w = 0; w < pad; ++w) {
yh[w] = pad_val;
}
for (int w = 0; w < input_w; ++w) {
yh[w + pad] = Round<T>()(xh[w] * scale);
}
// pad right
for (int w = 0; w < pad; ++w) {
yh[pad + input_w + w] = pad_val;
}
}
// pad bottom
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
yh[w] = Round<T>()(xh[w] * scale);
}
}
}
......@@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) {
int TestQuqntizeOp(int argc, char *argv[]) {
if (argc < 5) {
std::cout
<< "Usage: ./test-quantize-op batch_size channel height width [pad]"
<< std::endl;
std::cout << "Usage: ./test-quantize-op batch_size channel height width"
<< std::endl;
return 1;
}
int pad = 0;
int batch_size = atoi(argv[1]);
int channel = atoi(argv[2]);
int height = atoi(argv[3]);
int width = atoi(argv[4]);
if (argc == 6) {
pad = atoi(argv[5]);
}
std::cout << "batch_size: " << batch_size << ", channel: " << channel
<< ", height: " << height << ", width: " << width << std::endl;
framework::DDim dim =
......@@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) {
auto output_scale_var = scope.get()->Var("output_scale");
framework::AttributeMap attrs;
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad, pad}));
auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs,
attrs, scope);
op->InferShape();
......@@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) {
framework::Tensor output_cmp;
output_cmp.Resize(output->dims());
float scale = 127 / output_scale_cmp;
// quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp);
// quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, pad, 0, &output_cmp);
// quantize<round::RoundToEven>(input, scale, &output_cmp);
// quantize<round::RoundAwayZero>(input, scale, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册