提交 7ab3f5c3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1086 fix some bug in quant debug

Merge pull request !1086 from SanjayChan/05showcase
......@@ -15,7 +15,7 @@
"""grad impl."""
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
grad_math_ops, grad_nn_ops, grad_other_ops
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
from .grad_base import get_bprop_fn
__all__ = ['get_bprop_fn']
......@@ -223,8 +223,8 @@ class BatchNormFold(PrimitiveWithInfer):
Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in
float32 else 1e-3. Default: 1e-12.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5.
is_training (bool): In training mode set True, else set False. Default: True.
freeze_bn (int): Delay in steps at which computation switches from regular batch
norm to frozen mean and std. Default: 0.
......@@ -247,7 +247,7 @@ class BatchNormFold(PrimitiveWithInfer):
channel = 1
@prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0):
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
......@@ -277,7 +277,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
channel = 1
@prim_attr_register
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0):
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init BatchNormGrad layer"""
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
......
......@@ -32,6 +32,7 @@ __all__ = ["build_train_network"]
class OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16"
def __init__(self, op):
super(OutputTo16, self).__init__(auto_prefix=False)
self._op = op
......@@ -53,7 +54,7 @@ def _do_keep_batchnorm_fp32(network):
change = True
else:
_do_keep_batchnorm_fp32(subcell)
if isinstance(network, nn.SequentialCell) and change:
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
......@@ -72,7 +73,7 @@ def _check_kwargs(key_words):
"""Check kwargs."""
for arg in key_words:
if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
raise ValueError(f"Unsupported arg '{arg}'")
raise ValueError(f"Unsupported arg '{arg}'")
if 'cast_model_type' in key_words:
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
......
......@@ -18,4 +18,16 @@ set -e
# Usage : get_shape_from_ir.sh ir_file
cat "$1" | perl -p -e 's/\n/NEWLINE/' | sed 's/NEWLINE :/:/g' | sed 's/Tensor NEWLINEshape//g' | perl -p -e 's/NEWLINE/\n/g' | perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' | perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' | tr -d '()' | awk '/subgraph/{p=1;next}{if(p){print}}'| awk '/return/{p=1;next}{if(!p){print}}' | sed '/^$/d' | awk -F'\t' '{print $1"\t"$2"\t"$4"\t"$3}'
cat "$1" | perl -p -e 's/\n/NEWLINE/' \
| sed 's/NEWLINE :/:/g' \
| sed 's/Tensor NEWLINEshape//g' \
| perl -p -e 's/NEWLINE/\n/g' \
| perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' \
| perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' \
| perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' \
| perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' \
| tr -d '()' \
| awk '/subgraph/{p=1;next}{if(p){print}}'\
| awk '/return/{p=1;next}{if(!p){print}}' \
| sed '/^$/d' \
| awk -F'\t' '{print $1"\t"$2"\t"$4}'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册