提交 e3d4dc20 编写于 作者: W wjj19950828

Add conv3d support

上级 9f27c2cc
...@@ -1315,8 +1315,10 @@ def aten__convolution(mapper, graph, node): ...@@ -1315,8 +1315,10 @@ def aten__convolution(mapper, graph, node):
weights = mapper.pytorch_params[inputs_name[1]] weights = mapper.pytorch_params[inputs_name[1]]
if len(weights.shape) == 3: if len(weights.shape) == 3:
op_name = name_generator("conv1d", mapper.nn_name2id) op_name = name_generator("conv1d", mapper.nn_name2id)
else: elif len(weights.shape) == 4:
op_name = name_generator("conv2d", mapper.nn_name2id) op_name = name_generator("conv2d", mapper.nn_name2id)
else:
op_name = name_generator("conv3d", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name] layer_outputs = [op_name, output_name]
layer_inputs = {} layer_inputs = {}
...@@ -1364,7 +1366,22 @@ def aten__convolution(mapper, graph, node): ...@@ -1364,7 +1366,22 @@ def aten__convolution(mapper, graph, node):
else: else:
layer_attrs['in_channels'] = weights.shape[1] * mapper.attrs[ layer_attrs['in_channels'] = weights.shape[1] * mapper.attrs[
inputs_name[8]] inputs_name[8]]
if len(weights.shape) == 4: if len(weights.shape) == 3:
if mapper.attrs[inputs_name[6]]:
graph.add_layer(
"paddle.nn.Conv1DTranspose",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
else:
graph.add_layer(
"paddle.nn.Conv1D",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
elif len(weights.shape) == 4:
if mapper.attrs[inputs_name[6]]: if mapper.attrs[inputs_name[6]]:
graph.add_layer( graph.add_layer(
"paddle.nn.Conv2DTranspose", "paddle.nn.Conv2DTranspose",
...@@ -1382,14 +1399,14 @@ def aten__convolution(mapper, graph, node): ...@@ -1382,14 +1399,14 @@ def aten__convolution(mapper, graph, node):
else: else:
if mapper.attrs[inputs_name[6]]: if mapper.attrs[inputs_name[6]]:
graph.add_layer( graph.add_layer(
"paddle.nn.Conv1DTranspose", "paddle.nn.Conv3DTranspose",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) **layer_attrs)
else: else:
graph.add_layer( graph.add_layer(
"paddle.nn.Conv1D", "paddle.nn.Conv3D",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册