提交 157ee1ca 编写于 作者: J jiangjinsheng

fix nn.PReLU example

上级 304dbfaa
......@@ -378,7 +378,7 @@ class PReLU(Cell):
Tensor, with the same type and shape as the `input_data`.
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
>>> input_x = Tensor(np.random.rand(1, 10, 4, 4), mindspore.float32)
>>> prelu = nn.PReLU()
>>> prelu(input_x)
......
......@@ -1066,6 +1066,8 @@ class StridedSliceGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides):
args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return {'shape': shapex['value'],
'dtype': dy['dtype'],
'value': None}
......
......@@ -2602,6 +2602,8 @@ class SpaceToBatchND(PrimitiveWithInfer):
for elem in block_shape:
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
validator.check_value_type('block_shape element', elem, [int], self.name)
self.block_shape = block_shape
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
......@@ -2644,7 +2646,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
The length of block_shape is M correspoding to the number of spatial dimensions.
crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
input dimension i+2. It is required that input_shape[i+2]*block_size[i] > crops[i][0]+crops[i][1].
input dimension i+2. It is required that input_shape[i+2]*block_shape[i] > crops[i][0]+crops[i][1].
Inputs:
- **input_x** (Tensor) - The input tensor.
......@@ -2680,6 +2682,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
for elem in block_shape:
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
validator.check_value_type('block_shape element', elem, [int], self.name)
self.block_shape = block_shape
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
......
......@@ -2157,10 +2157,10 @@ class ResizeBilinear(PrimitiveWithInfer):
Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`.
Examples:
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32)
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
>>> resize_bilinear = P.ResizeBilinear((5, 5))
>>> result = resize_bilinear(tensor)
>>> assert result.shape == (5, 5)
>>> assert result.shape == (1, 1, 5, 5)
"""
@prim_attr_register
......@@ -2176,6 +2176,7 @@ class ResizeBilinear(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, input_dtype):
validator.check_tensor_type_same({'input_dtype': input_dtype}, [mstype.float16, mstype.float32], self.name)
return mstype.tensor_type(mstype.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册