提交 d0c7ece6 编写于 作者: J jiangjinsheng

fix HistogramFixedWidth

上级 68731921
......@@ -1371,11 +1371,9 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
"""
Computes the minimum along segments of a tensor.
If the given segment_ids is negative, the value will be ignored.
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value should be >= 0.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
Outputs:
......
......@@ -1356,7 +1356,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
Args:
dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32".
nbins (Tensor): Number of histogram bins, the type is int32.
nbins (int): Number of histogram bins, the type is positive integer.
Inputs:
- **x** (Tensor) - Numeric Tensor. Must be one of the following types: int32, float32, float16.
......@@ -1377,6 +1377,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nbins, dtype='int32'):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
valid_values = ['int32', 'int64']
self.dtype = validator.check_string("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
......
......@@ -1738,6 +1738,8 @@ class SGD(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
validator.check_value_type("nesterov", nesterov, [bool], self.name)
if nesterov and dampening != 0:
raise ValueError(f"Nesterov need zero dampening!")
self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'],
outputs=['output'])
......@@ -2151,7 +2153,8 @@ class ResizeBilinear(PrimitiveWithInfer):
rescale by `new_height / height`. Default: False.
Inputs:
- **input** (Tensor) - Image to be resized. Tensor of shape `(N_i, ..., N_n, height, width)`.
- **input** (Tensor) - Image to be resized. Tensor of shape `(N_i, ..., N_n, height, width)`,
with data type of float32 or float16.
Outputs:
Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册