未验证 提交 b04c78ef 编写于 作者: J joejiong 提交者: GitHub

Update pow (#29000)

Simple code clean up
上级 9479961d
......@@ -23,20 +23,17 @@ namespace operators {
template <typename T>
struct PowFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
// #ifdef __CUDA_ARCH__
// // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// // it will return a float number like 2.99... , which floor to 2
// // when cast to int by default and it is wrong.
// // Use llrint to cast it to the nearest integer, which is 3.
// if (std::is_integral<T>::value) {
// return std::llrint(std::pow(a, b));
// }
// #endif
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
#ifdef __CUDA_ARCH__
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
if (std::is_integral<T>::value) {
return std::llrint(std::pow(a, b));
}
#endif
return std::pow(a, b);
}
};
......
......@@ -172,12 +172,12 @@ def pow(x, y, name=None):
x = paddle.to_tensor([1, 2, 3])
y = 2
res = paddle.pow(x, y)
print(res.numpy()) # [1 4 9]
print(res) # [1 4 9]
# example 2: y is a Tensor
y = paddle.full(shape=[1], fill_value=2, dtype='float32')
res = paddle.pow(x, y)
print(res.numpy()) # [1 4 9]
print(res) # [1 4 9]
"""
# in dynamic graph mode
......@@ -185,14 +185,9 @@ def pow(x, y, name=None):
if isinstance(y, (int, float)):
return core.ops.pow(x, 'factor', y)
elif isinstance(y, (paddle.Tensor, Variable)):
if x.dtype != y.dtype:
y = cast(y, dtype='float64')
x = cast(x, dtype='float64')
out_dygraph = _elementwise_op_in_dygraph(
x, y, axis=-1, act=None, op_name='elementwise_pow')
return out_dygraph
return _elementwise_op_in_dygraph(
x, y, axis=-1, act=None, op_name='elementwise_pow')
else:
......@@ -213,9 +208,7 @@ def pow(x, y, name=None):
if x.dtype != y.dtype:
y = cast(y, dtype='float64')
x = cast(x, dtype='float64')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
return _elementwise_op(LayerHelper('elementwise_pow', **locals()))
else:
raise TypeError('y must be scalar or tensor type, but received: %s '% (type(y)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册