提交 405630c7 编写于 作者: H hjchen2

Fix quantize kernel while pad != 0

上级 ee79fcf4
...@@ -379,8 +379,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -379,8 +379,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
const float *x3 = input3 + h * input_w; const float *x3 = input3 + h * input_w;
int loop = input_w >> 4; int loop = input_w >> 4;
int remain = input_w & 0xF; int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = paddings[1] & 0x1; int pad_remain = (paddings[1] << 1) & 0x3;
int remain_steps = remain; int remain_steps = remain;
asm volatile( asm volatile(
"vdup.f32 q0, %[scale] \n" "vdup.f32 q0, %[scale] \n"
...@@ -596,7 +596,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -596,7 +596,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
"store_pad_2w_%=: \n" "store_pad_2w_%=: \n"
"cmp %[pad_remain], #2 \n" "cmp %[pad_remain], #2 \n"
"ble store_pad_1w_%= \n" "blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n" "vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n" "vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n" "vst1.16 {d0[0]}, [%[y2]]! \n"
...@@ -605,7 +605,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -605,7 +605,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
"store_pad_1w_%=: \n" "store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n" "cmp %[pad_remain], #1 \n"
"ble end_%= \n" "blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n" "vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n" "vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n" "vst1.8 {d0[0]}, [%[y2]]! \n"
...@@ -669,8 +669,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -669,8 +669,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
const float *x0 = input0 + h * input_w; const float *x0 = input0 + h * input_w;
int loop = input_w >> 4; int loop = input_w >> 4;
int remain = input_w & 0xF; int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = paddings[1] & 0x1; int pad_remain = (paddings[1] << 1) & 0x3;
asm volatile( asm volatile(
"vdup.f32 q0, %[scale] \n" "vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n" "cmp %[loop], #0 \n"
...@@ -754,13 +754,13 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -754,13 +754,13 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
"pad_remain_%=: \n" "pad_remain_%=: \n"
"cmp %[pad_remain], #2 \n" "cmp %[pad_remain], #2 \n"
"ble store_pad_1w_%= \n" "blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n" "vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[pad_remain], #2 \n" "sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n" "store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n" "cmp %[pad_remain], #1 \n"
"ble end_%= \n" "blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n" "vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n" "end_%=: \n"
: [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop), : [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop),
...@@ -795,10 +795,10 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) { ...@@ -795,10 +795,10 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
// only support int8 currently // only support int8 currently
float scale = 127 / max_abs; float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs; param.online_scale_->mutable_data<float>()[0] = max_abs;
// const auto &paddings = param.paddings_; const auto &paddings = param.paddings_;
std::vector<int> paddings = {0, 0}; // std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_; // const auto padding_val = param.padding_val_;
int8_t padding_val = 127; int8_t padding_val = 0;
switch (param.round_type_) { switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN: case ROUND_NEAREST_TO_EVEN:
quantize_round_to_even(input, scale, paddings, padding_val, output); quantize_round_to_even(input, scale, paddings, padding_val, output);
......
...@@ -2536,6 +2536,11 @@ class QuantizeParam : public OpParam { ...@@ -2536,6 +2536,11 @@ class QuantizeParam : public OpParam {
if (HasAttr("round_type", attrs)) { if (HasAttr("round_type", attrs)) {
round_type_ = GetAttr<RoundType>("round_type", attrs); round_type_ = GetAttr<RoundType>("round_type", attrs);
} }
// get paddings
paddings_ = std::vector<int>({0, 0});
if (HasAttr("paddings", attrs)) {
paddings_ = GetAttr<vector<int>>("paddings", attrs);
}
} }
public: public:
......
...@@ -22,7 +22,10 @@ namespace operators { ...@@ -22,7 +22,10 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void QuantizeOp<DeviceType, T>::InferShape() const { void QuantizeOp<DeviceType, T>::InferShape() const {
const auto &input_dims = this->param_.input_->dims(); auto input_dims = this->param_.input_->dims();
const std::vector<int> &paddings = this->param_.paddings_;
input_dims[2] += 2 * paddings[0];
input_dims[3] += 2 * paddings[1];
this->param_.output_->Resize(input_dims); this->param_.output_->Resize(input_dims);
auto scale_dims = framework::make_ddim(std::vector<int>{1}); auto scale_dims = framework::make_ddim(std::vector<int>{1});
this->param_.online_scale_->Resize(scale_dims); this->param_.online_scale_->Resize(scale_dims);
......
...@@ -12,58 +12,128 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,58 +12,128 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream>
#include "../test_helper.h" #include "../test_helper.h"
#include "../test_include.h" #include "../test_include.h"
#include "operators/quantize_op.h" #include "operators/quantize_op.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace round {
enum RoundType {
RoundToEven = 0,
RoundAwayZero = 1,
RoundTowardsZero = 2,
};
}
static float find_abs_max(const Tensor *input) { template <round::RoundType T>
float max_abs = 0.f; static int8_t Round(float x);
const float *x = input->data<const float>();
size_t size = input->numel(); template <>
for (size_t i = 0; i < size; ++i) { static int8_t Round<round::RoundAwayZero>(float x) {
float value = std::abs(x[i]); return std::round(x);
if (value > max_abs) {
max_abs = value;
}
}
return max_abs;
} }
static void quantize_round_to_even(const Tensor *input, const float scale, template <>
Tensor *output) { static int8_t Round<round::RoundTowardsZero>(float x) {
const float *x = input->data<const float>(); return int8_t(x);
int8_t *y = output->mutable_data<int8_t>(); }
size_t size = input->numel();
for (size_t i = 0; i < size; ++i) { template <>
float value = x[i] * scale; static int8_t Round<round::RoundToEven>(float x) {
float v = round(value); int8_t ret = 0;
float v = std::round(x);
int32_t q = (int32_t)v; int32_t q = (int32_t)v;
if (abs(abs(q - value) - 0.5) > 0) { if (abs(abs(q - x) - 0.5) > 0) {
y[i] = q; ret = q;
} else { } else {
if (abs(q) % 2 == 0) { if (abs(q) % 2 == 0) {
y[i] = q; ret = q;
} else { } else {
y[i] = q + ((q > 0) ? -1 : 1); ret = q + ((q > 0) ? -1 : 1);
}
} }
} }
return ret;
} }
static void quantize_round_to_nearest(const Tensor *input, const float scale, template <round::RoundType T>
Tensor *output) { static void quantize(const Tensor *input, const float scale, const int pad,
const int8_t pad_val, Tensor *output) {
int batch_size = input->dims()[0];
int channels = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
size_t input_spatial = input_h * input_w;
size_t output_spatial = output_h * output_w;
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>(); int8_t *y = output->mutable_data<int8_t>();
std::cout << "pad: " << pad << ", pad_val: " << int(pad_val) << std::endl;
for (int nc = 0; nc < batch_size * channels; ++nc) {
const float *xh = x + nc * input_spatial;
int8_t *yh = y + nc * output_spatial;
// pad top
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) {
// pad left
for (int w = 0; w < pad; ++w) {
yh[w] = pad_val;
}
for (int w = 0; w < input_w; ++w) {
yh[w + pad] = Round<T>(xh[w] * scale);
}
// pad right
for (int w = 0; w < pad; ++w) {
yh[pad + input_w + w] = pad_val;
}
}
// pad bottom
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
}
}
static float find_abs_max(const Tensor *input) {
float max_abs = 0.f;
const float *x = input->data<const float>();
size_t size = input->numel(); size_t size = input->numel();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
y[i] = round(x[i] * scale); float value = std::abs(x[i]);
if (value > max_abs) {
max_abs = value;
} }
}
return max_abs;
} }
int TestQuqntizeOp() { int TestQuqntizeOp(int argc, char *argv[]) {
framework::DDim dim = framework::make_ddim({1, 3, 224, 224}); if (argc < 5) {
std::cout
<< "Usage: ./test-quantize-op batch_size channel height width [pad]"
<< std::endl;
return 1;
}
int pad = 0;
int batch_size = atoi(argv[1]);
int channel = atoi(argv[2]);
int height = atoi(argv[3]);
int width = atoi(argv[4]);
if (argc == 6) {
pad = atoi(argv[5]);
}
std::cout << "batch_size: " << batch_size << ", channel: " << channel
<< ", height: " << height << ", width: " << width << std::endl;
framework::DDim dim =
framework::make_ddim({batch_size, channel, height, width});
VariableNameMap inputs; VariableNameMap inputs;
VariableNameMap outputs; VariableNameMap outputs;
...@@ -80,6 +150,7 @@ int TestQuqntizeOp() { ...@@ -80,6 +150,7 @@ int TestQuqntizeOp() {
auto output_scale_var = scope.get()->Var("output_scale"); auto output_scale_var = scope.get()->Var("output_scale");
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad, pad}));
auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs, auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs,
attrs, scope); attrs, scope);
op->InferShape(); op->InferShape();
...@@ -96,10 +167,11 @@ int TestQuqntizeOp() { ...@@ -96,10 +167,11 @@ int TestQuqntizeOp() {
output_scale_cmp, output_scale_data[0]); output_scale_cmp, output_scale_data[0]);
framework::Tensor output_cmp; framework::Tensor output_cmp;
output_cmp.Resize(dim); output_cmp.Resize(output->dims());
float scale = 127 / output_scale_cmp; float scale = 127 / output_scale_cmp;
// quantize_round_to_even(input, scale, &output_cmp); // quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp);
quantize_round_to_nearest(input, scale, &output_cmp); // quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, pad, 0, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>(); int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
...@@ -113,4 +185,6 @@ int TestQuqntizeOp() { ...@@ -113,4 +185,6 @@ int TestQuqntizeOp() {
} // namespace paddle_mobile } // namespace paddle_mobile
int main() { return paddle_mobile::TestQuqntizeOp(); } int main(int argc, char *argv[]) {
return paddle_mobile::TestQuqntizeOp(argc, argv);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册