未验证 提交 8fabb1c3 编写于 作者: C cc 提交者: GitHub

Add test attribute in channelwise_quant op, test=develop (#27742)

* Add test attribute in channelwise_quant op, test=develop
上级 81d3992c
......@@ -404,6 +404,10 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
"the received is %d",
bit_length));
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
......
......@@ -146,16 +146,19 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis");
bool is_test = context.Attr<bool>("is_test");
auto& dev_ctx = context.template device_context<DeviceContext>();
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis,
out_scale_data);
if (!is_test) {
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis,
out_scale_data);
}
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out);
}
......
......@@ -758,6 +758,7 @@ class QuantizationTransformPass(object):
attrs={
'bit_length': quant_bits,
'quant_axis': quant_axis,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node},
......@@ -1125,7 +1126,7 @@ class QuantizationFreezePass(object):
self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node)
# Remove all fake dequant op
# Remove all fake dequant op
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
......@@ -1331,16 +1332,25 @@ class QuantizationFreezePass(object):
def _quant(self, x, scale, num_bits, quant_axis):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
bnt = (1 << (num_bits - 1)) - 1
def _clip(x, scale):
x[x > scale] = scale
x[x < -scale] = -scale
return x
if isinstance(scale, list):
for i, s in enumerate(scale):
if quant_axis == 0:
x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
x[i] = _clip(x[i], s)
x[i] = np.round(x[i] / s * bnt)
else:
x[:, i] = np.round(x[:, i] / s * (
(1 << (num_bits - 1)) - 1))
return x
x[:, i] = _clip(x[:, i], s)
x[:, i] = np.round(x[:, i] / s * bnt)
else:
return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
x = _clip(x, scale)
x = np.round(x / scale * bnt)
return x
class ConvertToInt8Pass(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册