未验证 提交 72a41e50 编写于 作者: T Tao Luo 提交者: GitHub

reduce compile time of amax and amin (#38534)

上级 4853ab0a
......@@ -125,13 +125,14 @@ struct AMaxOrAMinGradFunctor {
HANDLE_AXIS_DIM(3, 2);
HANDLE_AXIS_DIM(4, 2);
HANDLE_AXIS_DIM(4, 3);
HANDLE_AXIS_DIM(5, 2);
HANDLE_AXIS_DIM(5, 3);
HANDLE_AXIS_DIM(5, 4);
HANDLE_AXIS_DIM(6, 2);
HANDLE_AXIS_DIM(6, 3);
HANDLE_AXIS_DIM(6, 4);
HANDLE_AXIS_DIM(6, 5);
// comments for accelerating compiling temporarily.
// HANDLE_AXIS_DIM(5, 2);
// HANDLE_AXIS_DIM(5, 3);
// HANDLE_AXIS_DIM(5, 4);
// HANDLE_AXIS_DIM(6, 2);
// HANDLE_AXIS_DIM(6, 3);
// HANDLE_AXIS_DIM(6, 4);
// HANDLE_AXIS_DIM(6, 5);
}
};
......
......@@ -1775,7 +1775,8 @@ def amax(x, axis=None, keepdim=False, name=None):
while max propagates gradient to all of them.
Args:
x(Tensor): A tensor, the data type is float32, float64, int32, int64.
x(Tensor): A tensor, the data type is float32, float64, int32, int64,
the dimension is no more than 4.
axis(int|list|tuple, optional): The axis along which the maximum is computed.
If :attr:`None`, compute the maximum over all elements of
`x` and return a Tensor with a single element,
......@@ -1887,7 +1888,8 @@ def amin(x, axis=None, keepdim=False, name=None):
while min propagates gradient to all of them.
Args:
x(Tensor): A tensor, the data type is float32, float64, int32, int64.
x(Tensor): A tensor, the data type is float32, float64, int32, int64,
the dimension is no more than 4.
axis(int|list|tuple, optional): The axis along which the minimum is computed.
If :attr:`None`, compute the minimum over all elements of
`x` and return a Tensor with a single element,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册