未验证 提交 b47478ef 编写于 作者: C cc 提交者: GitHub

[dygraph qat] Use layer to calculate output scale (#31861)

* Use layer to calculate output scale
* add backward for moving_average_abs_max_scale and save output scales to op's attr
上级 c3974d0e
...@@ -649,13 +649,18 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { ...@@ -649,13 +649,18 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
"MovingAverageAbsMaxScale"); "MovingAverageAbsMaxScale");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"MovingAverageAbsMaxScale"); "MovingAverageAbsMaxScale");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1}); ctx->SetOutputDim("OutState", {1});
} }
if (ctx->HasOutput("OutAccum")) { if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1}); ctx->SetOutputDim("OutAccum", {1});
} }
ctx->SetOutputDim("OutScale", {1}); if (ctx->HasOutput("Out")) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
} }
protected: protected:
...@@ -673,6 +678,9 @@ class MovingAverageAbsMaxScaleOpMaker ...@@ -673,6 +678,9 @@ class MovingAverageAbsMaxScaleOpMaker
AddInput("X", "(Tensor) Input is float data type."); AddInput("X", "(Tensor) Input is float data type.");
AddInput("InAccum", "Last accum.").AsDispensable(); AddInput("InAccum", "Last accum.").AsDispensable();
AddInput("InState", "Last state.").AsDispensable(); AddInput("InState", "Last state.").AsDispensable();
AddOutput("Out",
"(Tensor) Output tensor is just equivalent to the input tensor.")
.AsDispensable();
AddOutput("OutScale", " Current scale"); AddOutput("OutScale", " Current scale");
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable(); AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable(); AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
...@@ -693,7 +701,7 @@ $$Out = X$$ ...@@ -693,7 +701,7 @@ $$Out = X$$
} }
}; };
class FakeQuantDequantGradOp : public framework::OperatorWithKernel { class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -701,9 +709,9 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel { ...@@ -701,9 +709,9 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
auto out_grad_name = framework::GradVarName("Out"); auto out_grad_name = framework::GradVarName("Out");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"FakeQuantDequantGradOp"); "StrightThroughEstimatorGradOp");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"FakeQuantDequantGradOp"); "StrightThroughEstimatorGradOp");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
} }
...@@ -717,13 +725,13 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel { ...@@ -717,13 +725,13 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
}; };
template <typename T> template <typename T>
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> { class StrightThroughEstimatorMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
void Apply(GradOpPtr<T> grad_op) const override { void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fake_quantize_dequantize_grad"); grad_op->SetType("stright_throuth_estimator_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
...@@ -744,11 +752,11 @@ REGISTER_OPERATOR( ...@@ -744,11 +752,11 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>); ops::FakeQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max, REGISTER_OPERATOR(
ops::FakeQuantOrWithDequantAbsMaxOp, fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max, REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>); ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
...@@ -769,11 +777,12 @@ REGISTER_OPERATOR( ...@@ -769,11 +777,12 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max, REGISTER_OPERATOR(
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp, fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
...@@ -789,20 +798,22 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, ...@@ -789,20 +798,22 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
REGISTER_OPERATOR( REGISTER_OPERATOR(
moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker, ops::MovingAverageAbsMaxScaleOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>); ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp); REGISTER_OPERATOR(stright_throuth_estimator_grad,
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad, ops::StrightThroughEstimatorGradOp);
ops::FakeQuantDequantGradKernel<CPU, float>); REGISTER_OP_CPU_KERNEL(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_dequantize_abs_max, REGISTER_OPERATOR(
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp, fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker, ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>, ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max, fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>); ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>);
...@@ -820,4 +831,8 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale) ...@@ -820,4 +831,8 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale)
"Out", "Out",
"Delete output in order to make the inference model not " "Delete output in order to make the inference model not "
"save moving_average_abs_max_scale operator. This will " "save moving_average_abs_max_scale operator. This will "
"make the quantitative model be correctly applied in inference.")); "make the quantitative model be correctly applied in inference."))
.AddCheckpoint(
R"ROC(Incompatible upgrade of output [Out])ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"Out", "In order to support dygraph qat, add output again."));
...@@ -543,8 +543,8 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, ...@@ -543,8 +543,8 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>); ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad, REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad,
ops::FakeQuantDequantGradKernel<CUDA, float>); ops::StrightThroughEstimatorGradKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max, fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>); ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);
...@@ -314,6 +314,12 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { ...@@ -314,6 +314,12 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
if (context.HasOutput("Out")) {
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
}
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
// testing // testing
if (is_test) { if (is_test) {
...@@ -344,17 +350,17 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { ...@@ -344,17 +350,17 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantDequantGradKernel : public framework::OpKernel<T> { class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto* d_x = context.Output<framework::LoDTensor>(x_grad_name); auto* d_x = context.Output<framework::LoDTensor>(x_grad_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet(
d_x, platform::errors::PreconditionNotMet( "StrightThroughEstimatorGradKernel "
"FakeQuantDequantGradOp doesn't have the output named %s.", "doesn't have the output named %s.",
x_grad_name)); x_grad_name));
// Initialize dx as same as d_out // Initialize dx as same as d_out
d_x->mutable_data<T>(context.GetPlace()); d_x->mutable_data<T>(context.GetPlace());
......
...@@ -84,7 +84,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -84,7 +84,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"matrix_nms", {"Out", "Index", "RoisNum"}}, {"matrix_nms", {"Out", "Index", "RoisNum"}},
{"distribute_fpn_proposals", {"distribute_fpn_proposals",
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}}, {"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}}, {"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}}, {"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}}, {"momentum", {"ParamOut", "VelocityOut"}},
...@@ -137,7 +138,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -137,7 +138,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"check_finite_and_unscale", {"Out", "FoundInfinite"}}, {"check_finite_and_unscale", {"Out", "FoundInfinite"}},
{"update_loss_scaling", {"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}}, {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}}, {"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"lamb", {"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}}, {"rnn", {"DropoutState"}},
......
...@@ -21,14 +21,14 @@ import warnings ...@@ -21,14 +21,14 @@ import warnings
import paddle import paddle
from paddle.fluid import dygraph, core, framework, unique_name from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.io import load_inference_model, save_inference_model from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from . import quant_nn
from .. import quantization_pass from .. import quantization_pass
from . import quant_nn
from . import utils from . import utils
__all__ = ['ImperativeQuantAware'] __all__ = ['ImperativeQuantAware']
...@@ -201,7 +201,7 @@ class ImperativeQuantAware(object): ...@@ -201,7 +201,7 @@ class ImperativeQuantAware(object):
self._quantize_inputs = ImperativeQuantizeInputs(**kwargs) self._quantize_inputs = ImperativeQuantizeInputs(**kwargs)
self._calc_output_scale = ImperativeCalcOutputScale() self._quantize_outputs = ImperativeQuantizeOutputs()
def quantize(self, model): def quantize(self, model):
""" """
...@@ -219,11 +219,11 @@ class ImperativeQuantAware(object): ...@@ -219,11 +219,11 @@ class ImperativeQuantAware(object):
assert isinstance(model, dygraph.Layer), \ assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
self._quantize_inputs.apply(model) self._quantize_inputs.apply(model)
self._calc_output_scale.apply(model) self._quantize_outputs.apply(model)
def save_quantized_model(self, layer, path, input_spec=None, **config): def save_quantized_model(self, layer, path, input_spec=None, **config):
self._calc_output_scale.save_quantized_model(layer, path, input_spec, self._quantize_outputs.save_quantized_model(layer, path, input_spec,
**config) **config)
class ImperativeQuantizeInputs(object): class ImperativeQuantizeInputs(object):
...@@ -323,10 +323,10 @@ class ImperativeQuantizeInputs(object): ...@@ -323,10 +323,10 @@ class ImperativeQuantizeInputs(object):
idx += 1 idx += 1
target = name[last_idx:idx] target = name[last_idx:idx]
quant_layer = self._get_quantized_layer(layer) quant_layer = self._get_input_quantized_layer(layer)
setattr(obj, target, quant_layer) setattr(obj, target, quant_layer)
def _get_quantized_layer(self, layer): def _get_input_quantized_layer(self, layer):
quant_layer_name = None quant_layer_name = None
for key, value in utils.quant_input_layers_map.items(): for key, value in utils.quant_input_layers_map.items():
if isinstance(layer, value): if isinstance(layer, value):
...@@ -343,24 +343,26 @@ class ImperativeQuantizeInputs(object): ...@@ -343,24 +343,26 @@ class ImperativeQuantizeInputs(object):
return quant_nn.__dict__[quant_layer_name](layer, **self._kwargs) return quant_nn.__dict__[quant_layer_name](layer, **self._kwargs)
class ImperativeCalcOutputScale(object): class ImperativeQuantizeOutputs(object):
"""
Calculate the output scales for some layers.
"""
def __init__(self, moving_rate=0.9): def __init__(self, moving_rate=0.9):
""" """
Add the logic of calculating and setting output scales of some layers. The constructor for ImperativeQuantizeOutputs.
Args: Args:
moving_rate(float): The decay coefficient of moving average. moving_rate(float): The decay coefficient of moving average.
The default value is 0.9. The default value is 0.9.
""" """
super(ImperativeCalcOutputScale, self).__init__() super(ImperativeQuantizeOutputs, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._register_hook_handle_list = []
self._out_scale_dict = collections.OrderedDict()
def apply(self, model): def apply(self, model):
""" """
Insert the `moving_average_abs_max_scale` op to calculate output Insert the `moving_average_abs_max_scale` layers to calculate the
scale of specific layers in model. output scales for specific layers in the dygraph model.
Args: Args:
model(fluid.dygraph.Layer): The target model which would be model(fluid.dygraph.Layer): The target model which would be
...@@ -372,14 +374,25 @@ class ImperativeCalcOutputScale(object): ...@@ -372,14 +374,25 @@ class ImperativeCalcOutputScale(object):
assert isinstance(model, dygraph.Layer), \ assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
# Calculate the target ops's output scale, and don't consider for name, layer in model.named_sublayers():
# the skip_quant attr if not self._is_target_layer(layer):
for _, layer in model.named_sublayers(): continue
if self._is_target_layer(layer):
self._init_scale_params(layer) # TODO(jc): optimize this module
hook_handle = layer.register_forward_post_hook( last_idx = 0
self._calc_output_scale_hook) idx = 0
self._register_hook_handle_list.append(hook_handle) obj = model
while idx < len(name):
if (name[idx] == '.'):
if hasattr(obj, name[last_idx:idx]):
obj = getattr(obj, name[last_idx:idx])
last_idx = idx + 1
idx += 1
target = name[last_idx:idx]
quant_layer = quant_nn.__dict__["QuantizedOutputLayer"](
layer, self._moving_rate)
setattr(obj, target, quant_layer)
def save_quantized_model(self, layer, path, input_spec=None, **config): def save_quantized_model(self, layer, path, input_spec=None, **config):
""" """
...@@ -409,33 +422,18 @@ class ImperativeCalcOutputScale(object): ...@@ -409,33 +422,18 @@ class ImperativeCalcOutputScale(object):
Returns: Returns:
None None
""" """
assert isinstance(layer, dygraph.Layer), \ assert isinstance(layer, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
self._gather_output_scale(layer)
with dygraph.guard():
layer.eval()
for handle in self._register_hook_handle_list:
handle.remove()
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config)
if len(self._out_scale_dict) == 0:
warnings.warn("Warning: No Layer of the model while to be " \
"saved contains the out_threshold attribute, so the " \
"generated inference model would not contain the " \
"out_threshold.")
return
# load static model
is_dynamic_mode = False is_dynamic_mode = False
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
is_dynamic_mode = True is_dynamic_mode = True
paddle.enable_static() paddle.enable_static()
place = core.CUDAPlace(0) if core.is_compiled_with_cuda() \ place = core.CPUPlace()
else core.CPUPlace() scope = global_scope()
exe = Executor(place) exe = Executor(place)
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
...@@ -450,20 +448,10 @@ class ImperativeCalcOutputScale(object): ...@@ -450,20 +448,10 @@ class ImperativeCalcOutputScale(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
# TODO(jc): analyse whether the dygraph model has self._save_output_scale(infer_program, scope)
# several blocks before applying qat
assert infer_program.num_blocks == 1, \
"Quantization aware training (QAT) requires the program " \
"only has a block for now. When the model has if-else or " \
"while, the program will have several blocks."
# set output scales to the static model
self._save_output_scale(infer_program)
# process skip quant
self._set_skip_quant_attr(infer_program) self._set_skip_quant_attr(infer_program)
# save the final quantized model that has output scales
save_inference_model( save_inference_model(
dirname=dirname, dirname=dirname,
feeded_var_names=feed_target_names, feeded_var_names=feed_target_names,
...@@ -476,144 +464,42 @@ class ImperativeCalcOutputScale(object): ...@@ -476,144 +464,42 @@ class ImperativeCalcOutputScale(object):
if is_dynamic_mode: if is_dynamic_mode:
paddle.disable_static() paddle.disable_static()
def _gather_output_scale(self, layer): def _is_target_layer(self, layer):
"""
Gather all output scales to self._out_scale_dict
"""
with dygraph.guard():
layer.eval()
for _, sub_layer in layer.named_sublayers():
if self._is_target_layer(sub_layer):
layer_name = sub_layer.full_name()
if hasattr(sub_layer, "_quant_out_scale"):
self._out_scale_dict[layer_name] = float(
sub_layer._quant_out_scale)
def _save_output_scale(self, infer_program):
""" """
Save all output scales to the corresponding ops in static Whether the layer needs to calculate output scales.
inference program.
Because the Layer in dygraph may correspond to multiple ops
in static program after being saved. To ensure correctness,
the outscale collected for output of dygraph Layer can only
be set to the last op in the corresponding ops in static program.
""" """
assert infer_program.num_blocks == 1, \
"The inference program should only have a block."
global_block = infer_program.global_block()
target_ops = global_block.ops
scale_idx = 0
op_idx = 0
attr_name = "out_threshold"
for scale_name, scale_value in self._out_scale_dict.items():
while True:
if op_idx >= len(target_ops):
break
op = target_ops[op_idx]
if not self._is_scale_op_matched(scale_name, op, global_block):
op_idx += 1
else:
if op.type in utils.weight_op_types \
and op_idx + 1 < len(target_ops) \
and target_ops[op_idx+1].type == "elementwise_add":
target_ops[op_idx + 1]._set_attr(attr_name, scale_value)
op_idx += 2
else:
op._set_attr(attr_name, scale_value)
op_idx += 1
scale_idx += 1
break
if scale_idx != len(self._out_scale_dict):
_logger.warning("Warning: the model have %s output scales, "\
"but it only saves %s output scales." \
% (len(self._out_scale_dict), scale_idx))
def _is_target_layer(self, layer):
return isinstance(layer, tuple(utils.quant_output_layers_map.values())) \ return isinstance(layer, tuple(utils.quant_output_layers_map.values())) \
or ('quantized_' in layer.full_name() and \ or ('quantized' in layer.full_name() and \
'quantized_noweight' not in layer.full_name()) 'quantized_noweight' not in layer.full_name())
def _init_scale_params(self, layer, name=None): def _save_output_scale(self, program, scope):
""" """
Init the scale params for calculating output scales and save them in the Save all output scales to the corresponding ops in static
target layer. inference program and delete 'moving_average_abs_max_scale' ops.
After the users define the dygraph model, the hooks for calculating output
scales will not execute immediately. If the users load parameters form
checkpoint and save the quantized inference model immediately, the inference
model would not be saved successfully. Beacuse the dygraph_to_static requires
that the parameters created in __init__, but the uniqueness of hook make it
impossible to create parameters in __init__. To avoid this mistake, we define
the scale parameters in the beginning instead of hook.
""" """
for block in program.blocks:
for op in block.ops:
if op.type == "moving_average_abs_max_scale":
in_var_name = op.input('X')[0]
out_var_name = op.output('Out')[0]
out_scale_name = op.output('OutScale')[0]
def _create_param(in_layer, first_name, last_name, dtype): out_scale = utils.load_variable_data(scope, out_scale_name)
prefix = '{}.{}'.format(first_name, last_name) \ previous_op = utils.find_previous_op(block, in_var_name)
if first_name else 'outscale.{}'.format(last_name) previous_op._set_attr("out_threshold", float(out_scale))
attr = ParamAttr(
name=unique_name.generate(prefix),
initializer=Constant(1),
trainable=False)
param = in_layer.create_parameter(shape=[1], attr=attr, dtype=dtype)
return param
dtype = layer._dtype if layer._dtype is not None else "float32"
if dtype not in ["float32", "float64"]:
return
layer._quant_out_scale = _create_param(layer, name, "scale", dtype)
layer._quant_out_scale.stop_gradient = True
layer._quant_out_state = _create_param(layer, name, "state", dtype)
layer._quant_out_state.stop_gradient = True
layer._quant_out_accum = _create_param(layer, name, "accum", dtype) next_ops = utils.find_next_ops(block, out_var_name)
layer._quant_out_accum.stop_gradient = True for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name)
def _is_scale_op_matched(self, scale_name, op, block): def _set_skip_quant_attr(self, program):
""" """
Based on the op name and attrs to judge whether the op in Label the skip quantized ops.
program matches the scale_name. We must know the corresponding
name between dgraph and static model.
""" """
fp_type = [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32] for block in program.blocks:
if op.type in quantization_pass._op_real_in_out_name.keys(): for op in block.ops:
output_var_names = quantization_pass._get_op_output_var_names(op) if self._is_skip_quant_op(block, op):
for output_var_name in output_var_names: op._set_attr("skip_quant", True)
output_var_tensor = block.var(output_var_name)
if output_var_tensor.dtype not in fp_type:
return False
# corresponding_map: [name, op_types, function]
# Note that, the items have priority in corresponding_map
corresponding_map = [
['conv2d_tranpose', ['conv2d_transpose', \
'depthwise_conv2d_transpose'], None],
['conv2d', ['conv2d', 'depthwise_conv2d'], None],
['linear', ['matmul'], None],
['re_lu6', ['relu6'], None],
['p_re_lu', ['prelu'], None],
['leaky_re_lu', ['leaky_relu'], None],
['re_lu', ['relu'], None],
]
for item in corresponding_map:
if item[0] in scale_name:
return (op.type in item[1]) and \
(len(item) == 2 or item[2] is None or item[2](op))
return op.type in scale_name
def _set_skip_quant_attr(self, program):
block = program.global_block()
for op in block.ops:
if self._is_skip_quant_op(block, op):
op._set_attr("skip_quant", True)
def _is_skip_quant_op(self, block, in_op): def _is_skip_quant_op(self, block, in_op):
""" """
...@@ -621,33 +507,11 @@ class ImperativeCalcOutputScale(object): ...@@ -621,33 +507,11 @@ class ImperativeCalcOutputScale(object):
1. the type of input op should be conv2d, depthwise_conv2d or matmul 1. the type of input op should be conv2d, depthwise_conv2d or matmul
2. the previous ops of the input op are not fake_quantize_dequantize ops 2. the previous ops of the input op are not fake_quantize_dequantize ops
""" """
def _find_previous_op(block, var_name):
for op in block.ops:
if var_name in op.output_arg_names:
return op
target_op_types = ["conv2d", "depthwise_conv2d", "matmul"] target_op_types = ["conv2d", "depthwise_conv2d", "matmul"]
if in_op.type not in target_op_types: if in_op.type not in target_op_types:
return False return False
previous_ops = [_find_previous_op(block, arg_name) \ previous_ops = [utils.find_previous_op(block, arg_name) \
for arg_name in in_op.input_arg_names] for arg_name in in_op.input_arg_names]
return any(op is not None and op.type not in utils.fake_quantize_dequantize_types \ return any(op is not None and op.type not in \
for op in previous_ops ) utils.fake_quantize_dequantize_types for op in previous_ops)
def _calc_output_scale_hook(self, layer, input, output):
"""
Create the MovingAverageAbsMaxScale layer for the target layer if needed.
Execute MovingAverageAbsMaxScale layer to calculate the output scale.
"""
assert isinstance(output, (core.VarBase, framework.Variable)), \
"Multiple outputs are not currently supported in ImperativeOutScale."
fp_types = [core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64]
if output.dtype in fp_types:
if not hasattr(layer, "_out_scale"):
self._out_scale = quant_nn.MovingAverageAbsMaxScale(
layer, output.name, self._moving_rate, output.dtype)
# TODO (jc): consider the ops that have several outputs
self._out_scale(output)
...@@ -507,59 +507,42 @@ class QuantizedNoweightLayer(layers.Layer): ...@@ -507,59 +507,42 @@ class QuantizedNoweightLayer(layers.Layer):
class MovingAverageAbsMaxScale(layers.Layer): class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, layer=None, name=None, moving_rate=0.9, dtype='float32'): def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
r""" r"""
MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer. MovingAverageMaxScale layer is used to calculating the output quantization
Its computational formula is described as below: scale of Layer. Its computational formula is described as below:
:math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)` :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
:math:`Out = X` :math:`Out = X`
""" """
super(MovingAverageAbsMaxScale, self).__init__() super(MovingAverageAbsMaxScale, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._dtype = dtype
self._layer = layer
if self._layer is None or not hasattr(self._layer, "_quant_out_scale"): scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' scale_name = unique_name.generate(scale_prefix)
scale_name = unique_name.generate(scale_prefix) scale_attr = ParamAttr(
scale_attr = ParamAttr( name=scale_name, initializer=Constant(1), trainable=False)
name=scale_name, initializer=Constant(1), trainable=False) self._scale = self.create_parameter(
self._scale = self.create_parameter( shape=[1], attr=scale_attr, dtype=dtype)
shape=[1], attr=scale_attr, dtype=self._dtype) self._scale.stop_gradient = True
self._scale.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_scale", self._scale)
else:
self._scale = self._layer._quant_out_scale
if self._layer is None or not hasattr(self._layer, "_quant_out_state"): state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_prefix = "{}.state".format(name) if name else 'outscale.state' state_attr = ParamAttr(
state_attr = ParamAttr( name=unique_name.generate(state_prefix),
name=unique_name.generate(state_prefix), initializer=Constant(1),
initializer=Constant(1), trainable=False)
trainable=False) self._state = self.create_parameter(
self._state = self.create_parameter( shape=[1], attr=state_attr, dtype=dtype)
shape=[1], attr=state_attr, dtype=self._dtype) self._state.stop_gradient = True
self._state.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_state", self._state)
else:
self._state = self._layer._quant_out_state
if self._layer is None or not hasattr(self._layer, "_quant_out_accum"): accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' accum_attr = ParamAttr(
accum_attr = ParamAttr( name=unique_name.generate(accum_prefix),
name=unique_name.generate(accum_prefix), initializer=Constant(1),
initializer=Constant(1), trainable=False)
trainable=False) self._accum = self.create_parameter(
self._accum = self.create_parameter( shape=[1], attr=accum_attr, dtype=dtype)
shape=[1], attr=accum_attr, dtype=self._dtype) self._accum.stop_gradient = True
self._accum.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_accum", self._accum)
else:
self._accum = self._layer._quant_out_accum
def forward(self, input): def forward(self, input):
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -567,18 +550,30 @@ class MovingAverageAbsMaxScale(layers.Layer): ...@@ -567,18 +550,30 @@ class MovingAverageAbsMaxScale(layers.Layer):
not self.training) not self.training)
state = self._state if self.training else None state = self._state if self.training else None
accum = self._accum if self.training else None accum = self._accum if self.training else None
quant_out = _varbase_creator(
type=input.type,
name="{}.tmp".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
self._scale, _, _ = core.ops.moving_average_abs_max_scale( out, _, _, _ = core.ops.moving_average_abs_max_scale(
input, accum, state, self._scale, state, accum, *attrs) input, accum, state, quant_out, self._scale, state, accum,
return self._scale *attrs)
return out
check_variable_and_dtype(input, 'input', ['float32', 'float64'], check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'MovingAverageAbsMaxScale') 'MovingAverageAbsMaxScale')
attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training} attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
inputs = {"X": [input]} inputs = {"X": [input]}
outputs = {"OutScale": [self._scale]} quant_out = self._helper.create_variable(
name="{}.tmp".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
outputs = {"Out": [quant_out], "OutScale": [self._scale]}
if self.training: if self.training:
inputs['InState'] = [self._state] inputs['InState'] = [self._state]
...@@ -592,4 +587,22 @@ class MovingAverageAbsMaxScale(layers.Layer): ...@@ -592,4 +587,22 @@ class MovingAverageAbsMaxScale(layers.Layer):
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs)
return self._scale return quant_out
class QuantizedOutputLayer(layers.Layer):
def __init__(self, layer=None, moving_rate=0.9, dtype='float32'):
r"""
Add MovingAverageMaxScale layer to the behind of the input layer.
"""
super(QuantizedOutputLayer, self).__init__()
self._layer = layer
self._moving_average_abs_max_scale = \
MovingAverageAbsMaxScale(layer.full_name(), moving_rate, dtype)
def forward(self, input):
if isinstance(input, list):
assert len(input) == 1, \
"The QuantizedOutputLayer should only have one input."
out = self._layer(input)
return self._moving_average_abs_max_scale(out)
...@@ -13,22 +13,7 @@ ...@@ -13,22 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
import numpy as np
op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"softmax": [["X"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
}
quant_input_layers_map = { quant_input_layers_map = {
'Conv2D': paddle.nn.Conv2D, 'Conv2D': paddle.nn.Conv2D,
...@@ -85,3 +70,33 @@ weight_op_types = [ ...@@ -85,3 +70,33 @@ weight_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose", "conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose",
"depthwise_conv2d_transpose" "depthwise_conv2d_transpose"
] ]
def load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Can not find " + var_name + " in the scope."
return np.array(var_node.get_tensor())
def find_previous_op(block, var_name):
"""
Find the previous op for the input variable.
"""
for op in block.ops:
if var_name in op.output_arg_names:
return op
def find_next_ops(block, var_name):
"""
Find all followed ops for the input variable.
"""
res_ops = []
for op in block.ops:
if var_name in op.input_arg_names:
res_ops.append(op)
return res_ops
...@@ -478,30 +478,5 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): ...@@ -478,30 +478,5 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
self.assertTrue(op_count == 14) self.assertTrue(op_count == 14)
class TestSaveQuantizedModel_Warning(unittest.TestCase):
def test_warning(self):
path = "./dynamic_outscale_infer_model_with_warnings/lenet"
imperative_out_scale = ImperativeQuantAware()
with fluid.dygraph.guard():
lenet = ImperativeLenet()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
warning_message = "Warning: No Layer of the model while to be " \
"saved contains the out_threshold attribute, so the " \
"generated inference model would not contain the " \
"out_threshold."
num = get_vaild_warning_num(warning_message, w)
assert num == 1
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -166,12 +166,14 @@ class TestMovingAverageAbsMaxScaleOp(OpTest): ...@@ -166,12 +166,14 @@ class TestMovingAverageAbsMaxScaleOp(OpTest):
accum[0] = 1 accum[0] = 1
state = np.zeros(1).astype("float32") state = np.zeros(1).astype("float32")
state[0] = 1 state[0] = 1
x = np.random.random((8, 16, 7, 7)).astype("float32")
self.inputs = { self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"), 'X': x,
'InAccum': accum, 'InAccum': accum,
'InState': state, 'InState': state,
} }
out = x
out_accum = np.zeros(1).astype("float32") out_accum = np.zeros(1).astype("float32")
out_state = np.zeros(1).astype("float32") out_state = np.zeros(1).astype("float32")
out_scale = np.zeros(1).astype("float32") out_scale = np.zeros(1).astype("float32")
...@@ -180,6 +182,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest): ...@@ -180,6 +182,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest):
out_state[0] = self.attrs['moving_rate'] * state[0] + 1 out_state[0] = self.attrs['moving_rate'] * state[0] + 1
out_scale = out_accum / out_state out_scale = out_accum / out_state
self.outputs = { self.outputs = {
'Out': out,
'OutAccum': out_accum, 'OutAccum': out_accum,
'OutState': out_state, 'OutState': out_state,
'OutScale': out_scale, 'OutScale': out_scale,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册