未验证 提交 c2173f91 编写于 作者: Z zhangshijin 提交者: GitHub

Merge pull request #62 from Cambricon/fix_conv_quant

fix(quant): fix quant compute error
......@@ -249,6 +249,10 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CNML_CALL(cnmlDestroyConvOpParam(&conv_param));
} else {
cnmlConvOpParam_t conv_param;
VLOG(5) << "conv param (" << input_var_name << ")"
<< "stride: " << strides[0] << ',' << strides[1] << '\t'
<< "dilations: " << dilations[0] << ',' << dilations[1] << '\t'
<< "paddings: " << paddings[0] << ',' << paddings[2] << std::endl;
CNML_CALL(cnmlCreateConvOpParam(&conv_param,
strides[0],
strides[1],
......@@ -272,7 +276,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
graph->SetComputingDataType(
conv_op,
filter_tensor->mlu_tensor(),
1 / *min_element(weight_scale.begin(), weight_scale.end()));
1 / *max_element(weight_scale.begin(), weight_scale.end()));
}
CNML_CALL(cnmlSetOperationComputingLayout(conv_op, CNML_NHWC));
if (HasInputArg(op_info, scope, "Bias")) {
......
......@@ -157,7 +157,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
graph->SetComputingDataType(
fc_op,
w_tensor->mlu_tensor(),
1 / *min_element(weight_scale.begin(), weight_scale.end()));
1 / *max_element(weight_scale.begin(), weight_scale.end()));
graph->FuseOp(fc_op);
CNML_CALL(cnmlDestroyBaseOp(&fc_op));
......
......@@ -215,8 +215,11 @@ class Graph {
float scale,
cnmlDataType_t data_type = CNML_DATA_INT8) {
cnmlQuantizedParam_t quant_param;
CNML_CALL(
cnmlCreateQuantizedParam(&quant_param, scale2position(scale), 1, 0.0));
int pos = scale2position(scale);
auto cnml_scale = pow(2, pos) * scale;
VLOG(5) << "[cnml quantized param] pos: " << pos
<< "\tscale: " << cnml_scale << std::endl;
CNML_CALL(cnmlCreateQuantizedParam(&quant_param, pos, cnml_scale, 0.0));
CNML_CALL(
cnmlSetOperationComputingDataType(op, tensor, data_type, quant_param));
CNML_CALL(cnmlDestroyQuantizedParam(&quant_param));
......
......@@ -60,8 +60,6 @@ void transpose(float* input_data,
}
}
int scale2position(float scale) { return static_cast<int>(-std::log2(scale)); }
void dequant(float* dst, int8_t* src, size_t size, float scale) {
for (size_t i = 0; i < size; ++i) {
dst[i] = static_cast<float>(src[i]) * scale;
......
......@@ -36,7 +36,9 @@ void transpose(float* input_data,
float* output_data,
std::vector<int> input_shape,
std::vector<int> axis);
int scale2position(float scale);
inline int scale2position(float scale) { return std::floor(-std::log2(scale)); }
void dequant(float* dst, int8_t* src, size_t size, float scale);
void dequant(float* dst,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册