提交 9cb71441 编写于 作者: B buxue

fix bugs of Acosh, TopK, ResizeNearestNeighbor, DepthwiseConv2dNative

上级 90dfbab3
...@@ -171,20 +171,17 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::s ...@@ -171,20 +171,17 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::s
MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size; MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
return nullptr; return nullptr;
} }
// get tensor buff size
size_t data_buff_size = 0;
size_t elements_num = IntToSize(tensor->ElementsNum()); size_t elements_num = IntToSize(tensor->ElementsNum());
if (elements_num > 0 && type_size > 0 && UINT_MAX / type_size >= elements_num) { if (UINT_MAX / type_size < elements_num) {
data_buff_size = elements_num * type_size; MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size
<< " overflowed UINT_MAX: " << UINT_MAX << ".";
return nullptr;
} }
// get tensor buff size
size_t data_buff_size = elements_num * type_size;
if (data_buff_size == 0) { if (data_buff_size == 0) {
if (elements_num > 0 && type_size > 0 && UINT_MAX / type_size < elements_num) { MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size
<< " overflowed UINT_MAX: " << UINT_MAX << ".";
} else {
MS_LOG(ERROR) << "The Me Tensor data buff size is 0.";
}
return nullptr;
} }
// create ge tensor // create ge tensor
auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format); auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
......
...@@ -56,7 +56,7 @@ class Momentum(Optimizer): ...@@ -56,7 +56,7 @@ class Momentum(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs: Outputs:
Tensor[bool], the value is True. tuple[bool], all elements are True.
Raises: Raises:
ValueError: If the momentum is less than 0.0. ValueError: If the momentum is less than 0.0.
......
...@@ -1885,6 +1885,11 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): ...@@ -1885,6 +1885,11 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, size, align_corners=False): def __init__(self, size, align_corners=False):
"""Init ResizeNearestNeighbor""" """Init ResizeNearestNeighbor"""
validator.check_value_type("size", size, [tuple, list], self.name)
validator.check_value_type("align_corners", align_corners, [bool], self.name)
validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name)
for i, value in enumerate(size):
validator.check_integer(f'{i}th value of size', value, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
def infer_shape(self, x): def infer_shape(self, x):
......
...@@ -1251,7 +1251,8 @@ class Acosh(PrimitiveWithInfer): ...@@ -1251,7 +1251,8 @@ class Acosh(PrimitiveWithInfer):
Compute inverse hyperbolic cosine of x element-wise. Compute inverse hyperbolic cosine of x element-wise.
Inputs: Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
and the data type of 'input_x' is number, the element in 'input_x' should be greater than or equal to 1.
Outputs: Outputs:
Tensor, has the same shape as `input_x`. Tensor, has the same shape as `input_x`.
......
...@@ -753,8 +753,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): ...@@ -753,8 +753,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
if self.stride[0] != self.stride[1]:
raise ValueError("The height and width of stride should be equal,"
f"but got height:{self.stride[0]}, width:{self.stride[1]}")
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name)
if self.dilation[0] != self.dilation[1]:
raise ValueError("The height and width of dilation should be equal,"
f"but got height:{self.dilation[0]}, width:{self.dilation[1]}")
self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1])) self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1]))
validator.check_value_type('pad', pad, (int,), self.name) validator.check_value_type('pad', pad, (int,), self.name)
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
...@@ -771,13 +778,11 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): ...@@ -771,13 +778,11 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
kernel_size_h = w_shape[2] kernel_size_n, _, kernel_size_h, kernel_size_w = w_shape
kernel_size_w = w_shape[3] _, _, stride_h, stride_w = self.stride
stride_h = self.stride[2] _, _, dilation_h, dilation_w = self.dilation
stride_w = self.stride[3] if kernel_size_n != 1:
dilation_h = self.dilation[2] raise ValueError(f"The batch of input weight should be 1, but got {kernel_size_n}")
dilation_w = self.dilation[3]
if self.pad_mode == "valid": if self.pad_mode == "valid":
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h) h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w) w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
...@@ -1198,8 +1203,8 @@ class TopK(PrimitiveWithInfer): ...@@ -1198,8 +1203,8 @@ class TopK(PrimitiveWithInfer):
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16) >>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
>>> k = 3 >>> k = 3
>>> values, indices = topk(input_x, k) >>> values, indices = topk(input_x, k)
>>> assert values == Tensor(np.array([5, 4, 3])) >>> assert values == Tensor(np.array([5, 4, 3]), mstype.float16)
>>> assert indices == Tensor(np.array([4, 3, 2])) >>> assert indices == Tensor(np.array([4, 3, 2]), mstype.int32)
""" """
@prim_attr_register @prim_attr_register
......
...@@ -793,8 +793,8 @@ test_case_nn_ops = [ ...@@ -793,8 +793,8 @@ test_case_nn_ops = [
'desc_bprop': [[5, 5]]}), 'desc_bprop': [[5, 5]]}),
('DepthwiseConv2dNative_1', { ('DepthwiseConv2dNative_1', {
'block': P.DepthwiseConv2dNative(3, (3, 3), pad_mode="pad", pad=1, stride=2), 'block': P.DepthwiseConv2dNative(3, (3, 3), pad_mode="pad", pad=1, stride=2),
'desc_inputs': [[10, 32, 32, 32], [3, 32, 3, 3]], 'desc_inputs': [[10, 32, 32, 32], [1, 32, 3, 3]],
'desc_bprop': [[10, 30, 16, 16]]}), 'desc_bprop': [[10, 32, 16, 16]]}),
('DepthwiseConv2dNative_2', { ('DepthwiseConv2dNative_2', {
'block': P.DepthwiseConv2dNative(1, (3, 3), pad_mode="same", pad=0, stride=1), 'block': P.DepthwiseConv2dNative(1, (3, 3), pad_mode="same", pad=0, stride=1),
'desc_inputs': [[2592, 2048, 4, 4], [1, 2048, 3, 3]], 'desc_inputs': [[2592, 2048, 4, 4], [1, 2048, 3, 3]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册