未验证 提交 b04c78ef 编写于 作者: J joejiong 提交者: GitHub

Update pow (#29000)

Simple code clean up
上级 9479961d
...@@ -23,20 +23,17 @@ namespace operators { ...@@ -23,20 +23,17 @@ namespace operators {
template <typename T> template <typename T>
struct PowFunctor { struct PowFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { inline HOSTDEVICE T operator()(T a, T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different // TODO(wujionghao): A potential speed improvement is supporting different
// types in C++. // types in C++.
// #ifdef __CUDA_ARCH__ #ifdef __CUDA_ARCH__
// // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// // it will return a float number like 2.99... , which floor to 2 // it will return a float number like 2.99... , which floor to 2
// // when cast to int by default and it is wrong. // when cast to int by default and it is wrong.
// // Use llrint to cast it to the nearest integer, which is 3. // Use llrint to cast it to the nearest integer, which is 3.
// if (std::is_integral<T>::value) {
// return std::llrint(std::pow(a, b));
// }
// #endif
if (std::is_integral<T>::value) { if (std::is_integral<T>::value) {
return std::llrint(std::pow(a, b)); return std::llrint(std::pow(a, b));
} }
#endif
return std::pow(a, b); return std::pow(a, b);
} }
}; };
......
...@@ -172,12 +172,12 @@ def pow(x, y, name=None): ...@@ -172,12 +172,12 @@ def pow(x, y, name=None):
x = paddle.to_tensor([1, 2, 3]) x = paddle.to_tensor([1, 2, 3])
y = 2 y = 2
res = paddle.pow(x, y) res = paddle.pow(x, y)
print(res.numpy()) # [1 4 9] print(res) # [1 4 9]
# example 2: y is a Tensor # example 2: y is a Tensor
y = paddle.full(shape=[1], fill_value=2, dtype='float32') y = paddle.full(shape=[1], fill_value=2, dtype='float32')
res = paddle.pow(x, y) res = paddle.pow(x, y)
print(res.numpy()) # [1 4 9] print(res) # [1 4 9]
""" """
# in dynamic graph mode # in dynamic graph mode
...@@ -185,14 +185,9 @@ def pow(x, y, name=None): ...@@ -185,14 +185,9 @@ def pow(x, y, name=None):
if isinstance(y, (int, float)): if isinstance(y, (int, float)):
return core.ops.pow(x, 'factor', y) return core.ops.pow(x, 'factor', y)
elif isinstance(y, (paddle.Tensor, Variable)): elif isinstance(y, (paddle.Tensor, Variable)):
if x.dtype != y.dtype: if x.dtype != y.dtype:
y = cast(y, dtype='float64') y = cast(y, dtype='float64')
x = cast(x, dtype='float64') x = cast(x, dtype='float64')
out_dygraph = _elementwise_op_in_dygraph(
x, y, axis=-1, act=None, op_name='elementwise_pow')
return out_dygraph
return _elementwise_op_in_dygraph( return _elementwise_op_in_dygraph(
x, y, axis=-1, act=None, op_name='elementwise_pow') x, y, axis=-1, act=None, op_name='elementwise_pow')
else: else:
...@@ -213,9 +208,7 @@ def pow(x, y, name=None): ...@@ -213,9 +208,7 @@ def pow(x, y, name=None):
if x.dtype != y.dtype: if x.dtype != y.dtype:
y = cast(y, dtype='float64') y = cast(y, dtype='float64')
x = cast(x, dtype='float64') x = cast(x, dtype='float64')
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
return _elementwise_op(LayerHelper('elementwise_pow', **locals())) return _elementwise_op(LayerHelper('elementwise_pow', **locals()))
else: else:
raise TypeError('y must be scalar or tensor type, but received: %s '% (type(y))) raise TypeError('y must be scalar or tensor type, but received: %s '% (type(y)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册