提交 d0c7ece6 编写于 作者: J jiangjinsheng

fix HistogramFixedWidth

上级 68731921
...@@ -1371,11 +1371,9 @@ class UnsortedSegmentMin(PrimitiveWithInfer): ...@@ -1371,11 +1371,9 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
""" """
Computes the minimum along segments of a tensor. Computes the minimum along segments of a tensor.
If the given segment_ids is negative, the value will be ignored.
Inputs: Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value should be >= 0.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`. - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
Outputs: Outputs:
......
...@@ -1356,7 +1356,7 @@ class HistogramFixedWidth(PrimitiveWithInfer): ...@@ -1356,7 +1356,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
Args: Args:
dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32". dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32".
nbins (Tensor): Number of histogram bins, the type is int32. nbins (int): Number of histogram bins, the type is positive integer.
Inputs: Inputs:
- **x** (Tensor) - Numeric Tensor. Must be one of the following types: int32, float32, float16. - **x** (Tensor) - Numeric Tensor. Must be one of the following types: int32, float32, float16.
...@@ -1377,6 +1377,7 @@ class HistogramFixedWidth(PrimitiveWithInfer): ...@@ -1377,6 +1377,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, nbins, dtype='int32'): def __init__(self, nbins, dtype='int32'):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name) self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
valid_values = ['int32', 'int64'] valid_values = ['int32', 'int64']
self.dtype = validator.check_string("dtype", dtype, valid_values, self.name) self.dtype = validator.check_string("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
......
...@@ -1738,6 +1738,8 @@ class SGD(PrimitiveWithInfer): ...@@ -1738,6 +1738,8 @@ class SGD(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False): def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
validator.check_value_type("nesterov", nesterov, [bool], self.name) validator.check_value_type("nesterov", nesterov, [bool], self.name)
if nesterov and dampening != 0:
raise ValueError(f"Nesterov need zero dampening!")
self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'],
outputs=['output']) outputs=['output'])
...@@ -2151,7 +2153,8 @@ class ResizeBilinear(PrimitiveWithInfer): ...@@ -2151,7 +2153,8 @@ class ResizeBilinear(PrimitiveWithInfer):
rescale by `new_height / height`. Default: False. rescale by `new_height / height`. Default: False.
Inputs: Inputs:
- **input** (Tensor) - Image to be resized. Tensor of shape `(N_i, ..., N_n, height, width)`. - **input** (Tensor) - Image to be resized. Tensor of shape `(N_i, ..., N_n, height, width)`,
with data type of float32 or float16.
Outputs: Outputs:
Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`. Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册