未验证 提交 203fd120 编写于 作者: J Jason 提交者: GitHub

Merge pull request #725 from wjj19950828/add_conv3d_support

Add aten::_convolution for conv3d support
......@@ -1315,8 +1315,10 @@ def aten__convolution(mapper, graph, node):
weights = mapper.pytorch_params[inputs_name[1]]
if len(weights.shape) == 3:
op_name = name_generator("conv1d", mapper.nn_name2id)
else:
elif len(weights.shape) == 4:
op_name = name_generator("conv2d", mapper.nn_name2id)
else:
op_name = name_generator("conv3d", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
......@@ -1364,7 +1366,22 @@ def aten__convolution(mapper, graph, node):
else:
layer_attrs['in_channels'] = weights.shape[1] * mapper.attrs[
inputs_name[8]]
if len(weights.shape) == 4:
if len(weights.shape) == 3:
if mapper.attrs[inputs_name[6]]:
graph.add_layer(
"paddle.nn.Conv1DTranspose",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
else:
graph.add_layer(
"paddle.nn.Conv1D",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
elif len(weights.shape) == 4:
if mapper.attrs[inputs_name[6]]:
graph.add_layer(
"paddle.nn.Conv2DTranspose",
......@@ -1382,14 +1399,14 @@ def aten__convolution(mapper, graph, node):
else:
if mapper.attrs[inputs_name[6]]:
graph.add_layer(
"paddle.nn.Conv1DTranspose",
"paddle.nn.Conv3DTranspose",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
else:
graph.add_layer(
"paddle.nn.Conv1D",
"paddle.nn.Conv3D",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册