未验证 提交 19592d2b 编写于 作者: C cc 提交者: GitHub

Refine dygraph qat, test=develop (#31680)

上级 4c0c55bb
......@@ -25,101 +25,99 @@ from paddle.fluid.executor import Executor
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D, BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm
from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D
from paddle.nn import BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm
from paddle.fluid.dygraph.nn import BatchNorm, Pool2D
from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU, Swish
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6
from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish
from paddle.fluid.log_helper import get_logger
from . import quant_nn
from .. import quantization_pass
from . import utils
__all__ = ['ImperativeQuantAware', 'ImperativeCalcOutScale']
__all__ = ['ImperativeQuantAware']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"softmax": [["X"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
}
class ImperativeQuantAware(object):
"""
Add the fake quant logic for given quantizable layers, namely add the quant_dequant
computational logic both for activation inputs and weight inputs.
Applying quantization aware training (QAT) to dgraph model.
"""
def __init__(self,
weight_bits=8,
activation_bits=8,
quantizable_layer_type=['Conv2D', 'Linear'],
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
quantizable_layer_type=['Conv2D', 'Linear'],
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
act_quantize_layer=None):
r"""
"""
The constructor for ImperativeQuantAware.
Args:
weight_bits(int): quantization bit number for weights,
whereas the bias is not quantized.
activation_bits(int): quantization bit number for activations.
quantizable_layer_type(list[str]): List the type of layers that
will be quantized. Default is ['Conv2D', 'Linear'].
The quantizable_op_type in QuantizationFreezePass and
ConvertToInt8Pass must be the same as this.
weight_quantize_type(str): quantization type for weights,
which supports 'abs_max' now. The 'moving_average_abs_max'
usually is not used for weights, since weights are fixed once the
model is well trained.
usually is not used for weights, since weights are fixed
once the model is well trained.
activation_quantize_type(str): quantization type for activations,
which supports 'abs_max' and 'moving_average_abs_max' now.
If using 'abs_max' mode, the quantization scale will be calculated
dynamically each step in both training and testing period. If using
'moving_average_abs_max', the static quantization scale will be calculated
during training and used in inference.
moving_rate(float): the parameter for 'moving_average_abs_max' quantization.
quantizable_layer_type(list[str]): List the type of layers that will be quantized.
Default is ['Conv2D', 'Linear']. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess
weight before quantization. Using this can quickly test if user's
preprocess method works or not. The input is non-quantized
weight and function returns processed weight to be quantized.
If None, the weight will be quantized directly. Default is None.
act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess
activation before quantization. Using this can quickly test if user's
preprocess method works or not. The input is non-quantized
activation and function returns processed activation to be quantized.
If None, the activation will be quantized directly. Default is None.
weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize weight.
If using 'abs_max' mode, the quantization scale will be
calculated dynamically each step in both training and testing
period. If using 'moving_average_abs_max', the static
quantization scale will be calculated during training and
used in inference.
weight_bits(int): quantization bit number for weights,
whereas the bias is not quantized.
activation_bits(int): quantization bit number for activations.
moving_rate(float): the parameter for 'moving_average_abs_max'
quantization.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle
Layer that defines how to preprocess weight before quantization.
Using this can quickly test if user's preprocess method works
or not. The input is non-quantized weight and function returns
processed weight to be quantized.
If None, the weight will be quantized directly.
Default is None.
act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer
that defines how to preprocess activation before quantization.
Using this can quickly test if user's preprocess method works
or not. The input is non-quantized activation and function returns
processed activation to be quantized.
If None, the activation will be quantized directly.
Default is None.
weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that
defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
In this layer, user should both define quantization method and
dequantization method, that is, the function's input is non-quantized
weight and returns dequantized weight. If None, will use
quantization op defined by 'weight_quantize_type'. Default is None.
act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize activation.
weight and returns dequantized weight.
If None, will use uantization op defined by 'weight_quantize_type'.
Default is None.
act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines
how to quantize activation.
Using this can quickly test if user's quantization method works or not.
In this layer, user should both define quantization method and
dequantization method, that is, the function's input is non-quantized
activation and returns dequantized activation. If None, will use
quantization op defined by 'activation_quantize_type'. Default is None.
activation and returns dequantized activation.
If None, will use quantization op defined by 'activation_quantize_type'.
Default is None.
Note:
If user sets attribute 'skip_quant' to a Layer that support dynamic quantization and sets
it to true, the layer would not be quantized during training. If this attribute is not sets
or the attribute is false, the Layer would be qunatized in training.
If user sets attribute 'skip_quant' to a Layer that support dynamic
quantization and sets it to true, the layer would not be quantized
during training. If this attribute is not sets or the attribute is
false, the Layer would be qunatized in training.
Examples 1:
.. code-block:: python
......@@ -196,141 +194,175 @@ class ImperativeQuantAware(object):
model_path="./imperative_model_qat")
"""
super(ImperativeQuantAware, self).__init__()
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._moving_rate = moving_rate
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._weight_pre_layer = weight_preprocess_layer
self._act_pre_layer = act_preprocess_layer
self._weight_quant_layer = weight_quantize_layer
self._act_quant_layer = act_quantize_layer
self._out_scale = ImperativeCalcOutScale()
t_check = lambda method: method is None or issubclass(method, dygraph.layers.Layer)
assert t_check(
self._weight_pre_layer), "weight_preprocess should be nn.Layer"
assert t_check(self._act_pre_layer), "act_preprocess should be nn.Layer"
assert t_check(
self._weight_quant_layer), "weight_quantize should be nn.Layer"
assert t_check(self._act_quant_layer), "act_quantize should be nn.Layer"
quant_type = {
'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max'
}
assert activation_quantize_type != 'channel_wise_abs_max', \
"The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be "
"'abs_max' or 'moving_average_abs_max' now." %
(str(activation_quantize_type)))
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now."
% (str(weight_quantize_type)))
self._quant_layers_map = {
'Conv2D': Conv2D,
'Linear': Linear,
'Pool2D': Pool2D,
'ReLU': ReLU,
'LeakyReLU': LeakyReLU,
'ReLU6': ReLU6,
'Softmax': Softmax,
'Tanh': Tanh,
'Swish': Swish
kwargs = {
"quantizable_layer_type": quantizable_layer_type,
"weight_quantize_type": weight_quantize_type,
"activation_quantize_type": activation_quantize_type,
"weight_bits": weight_bits,
"activation_bits": activation_bits,
"moving_rate": moving_rate,
"weight_preprocess_layer": weight_preprocess_layer,
"act_preprocess_layer": act_preprocess_layer,
"weight_quantize_layer": weight_quantize_layer,
"act_quantize_layer": act_quantize_layer
}
self._quantizable_layer_type = tuple(
self._quant_layers_map[layer]
if layer in self._quant_layers_map else layer
for layer in quantizable_layer_type)
for layer in self._quantizable_layer_type:
assert not isinstance(
layer, str), "{} is unspported to be quantized.".format(layer)
self._quantize_inputs = ImperativeQuantizeInputs(**kwargs)
self._calc_output_scale = ImperativeCalcOutputScale()
def quantize(self, model):
"""
According to weights' and activations' quantization types, the model will be added some fake
quant ops, such as fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_abs_max
and so on. At the same time, the out_scale value of outputs would be calculated.
According to weights' and activations' quantization types,
the model will be added some fake quant ops, such as
fake_quantize_dequantize_moving_average_abs_max,
fake_quantize_dequantize_abs_max and so on. At the same time,
the out_scale value of outputs would be calculated.
Args:
model(fluid.dygraph.Layer): the model to be quantized.
Returns:
None
"""
assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."
self._quantize_inputs.apply(model)
self._calc_output_scale.apply(model)
def save_quantized_model(self, layer, path, input_spec=None, **config):
self._calc_output_scale.save_quantized_model(layer, path, input_spec,
**config)
class ImperativeQuantizeInputs(object):
"""
Based on the input params, add the quant_dequant computational
logic both for activation inputs and weight inputs.
"""
def __init__(self,
quantizable_layer_type=['Conv2D', 'Linear'],
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
act_quantize_layer=None):
"""
The constructor for ImperativeQuantizeInputs.
Please refer to the args of ImperativeQuantAware.
"""
super(ImperativeQuantizeInputs, self).__init__()
self._quantizable_layer_type = tuple(
utils._quant_layers_map[layer]
if layer in utils._quant_layers_map else layer
for layer in quantizable_layer_type)
for layer in self._quantizable_layer_type:
assert not isinstance(layer, str), \
"%s is unspported to be quantized." % layer
quantize_type = {
'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max'
}
assert weight_quantize_type in quantize_type, \
"Unsupported weight_quantize_type: %s. It can only " \
"be abs_max or moving_average_abs_max or " \
"channel_wise_abs_max." % weight_quantize_type
assert activation_quantize_type != 'channel_wise_abs_max' \
and activation_quantize_type in quantize_type, \
"Unsupported activation_quantize_type: %s. It can " \
"only be abs_max or moving_average_abs_max now." \
% activation_quantize_type
bits_check = lambda bits: isinstance(bits, int) \
and bits >= 0 and bits <= 16
assert bits_check(weight_bits), \
"weight_bits should be 1, 2,... or 16."
assert bits_check(activation_bits), \
"activation_bits should be 1, 2,... or 16."
layer_check = lambda method: method is None or \
issubclass(method, dygraph.layers.Layer)
assert layer_check(weight_preprocess_layer), \
"weight_preprocess should be nn.Layer."
assert layer_check(act_preprocess_layer), \
"act_preprocess should be nn.Layer."
assert layer_check(weight_quantize_layer), \
"weight_quantize should be nn.Layer."
assert layer_check(act_quantize_layer), \
"act_quantize should be nn.Layer."
self._kwargs = {
"weight_quantize_type": weight_quantize_type,
"activation_quantize_type": activation_quantize_type,
"weight_bits": weight_bits,
"activation_bits": activation_bits,
"moving_rate": moving_rate,
"weight_pre_layer": weight_preprocess_layer,
"act_pre_layer": act_preprocess_layer,
"weight_quant_layer": weight_quantize_layer,
"act_quant_layer": act_quantize_layer
}
def apply(self, model):
assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."
for name, layer in model.named_sublayers():
if not isinstance(layer, self._quantizable_layer_type):
continue
if hasattr(layer, "skip_quant") and layer.skip_quant == True:
if not isinstance(layer, self._quantizable_layer_type) \
or (hasattr(layer, "skip_quant") \
and layer.skip_quant == True):
continue
# TODO(jc): optimize this module
last_idx = 0
idx = 0
obj = model
parent = model
while idx < len(name):
if (name[idx] == '.'):
if hasattr(parent, name[last_idx:idx]):
if hasattr(obj, name[last_idx:idx]):
obj = getattr(obj, name[last_idx:idx])
parent = obj
last_idx = idx + 1
idx += 1
target = name[last_idx:idx]
quant_layer = self._get_quantized_counterpart(layer)
quant_layer = self._get_quantized_layer(layer)
setattr(quant_layer, "layer_name", layer.full_name())
setattr(obj, target, quant_layer)
self._out_scale.calc_out_scale(model)
def _get_quantized_counterpart(self, layer):
quant_layers = tuple(self._quant_layers_map.values())
quantized_counterpart = tuple('Quantized' + k
for k in self._quant_layers_map.keys())
predicate = lambda value: isinstance(layer, value)
index_generator = (i for i, v in enumerate(quant_layers)
if predicate(v))
try:
index = next(index_generator)
except StopIteration:
_logger.fatal("The layer {} is unsupported to be quantized.".format(
layer.full_name()))
sys.exit(-1)
def _get_quantized_layer(self, layer):
quant_layer_name = None
for key, value in utils._quant_layers_map.items():
if isinstance(layer, value):
quant_layer_name = 'Quantized' + key
break
assert quant_layer_name is not None, \
"The layer %s is unsupported to be quantized." \
% layer.full_name()
layer_with_weight = ['QuantizedConv2D', 'QuantizedLinear']
if quantized_counterpart[index] not in layer_with_weight:
quant_layer_class_name = 'QuantizedNoweightLayer'
else:
quant_layer_class_name = quantized_counterpart[index]
quantized_layer = quant_nn.__dict__[quant_layer_class_name](
layer, self._weight_bits, self._activation_bits, self._moving_rate,
self._weight_quantize_type, self._activation_quantize_type,
self._weight_pre_layer, self._act_pre_layer,
self._weight_quant_layer, self._act_quant_layer)
return quantized_layer
if quant_layer_name not in layer_with_weight:
quant_layer_name = 'QuantizedNoweightLayer'
def save_quantized_model(self, layer, path, input_spec=None, **config):
self._out_scale.save_quantized_model(layer, path, input_spec, **config)
return quant_nn.__dict__[quant_layer_name](layer, **self._kwargs)
class ImperativeCalcOutScale(object):
class ImperativeCalcOutputScale(object):
def __init__(self, moving_rate=0.9):
"""
Add the logic of calculating and setting output quantization scales of some layers.
These output quantization scales may be used by tensorRT or some other inference engines.
Add the logic of calculating and setting output scales of some layers.
Args:
moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
moving_rate(float): The decay coefficient of moving average.
The default value is 0.9.
"""
super(ImperativeCalcOutScale, self).__init__()
super(ImperativeCalcOutputScale, self).__init__()
self._moving_rate = moving_rate
self._out_scale_layer_type_list = (
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU,
......@@ -339,83 +371,22 @@ class ImperativeCalcOutScale(object):
self._register_hook_handle_list = []
self._out_scale_dict = collections.OrderedDict()
# Determine whether layer supports calculation out_scale
def _is_matched_layer(self, layer):
if not isinstance(layer, self._out_scale_layer_type_list):
if 'quantized_' not in layer.full_name():
return False
return True
# When inferenc model is saved, the logic in hook would not be executed
# in program translation, so that some parameters can not created in
# __init__, which would cause the model to fail to save. Therefore, the
# parameters creation in the hook is advanced to be exected outside the hook.
def _add_new_parameters(self, layer, name=None):
dtype = layer._dtype if layer._dtype is not None else "float32"
if dtype not in ["float32", "float64"]:
return
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
layer._quant_out_scale = layer.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
layer._quant_out_scale.stop_gradient = True
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_state = layer.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
layer._quant_out_state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_accum = layer.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
layer._quant_out_accum.stop_gradient = True
# Judge whether the op in program matches the Layer in dynamic model
def _is_op_matched(self, layer_name, op, block):
output_var_names = quantization_pass._get_op_output_var_names(op)
for output_var_name in output_var_names:
output_var_tensor = block.var(output_var_name)
if output_var_tensor.dtype not in [
core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32
]:
return False
# Because the naming styles of static and dynamic graph are different,
# in order to avoid mistakes, we unify the name here.
op_type = output_var_names[0].split(".")[0]
op_type = op_type.rsplit("_", 1)[0]
if op_type == 'depthwise_conv2d':
op_type = 'conv2d'
if 'prelu' in op_type:
op_type = op_type.replace('prelu', 'p_re_lu')
if 'relu' in op_type:
op_type = op_type.replace('relu', 're_lu')
return op_type in layer_name
def calc_out_scale(self, model):
def apply(self, model):
"""
Insert the `moving_average_abs_max_scale` op to calculate output scale of Specific layers in model.
Insert the `moving_average_abs_max_scale` op to calculate output
scale of specific layers in model.
Args:
model(fluid.dygraph.Layer): The target model which would be calculate the output quantization scale.
model(fluid.dygraph.Layer): The target model which would be
calculate the output quantization scale.
Returns:
None
"""
assert isinstance(
model, dygraph.Layer), "model must be the instance of dygraph.Layer"
assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."
for _, layer in model.named_sublayers():
if self._is_matched_layer(layer):
if self._is_target_layer(layer):
self._add_new_parameters(layer)
forward_post_hook_handle = layer.register_forward_post_hook(
self._forward_post_hook)
......@@ -459,7 +430,7 @@ class ImperativeCalcOutScale(object):
.numpy())
else:
for _, sub_layer in self._layer.named_sublayers():
if self._is_matched_layer(sub_layer):
if self._is_target_layer(sub_layer):
layer_name = sub_layer.full_name()
if hasattr(sub_layer, "layer_name"):
layer_name = sub_layer.layer_name
......@@ -510,7 +481,7 @@ class ImperativeCalcOutScale(object):
forward_op = None
for block in inference_program.blocks:
for op in block.ops:
if op.type in _op_real_in_out_name:
if op.type in utils._op_real_in_out_name:
if op_count > len(ops_list):
warnings.warn(
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model"
......@@ -567,6 +538,66 @@ class ImperativeCalcOutScale(object):
if is_dynamic_mode:
paddle.disable_static()
def _is_target_layer(self, layer):
return isinstance(layer, self._out_scale_layer_type_list) \
or 'quantized_' in layer.full_name()
# When inferenc model is saved, the logic in hook would not be executed
# in program translation, so that some parameters can not created in
# __init__, which would cause the model to fail to save. Therefore, the
# parameters creation in the hook is advanced to be exected outside the hook.
def _add_new_parameters(self, layer, name=None):
dtype = layer._dtype if layer._dtype is not None else "float32"
if dtype not in ["float32", "float64"]:
return
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
layer._quant_out_scale = layer.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
layer._quant_out_scale.stop_gradient = True
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_state = layer.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
layer._quant_out_state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
layer._quant_out_accum = layer.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
layer._quant_out_accum.stop_gradient = True
# Judge whether the op in program matches the Layer in dynamic model
def _is_op_matched(self, layer_name, op, block):
output_var_names = quantization_pass._get_op_output_var_names(op)
for output_var_name in output_var_names:
output_var_tensor = block.var(output_var_name)
if output_var_tensor.dtype not in [
core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32
]:
return False
# Because the naming styles of static and dynamic graph are different,
# in order to avoid mistakes, we unify the name here.
op_type = output_var_names[0].split(".")[0]
op_type = op_type.rsplit("_", 1)[0]
if op_type == 'depthwise_conv2d':
op_type = 'conv2d'
if 'prelu' in op_type:
op_type = op_type.replace('prelu', 'p_re_lu')
if 'relu' in op_type:
op_type = op_type.replace('relu', 're_lu')
return op_type in layer_name
def _forward_post_hook(self, layer, input, output):
assert isinstance(
output, (core.VarBase, framework.Variable)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.nn import Linear, Conv2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6
from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"softmax": [["X"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
}
_quant_layers_map = {
'Conv2D': Conv2D,
'Linear': Linear,
'Pool2D': Pool2D,
'ReLU': ReLU,
'LeakyReLU': LeakyReLU,
'ReLU6': ReLU6,
'Softmax': Softmax,
'Tanh': Tanh,
'Swish': Swish
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册