From 97624cf28363b741e6cd9b7959b04f72a83f138f Mon Sep 17 00:00:00 2001 From: like15 Date: Tue, 27 Oct 2020 11:16:21 +0800 Subject: [PATCH] fix: Extract Tensors from `params` instead of `state_dict` to be compatible with PyTorch 1.4 --- tools/python/transform/pytorch_converter.py | 81 ++++++++------------- 1 file changed, 31 insertions(+), 50 deletions(-) diff --git a/tools/python/transform/pytorch_converter.py b/tools/python/transform/pytorch_converter.py index 6996d5f1..019cce83 100644 --- a/tools/python/transform/pytorch_converter.py +++ b/tools/python/transform/pytorch_converter.py @@ -26,15 +26,15 @@ from torch.onnx.utils import _node_getitem from py_proto import mace_pb2 from transform import base_converter from transform.transformer import Transformer -from transform.base_converter import PoolingType from transform.base_converter import ActivationType from transform.base_converter import EltwiseType from transform.base_converter import FrameworkType -from transform.base_converter import MaceOp -from transform.base_converter import MaceKeyword from transform.base_converter import ConverterUtil -from transform.base_converter import RoundMode from transform.base_converter import DataFormat +from transform.base_converter import MaceKeyword +from transform.base_converter import MaceOp +from transform.base_converter import PoolingType +from transform.base_converter import RoundMode from utils.util import mace_check @@ -48,7 +48,11 @@ def _model_to_graph(model, args): in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params)) graph = _propagate_and_assign_input_shapes( method_graph, tuple(in_vars), False, propagate) - return graph + input_and_param_names = [val.debugName() for val in graph.inputs()] + param_names = input_and_param_names[-len(params):] + params = [elem.detach() for elem in params] + params_dict = dict(zip(param_names, params)) + return graph, params_dict class ValueType(object): @@ -172,26 +176,8 @@ class PytorchConverter(base_converter.ConverterInterface): else: dummy_input = dummy_input + (torch.randn(in_node.shape),) - graph = _model_to_graph(self._loaded_model, dummy_input) - state_dict = self._loaded_model.state_dict() - ''' - num_batches_tracked in state_dict for BN layer is not used by any node, - delete them to avoid mistake name. - Maybe there are more unused keys in the future. - ''' - unneeded_keys = [] - for key in state_dict.keys(): - if re.match(r'.*\.num_batches_tracked', key): - unneeded_keys.append(key) - for key in unneeded_keys: - del state_dict[key] - graph_inputs = list(graph.inputs()) - user_input_num = len(graph_inputs) - len(state_dict) - param_names = list(state_dict.keys()) - for i, inp in enumerate(graph.inputs()): - if i >= user_input_num: - inp.setDebugName(param_names[i - user_input_num]) - return graph + graph, params_dict = _model_to_graph(self._loaded_model, dummy_input) + return graph, params_dict def init_output_shape_cache(self): self._output_shape_cache = {} @@ -240,16 +226,13 @@ class PytorchConverter(base_converter.ConverterInterface): } self._loaded_model = torch.jit.load(src_model_file) self._loaded_model.eval() - self._graph = self.model_to_graph() + self._graph, self._params_dict = self.model_to_graph() self._output_node_name = list(self._graph.outputs())[0].debugName() self._output_value_type = list(self._graph.outputs())[0].type() - if not isinstance( - self._output_value_type, - (ValueType.TensorType, ValueType.ListType, - ValueType.TupleType)): - print('return type {} not supported'.format( - self._output_value_type)) - sys.exit(1) + mace_check(isinstance(self._output_value_type, (ValueType.TensorType, + ValueType.ListType, ValueType.TupleType)), + 'return type {} not supported'.format( + self._output_value_type)) self._node_map = {} self.init_output_shape_cache() @@ -405,7 +388,7 @@ class PytorchConverter(base_converter.ConverterInterface): # OIHW key = inputs_vals[1].debugName() - filter_shape = self._loaded_model.state_dict()[key].shape + filter_shape = self._params_dict[key].shape filter_shape = [int(elem) for elem in filter_shape] # Size -> list mace_check(len(filter_shape) == 4, 'MACE only supports 2D Conv, current Conv is {}D'.format( @@ -446,7 +429,7 @@ class PytorchConverter(base_converter.ConverterInterface): dilation_arg.ints.extend(mace_dilations) filter_tensor_name = inputs_vals[ConvParamIdx.weight_idx].debugName() - filter_data = self._loaded_model.state_dict()[filter_tensor_name] + filter_data = self._params_dict[filter_tensor_name] if is_depthwise: # C1HW => 1CHW filter_data = filter_data.permute((1, 0, 2, 3)) @@ -458,7 +441,7 @@ class PytorchConverter(base_converter.ConverterInterface): has_bias = (not isinstance(bias_val.type(), ValueType.NoneType)) if has_bias: bias_tensor_name = inputs_vals[ConvParamIdx.bias_idx].debugName() - bias_data = self._loaded_model.state_dict()[bias_tensor_name] + bias_data = self._params_dict[bias_tensor_name] bias_data = bias_data.numpy() self.add_tensor_and_shape(bias_tensor_name, bias_data.shape, mace_pb2.DT_FLOAT, bias_data) @@ -476,7 +459,7 @@ class PytorchConverter(base_converter.ConverterInterface): mace_check(is_training == 0, "Only support batch normalization with is_training = 0," " but got {}".format(is_training)) - state_dict = self._loaded_model.state_dict() + state_dict = self._params_dict gamma_key = inputs_vals[BNParamIdx.weight_idx].debugName() gamma_value = state_dict[gamma_key].numpy().astype(np.float32) beta_key = inputs_vals[BNParamIdx.bias_idx].debugName() @@ -515,15 +498,13 @@ class PytorchConverter(base_converter.ConverterInterface): type_arg = op.arg.add() type_arg.name = MaceKeyword.mace_activation_type_str - if (abs(max_val - 6.) < 1e-8): - type_arg.s = six.b(self.activation_type['ReLU6'].name) + mace_check(abs(max_val - 6.) < 1e-8, + 'only support converting hardtanh_ to ReLU6 yet') + type_arg.s = six.b(self.activation_type['ReLU6'].name) - limit_arg = op.arg.add() - limit_arg.name = MaceKeyword.mace_activation_max_limit_str - limit_arg.f = 6.0 - else: - print('only support converting hardtanh_ to ReLU6 yet') - sys.exit(1) + limit_arg = op.arg.add() + limit_arg.name = MaceKeyword.mace_activation_max_limit_str + limit_arg.f = 6.0 self.infer_shape_general(op) def convert_add(self, node, inputs_vals, outputs_vals): @@ -632,14 +613,14 @@ class PytorchConverter(base_converter.ConverterInterface): def get_weight_from_node(self, node): input_list = list(node.inputs()) key = input_list[0].debugName() - return self._loaded_model.state_dict()[key] + return self._params_dict[key] def is_trans_fc_w(self, node): in_vals = list(node.inputs()) mace_check(len(in_vals) == 1, 't() must have 1 input') in_name = in_vals[0].debugName() - if in_name in self._loaded_model.state_dict() and \ - len(self._loaded_model.state_dict()[in_name].shape) == 2: + if in_name in self._params_dict and \ + len(self._params_dict[in_name].shape) == 2: return True return False @@ -662,7 +643,7 @@ class PytorchConverter(base_converter.ConverterInterface): alpha_type = inputs_vals[AddmmParamIdx.alpha_idx].type() is_alpha_fc = isinstance(alpha_type, ValueType.IntType) and alpha == 1 is_bias_w = inputs_vals[AddmmParamIdx.bias_idx].debugName() in \ - self._loaded_model.state_dict() + self._params_dict beta = inputs_vals[AddmmParamIdx.beta_idx].node()['value'] beta_type = inputs_vals[AddmmParamIdx.beta_idx].type() is_beta_fc = isinstance(beta_type, ValueType.IntType) and beta == 1 @@ -703,7 +684,7 @@ class PytorchConverter(base_converter.ConverterInterface): opb.type = MaceOp.BiasAdd.name bias_tensor_name = opb.name + '_bias' key = inputs_vals[AddmmParamIdx.bias_idx].debugName() - bias_data = self._loaded_model.state_dict()[key] + bias_data = self._params_dict[key] bias_data = bias_data.numpy() self.add_tensor_and_shape(bias_tensor_name, bias_data.reshape( -1).shape, mace_pb2.DT_FLOAT, bias_data) -- GitLab