未验证 提交 72a41e50 编写于 作者: T Tao Luo 提交者: GitHub

reduce compile time of amax and amin (#38534)

上级 4853ab0a
...@@ -125,13 +125,14 @@ struct AMaxOrAMinGradFunctor { ...@@ -125,13 +125,14 @@ struct AMaxOrAMinGradFunctor {
HANDLE_AXIS_DIM(3, 2); HANDLE_AXIS_DIM(3, 2);
HANDLE_AXIS_DIM(4, 2); HANDLE_AXIS_DIM(4, 2);
HANDLE_AXIS_DIM(4, 3); HANDLE_AXIS_DIM(4, 3);
HANDLE_AXIS_DIM(5, 2); // comments for accelerating compiling temporarily.
HANDLE_AXIS_DIM(5, 3); // HANDLE_AXIS_DIM(5, 2);
HANDLE_AXIS_DIM(5, 4); // HANDLE_AXIS_DIM(5, 3);
HANDLE_AXIS_DIM(6, 2); // HANDLE_AXIS_DIM(5, 4);
HANDLE_AXIS_DIM(6, 3); // HANDLE_AXIS_DIM(6, 2);
HANDLE_AXIS_DIM(6, 4); // HANDLE_AXIS_DIM(6, 3);
HANDLE_AXIS_DIM(6, 5); // HANDLE_AXIS_DIM(6, 4);
// HANDLE_AXIS_DIM(6, 5);
} }
}; };
......
...@@ -1775,7 +1775,8 @@ def amax(x, axis=None, keepdim=False, name=None): ...@@ -1775,7 +1775,8 @@ def amax(x, axis=None, keepdim=False, name=None):
while max propagates gradient to all of them. while max propagates gradient to all of them.
Args: Args:
x(Tensor): A tensor, the data type is float32, float64, int32, int64. x(Tensor): A tensor, the data type is float32, float64, int32, int64,
the dimension is no more than 4.
axis(int|list|tuple, optional): The axis along which the maximum is computed. axis(int|list|tuple, optional): The axis along which the maximum is computed.
If :attr:`None`, compute the maximum over all elements of If :attr:`None`, compute the maximum over all elements of
`x` and return a Tensor with a single element, `x` and return a Tensor with a single element,
...@@ -1887,7 +1888,8 @@ def amin(x, axis=None, keepdim=False, name=None): ...@@ -1887,7 +1888,8 @@ def amin(x, axis=None, keepdim=False, name=None):
while min propagates gradient to all of them. while min propagates gradient to all of them.
Args: Args:
x(Tensor): A tensor, the data type is float32, float64, int32, int64. x(Tensor): A tensor, the data type is float32, float64, int32, int64,
the dimension is no more than 4.
axis(int|list|tuple, optional): The axis along which the minimum is computed. axis(int|list|tuple, optional): The axis along which the minimum is computed.
If :attr:`None`, compute the minimum over all elements of If :attr:`None`, compute the minimum over all elements of
`x` and return a Tensor with a single element, `x` and return a Tensor with a single element,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册