diff --git a/x2paddle/core/fluid_code.py b/x2paddle/core/fluid_code.py index 1e307767c7351f2f33b3ef8e591737933d295b85..d341585ad55722317e92e47a3bb6b07e5fbe2501 100644 --- a/x2paddle/core/fluid_code.py +++ b/x2paddle/core/fluid_code.py @@ -35,25 +35,30 @@ class Layer(object): if isinstance(self.inputs, list): in_list = "[" for input in self.inputs: - assert isinstance( - input, GraphNode), "Type of input should be GraphNode" - if hasattr(input, "index"): - in_list += (input.layer_name + "[{}]".format(input.index) + - ", ") + if isinstance(input, GraphNode): + if hasattr(input, "index"): + in_list += (input.layer_name + "[{}]".format(input.index) + ", ") + else: + in_list += (input.layer_name + ", ") + elif isinstance(input, str): + in_list += (input + ", ") else: - in_list += (input.layer_name + ", ") + raise Exception("Element of inputs should GraphNode or String") in_list = in_list.strip(", ") + "], " layer_code += in_list elif isinstance(self.inputs, dict): for key, input in self.inputs.items(): - assert isinstance( - input, GraphNode), "Type of input should be GraphNode" - if hasattr(input, "index"): - layer_code = layer_code + key + "={}, ".format( - input.layer_name + "[{}]".format(input.index)) + if isinstance(input, GraphNode): + if hasattr(input, "index"): + layer_code = layer_code + key + "={}, ".format( + input.layer_name + "[{}]".format(input.index)) + else: + layer_code = layer_code + key + "={}, ".format( + input.layer_name) + elif isinstance(input, str): + layer_code = layer_code + key + "={}, ".format(input) else: - layer_code = layer_code + key + "={}, ".format( - input.layer_name) + raise Exception("Element of inputs should GraphNode or String") elif isinstance(self.inputs, GraphNode): if hasattr(self.inputs, "index"): layer_code += (self.inputs.layer_name +