未验证 提交 18dd3aad 编写于 作者: C ceci3 提交者: GitHub

fix ofa (#1703)

上级 adb69ed6
...@@ -122,6 +122,14 @@ def _is_output_weight_ops(op, graph): ...@@ -122,6 +122,14 @@ def _is_output_weight_ops(op, graph):
return True return True
def if_is_bias(op, graph):
pre_ops = sorted(graph.pre_ops(op))
if 'conv' in pre_ops[0].type() and pre_ops[1].type() == "reshape2":
if pre_ops[1].inputs('X')[0]._var.persistable == True:
return True
return False
def check_search_space(graph): def check_search_space(graph):
""" Find the shortcut in the model and set same config for this situation. """ Find the shortcut in the model and set same config for this situation.
""" """
...@@ -139,7 +147,9 @@ def check_search_space(graph): ...@@ -139,7 +147,9 @@ def check_search_space(graph):
if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul': if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1] inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable): is_bias = if_is_bias(op, graph)
if ((not inp1._var.persistable) and
(not inp2._var.persistable)) and not is_bias:
# if one of two vars comes from input, # if one of two vars comes from input,
# then the two vars in this elementwise op should be all fixed # then the two vars in this elementwise op should be all fixed
if inp1.inputs() and inp2.inputs(): if inp1.inputs() and inp2.inputs():
...@@ -152,11 +162,11 @@ def check_search_space(graph): ...@@ -152,11 +162,11 @@ def check_search_space(graph):
fixed_by_input += pre_fixed_op_2 fixed_by_input += pre_fixed_op_2
if not pre_fixed_op_2: if not pre_fixed_op_2:
fixed_by_input += pre_fixed_op_1 fixed_by_input += pre_fixed_op_1
elif (not inp1.inputs() and inp2.inputs()) or ( elif (not inp1.inputs() and
inp1.inputs() and not inp2.inputs()): inp2.inputs()) or (inp1.inputs() and not inp2.inputs()):
pre_fixed_op = [] pre_fixed_op = []
inputs = inp1.inputs() if not inp2.inputs( inputs = inp1.inputs(
) else inp2.inputs() ) if not inp2.inputs() else inp2.inputs()
pre_fixed_op = _find_weight_ops(inputs[0], graph, pre_fixed_op = _find_weight_ops(inputs[0], graph,
pre_fixed_op) pre_fixed_op)
fixed_by_input += pre_fixed_op fixed_by_input += pre_fixed_op
...@@ -213,11 +223,13 @@ def broadcast_search_space(same_search_space, param2key, origin_config): ...@@ -213,11 +223,13 @@ def broadcast_search_space(same_search_space, param2key, origin_config):
if key in origin_config: if key in origin_config:
if 'expand_ratio' in origin_config[pre_key]: if 'expand_ratio' in origin_config[pre_key]:
origin_config[key].update({ origin_config[key].update({
'expand_ratio': origin_config[pre_key]['expand_ratio'] 'expand_ratio':
origin_config[pre_key]['expand_ratio']
}) })
elif 'channel' in origin_config[pre_key]: elif 'channel' in origin_config[pre_key]:
origin_config[key].update({ origin_config[key].update({
'channel': origin_config[pre_key]['channel'] 'channel':
origin_config[pre_key]['channel']
}) })
else: else:
# if the pre_key is removed from config for some reasons # if the pre_key is removed from config for some reasons
......
...@@ -1047,14 +1047,16 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D): ...@@ -1047,14 +1047,16 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D):
"Variance": [variance] "Variance": [variance]
} }
helper = paddle.fluid.layer_helper.LayerHelper('batch_norm') saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
reserve_space = self._helper.create_variable_for_type_inference(
dtype=self._helper.input_dtype(input), stop_gradient=True)
param_dtype = input.dtype if input.dtype != 'float16' else 'float32' batch_norm_out = (
saved_mean = helper.create_variable_for_type_inference( input if self._in_place else
dtype=param_dtype, stop_gradient=True) self._helper.create_variable_for_type_inference(self._dtype))
saved_variance = helper.create_variable_for_type_inference(
dtype=param_dtype, stop_gradient=True)
batch_norm_out = helper.create_variable_for_type_inference(input.dtype)
outputs = { outputs = {
"Y": [batch_norm_out], "Y": [batch_norm_out],
...@@ -1064,13 +1066,10 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D): ...@@ -1064,13 +1066,10 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D):
"SavedVariance": [saved_variance] "SavedVariance": [saved_variance]
} }
if self.training or trainable_statistics: if reserve_space is not None:
# reserve_space is only used for training.
reserve_space = helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True)
outputs["ReserveSpace"] = [reserve_space] outputs["ReserveSpace"] = [reserve_space]
helper.append_op( self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
self.cur_config = {'prune_dim': feature_dim} self.cur_config = {'prune_dim': feature_dim}
return batch_norm_out return batch_norm_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册