提交 97624cf2 编写于 作者: L like15

fix: Extract Tensors from `params` instead of `state_dict` to be compatible with PyTorch 1.4

上级 9e6ee11c
......@@ -26,15 +26,15 @@ from torch.onnx.utils import _node_getitem
from py_proto import mace_pb2
from transform import base_converter
from transform.transformer import Transformer
from transform.base_converter import PoolingType
from transform.base_converter import ActivationType
from transform.base_converter import EltwiseType
from transform.base_converter import FrameworkType
from transform.base_converter import MaceOp
from transform.base_converter import MaceKeyword
from transform.base_converter import ConverterUtil
from transform.base_converter import RoundMode
from transform.base_converter import DataFormat
from transform.base_converter import MaceKeyword
from transform.base_converter import MaceOp
from transform.base_converter import PoolingType
from transform.base_converter import RoundMode
from utils.util import mace_check
......@@ -48,7 +48,11 @@ def _model_to_graph(model, args):
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
graph = _propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), False, propagate)
return graph
input_and_param_names = [val.debugName() for val in graph.inputs()]
param_names = input_and_param_names[-len(params):]
params = [elem.detach() for elem in params]
params_dict = dict(zip(param_names, params))
return graph, params_dict
class ValueType(object):
......@@ -172,26 +176,8 @@ class PytorchConverter(base_converter.ConverterInterface):
else:
dummy_input = dummy_input + (torch.randn(in_node.shape),)
graph = _model_to_graph(self._loaded_model, dummy_input)
state_dict = self._loaded_model.state_dict()
'''
num_batches_tracked in state_dict for BN layer is not used by any node,
delete them to avoid mistake name.
Maybe there are more unused keys in the future.
'''
unneeded_keys = []
for key in state_dict.keys():
if re.match(r'.*\.num_batches_tracked', key):
unneeded_keys.append(key)
for key in unneeded_keys:
del state_dict[key]
graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
for i, inp in enumerate(graph.inputs()):
if i >= user_input_num:
inp.setDebugName(param_names[i - user_input_num])
return graph
graph, params_dict = _model_to_graph(self._loaded_model, dummy_input)
return graph, params_dict
def init_output_shape_cache(self):
self._output_shape_cache = {}
......@@ -240,16 +226,13 @@ class PytorchConverter(base_converter.ConverterInterface):
}
self._loaded_model = torch.jit.load(src_model_file)
self._loaded_model.eval()
self._graph = self.model_to_graph()
self._graph, self._params_dict = self.model_to_graph()
self._output_node_name = list(self._graph.outputs())[0].debugName()
self._output_value_type = list(self._graph.outputs())[0].type()
if not isinstance(
self._output_value_type,
(ValueType.TensorType, ValueType.ListType,
ValueType.TupleType)):
print('return type {} not supported'.format(
self._output_value_type))
sys.exit(1)
mace_check(isinstance(self._output_value_type, (ValueType.TensorType,
ValueType.ListType, ValueType.TupleType)),
'return type {} not supported'.format(
self._output_value_type))
self._node_map = {}
self.init_output_shape_cache()
......@@ -405,7 +388,7 @@ class PytorchConverter(base_converter.ConverterInterface):
# OIHW
key = inputs_vals[1].debugName()
filter_shape = self._loaded_model.state_dict()[key].shape
filter_shape = self._params_dict[key].shape
filter_shape = [int(elem) for elem in filter_shape] # Size -> list
mace_check(len(filter_shape) == 4,
'MACE only supports 2D Conv, current Conv is {}D'.format(
......@@ -446,7 +429,7 @@ class PytorchConverter(base_converter.ConverterInterface):
dilation_arg.ints.extend(mace_dilations)
filter_tensor_name = inputs_vals[ConvParamIdx.weight_idx].debugName()
filter_data = self._loaded_model.state_dict()[filter_tensor_name]
filter_data = self._params_dict[filter_tensor_name]
if is_depthwise:
# C1HW => 1CHW
filter_data = filter_data.permute((1, 0, 2, 3))
......@@ -458,7 +441,7 @@ class PytorchConverter(base_converter.ConverterInterface):
has_bias = (not isinstance(bias_val.type(), ValueType.NoneType))
if has_bias:
bias_tensor_name = inputs_vals[ConvParamIdx.bias_idx].debugName()
bias_data = self._loaded_model.state_dict()[bias_tensor_name]
bias_data = self._params_dict[bias_tensor_name]
bias_data = bias_data.numpy()
self.add_tensor_and_shape(bias_tensor_name, bias_data.shape,
mace_pb2.DT_FLOAT, bias_data)
......@@ -476,7 +459,7 @@ class PytorchConverter(base_converter.ConverterInterface):
mace_check(is_training == 0,
"Only support batch normalization with is_training = 0,"
" but got {}".format(is_training))
state_dict = self._loaded_model.state_dict()
state_dict = self._params_dict
gamma_key = inputs_vals[BNParamIdx.weight_idx].debugName()
gamma_value = state_dict[gamma_key].numpy().astype(np.float32)
beta_key = inputs_vals[BNParamIdx.bias_idx].debugName()
......@@ -515,15 +498,13 @@ class PytorchConverter(base_converter.ConverterInterface):
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_activation_type_str
if (abs(max_val - 6.) < 1e-8):
type_arg.s = six.b(self.activation_type['ReLU6'].name)
mace_check(abs(max_val - 6.) < 1e-8,
'only support converting hardtanh_ to ReLU6 yet')
type_arg.s = six.b(self.activation_type['ReLU6'].name)
limit_arg = op.arg.add()
limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0
else:
print('only support converting hardtanh_ to ReLU6 yet')
sys.exit(1)
limit_arg = op.arg.add()
limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0
self.infer_shape_general(op)
def convert_add(self, node, inputs_vals, outputs_vals):
......@@ -632,14 +613,14 @@ class PytorchConverter(base_converter.ConverterInterface):
def get_weight_from_node(self, node):
input_list = list(node.inputs())
key = input_list[0].debugName()
return self._loaded_model.state_dict()[key]
return self._params_dict[key]
def is_trans_fc_w(self, node):
in_vals = list(node.inputs())
mace_check(len(in_vals) == 1, 't() must have 1 input')
in_name = in_vals[0].debugName()
if in_name in self._loaded_model.state_dict() and \
len(self._loaded_model.state_dict()[in_name].shape) == 2:
if in_name in self._params_dict and \
len(self._params_dict[in_name].shape) == 2:
return True
return False
......@@ -662,7 +643,7 @@ class PytorchConverter(base_converter.ConverterInterface):
alpha_type = inputs_vals[AddmmParamIdx.alpha_idx].type()
is_alpha_fc = isinstance(alpha_type, ValueType.IntType) and alpha == 1
is_bias_w = inputs_vals[AddmmParamIdx.bias_idx].debugName() in \
self._loaded_model.state_dict()
self._params_dict
beta = inputs_vals[AddmmParamIdx.beta_idx].node()['value']
beta_type = inputs_vals[AddmmParamIdx.beta_idx].type()
is_beta_fc = isinstance(beta_type, ValueType.IntType) and beta == 1
......@@ -703,7 +684,7 @@ class PytorchConverter(base_converter.ConverterInterface):
opb.type = MaceOp.BiasAdd.name
bias_tensor_name = opb.name + '_bias'
key = inputs_vals[AddmmParamIdx.bias_idx].debugName()
bias_data = self._loaded_model.state_dict()[key]
bias_data = self._params_dict[key]
bias_data = bias_data.numpy()
self.add_tensor_and_shape(bias_tensor_name, bias_data.reshape(
-1).shape, mace_pb2.DT_FLOAT, bias_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册