提交 864622bd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!574 Add parameter configuration

Merge pull request !574 from liubuyu/master
......@@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
auto prim = std::make_shared<Primitive>(kFusedMulAddNOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
inputs.push_back(mul->input(kMulInputNum - lossscale_input_index));
inputs.push_back(addn->input(1));
inputs.push_back(addn->input(2));
// scalar input should be 3rd input
inputs.push_back(mul->input(lossscale_input_index));
auto fusion_node = graph->NewCNode(inputs);
......@@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const {
VarPtr Z = std::make_shared<Var>();
VectorRef mul({prim::kPrimMul, X, Z});
VectorRef addn({prim::kPrimAddN, Y, mul});
VectorRef addn({prim::kPrimAddN, mul, Y});
return addn;
}
......@@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
if (addn == nullptr || addn->inputs().size() != kAddNInputNum) {
return nullptr;
}
auto mul_anf = addn->input(2);
auto mul_anf = addn->input(1);
if (mul_anf == nullptr) {
return nullptr;
}
......
......@@ -177,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((gradient, weight * weight_decay))
return op_add((weight * weight_decay, gradient))
return gradient
......
......@@ -62,6 +62,7 @@ class Model:
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
Examples:
>>> class Net(nn.Cell):
......@@ -96,7 +97,10 @@ class Model:
self._optimizer = optimizer
self._loss_scale_manager = None
self._loss_scale_manager_set = False
self._keep_bn_fp32 = True
self._check_kwargs(kwargs)
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
......@@ -112,7 +116,7 @@ class Model:
def _check_kwargs(self, kwargs):
for arg in kwargs:
if arg not in ['loss_scale_manager']:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
raise ValueError(f"Unsupport arg '{arg}'")
def _build_train_network(self):
......@@ -124,12 +128,14 @@ class Model:
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager)
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level)
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None
......
......@@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag):
@fns
def before(a, b):
res = mul(scalar, a)
res = addn((b, res))
res = addn((res, b))
return res
@fns
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册