提交 fcaa378c 编写于 作者: L liuxiao93

fix bug for con1d with 3d input.

上级 7fbed0ce
......@@ -13,10 +13,13 @@
# limitations under the License.
# ============================================================================
"""conv"""
import numpy as np
from mindspore import log as logger
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator, Rel
from mindspore._checkparam import Validator
from mindspore._checkparam import check_bool, twice, check_int_positive
......@@ -254,6 +257,11 @@ class Conv2d(_Conv):
return s
@constexpr
def _check_input_3d(input_shape):
if len(input_shape) != 3:
raise ValueError(f"Input should be 3d, but got shape {input_shape}")
class Conv1d(_Conv):
r"""
1D convolution layer.
......@@ -359,6 +367,15 @@ class Conv1d(_Conv):
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
get_shape = P.Shape()
get_dtype = P.DType()
if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2)
weight_init = Tensor(weight_init_value, weight_init_dtype)
super(Conv1d, self).__init__(
in_channels,
......@@ -391,13 +408,13 @@ class Conv1d(_Conv):
def construct(self, x):
x_shape = self.shape(x)
if len(x_shape) == 3:
x = self.expand_dims(x, 2)
_check_input_3d(x_shape)
x = self.expand_dims(x, 2)
output = self.conv2d(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if len(x_shape) == 3:
output = self.squeeze(output)
output = self.squeeze(output)
return output
def extend_repr(self):
......@@ -669,6 +686,15 @@ class Conv1dTranspose(_Conv):
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
get_shape = P.Shape()
get_dtype = P.DType()
if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2)
weight_init = Tensor(weight_init_value, weight_init_dtype)
# out_channels and in_channels swap.
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
# then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
......@@ -733,8 +759,8 @@ class Conv1dTranspose(_Conv):
def construct(self, x):
x_shape = self.shape(x)
if len(x_shape) == 3:
x = self.expand_dims(x, 2)
_check_input_3d(x_shape)
x = self.expand_dims(x, 2)
n, _, h, w = self.shape(x)
......@@ -746,8 +772,7 @@ class Conv1dTranspose(_Conv):
if self.has_bias:
output = self.bias_add(output, self.bias)
if len(x_shape) == 3:
output = self.squeeze(output)
output = self.squeeze(output)
return output
def extend_repr(self):
......
......@@ -1690,7 +1690,9 @@ class L2Loss(PrimitiveWithInfer):
Set `input_x` as x and output as loss.
.. math::
loss = sum(x ** 2) / 2
loss = sum(x ** 2) / nelement(x)
:math:`nelement(x)` represents the number of `input_x`.
Inputs:
- **input_x** (Tensor) - A input Tensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册