未验证 提交 d8314ff5 编写于 作者: L Lin Manhui 提交者: GitHub

[Fix] Fix paddle.pow() Gets Incorrect Result When Broadcasting Is Triggered (#47307)

* Fix paddle.pow() bugs

* Add unittest cases

* Fix ut cases

* Add ut cases on multiple devices
上级 2534ca7e
...@@ -91,8 +91,15 @@ void ElementwisePowRawKernel(const Context& dev_ctx, ...@@ -91,8 +91,15 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
// allocate memory for out // allocate memory for out
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::ElementwisePowFunctor<T>, T>( auto x_dims = x.dims();
dev_ctx, x, y, axis, funcs::ElementwisePowFunctor<T>(), out); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::ElementwisePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwisePowFunctor<T>(), out);
} else {
funcs::ElementwiseCompute<funcs::ElementwiseInversePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwiseInversePowFunctor<T>(), out);
}
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -606,6 +606,28 @@ struct ElementwisePowFunctor { ...@@ -606,6 +606,28 @@ struct ElementwisePowFunctor {
} }
}; };
template <typename T>
struct ElementwiseInversePowFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
if (std::is_integral<T>::value) {
return std::llrint(
std::pow(static_cast<double>(b), static_cast<double>(a)));
}
#endif
#ifdef PADDLE_WITH_XPU_KP
return pow(b, a);
#endif
return std::pow(b, a);
}
};
template <> template <>
struct ElementwisePowFunctor<dtype::float16> { struct ElementwisePowFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a, inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
...@@ -616,5 +638,15 @@ struct ElementwisePowFunctor<dtype::float16> { ...@@ -616,5 +638,15 @@ struct ElementwisePowFunctor<dtype::float16> {
} }
}; };
template <>
struct ElementwiseInversePowFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float f_a = static_cast<float>(a);
float f_b = static_cast<float>(b);
return static_cast<dtype::float16>(std::pow(f_b, f_a));
}
};
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -18,15 +18,18 @@ import numpy as np ...@@ -18,15 +18,18 @@ import numpy as np
import paddle import paddle
from paddle.static import Program, program_guard from paddle.static import Program, program_guard
import paddle.fluid.core as core
DYNAMIC = 1 DYNAMIC = 1
STATIC = 2 STATIC = 2
def _run_power(mode, x, y): def _run_power(mode, x, y, device='cpu'):
# dynamic mode # dynamic mode
if mode == DYNAMIC: if mode == DYNAMIC:
paddle.disable_static() paddle.disable_static()
# Set device
paddle.set_device(device)
# y is scalar # y is scalar
if isinstance(y, (int, float)): if isinstance(y, (int, float)):
x_ = paddle.to_tensor(x) x_ = paddle.to_tensor(x)
...@@ -48,7 +51,11 @@ def _run_power(mode, x, y): ...@@ -48,7 +51,11 @@ def _run_power(mode, x, y):
x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype) x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
y_ = y y_ = y
res = paddle.pow(x_, y_) res = paddle.pow(x_, y_)
place = paddle.CPUPlace() place = (
paddle.CPUPlace()
if device == 'cpu'
else paddle.CUDAPlace(0)
)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
outs = exe.run(feed={'x': x}, fetch_list=[res]) outs = exe.run(feed={'x': x}, fetch_list=[res])
return outs[0] return outs[0]
...@@ -58,7 +65,11 @@ def _run_power(mode, x, y): ...@@ -58,7 +65,11 @@ def _run_power(mode, x, y):
x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype) x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
y_ = paddle.static.data(name="y", shape=y.shape, dtype=y.dtype) y_ = paddle.static.data(name="y", shape=y.shape, dtype=y.dtype)
res = paddle.pow(x_, y_) res = paddle.pow(x_, y_)
place = paddle.CPUPlace() place = (
paddle.CPUPlace()
if device == 'cpu'
else paddle.CUDAPlace(0)
)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
outs = exe.run(feed={'x': x, 'y': y}, fetch_list=[res]) outs = exe.run(feed={'x': x, 'y': y}, fetch_list=[res])
return outs[0] return outs[0]
...@@ -67,82 +78,104 @@ def _run_power(mode, x, y): ...@@ -67,82 +78,104 @@ def _run_power(mode, x, y):
class TestPowerAPI(unittest.TestCase): class TestPowerAPI(unittest.TestCase):
"""TestPowerAPI.""" """TestPowerAPI."""
def setUp(self):
self.places = ['cpu']
if core.is_compiled_with_cuda():
self.places.append('gpu')
def test_power(self): def test_power(self):
"""test_power.""" """test_power."""
np.random.seed(7) np.random.seed(7)
# test 1-d float tensor ** float scalar for place in self.places:
dims = (np.random.randint(200, 300),) # test 1-d float tensor ** float scalar
x = (np.random.rand(*dims) * 10).astype(np.float64) dims = (np.random.randint(200, 300),)
y = np.random.rand() * 10 x = (np.random.rand(*dims) * 10).astype(np.float64)
res = _run_power(DYNAMIC, x, y) y = np.random.rand() * 10
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(DYNAMIC, x, y, place)
res = _run_power(STATIC, x, y) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
# test 1-d float tensor ** int scalar
dims = (np.random.randint(200, 300),) # test 1-d float tensor ** int scalar
x = (np.random.rand(*dims) * 10).astype(np.float64) dims = (np.random.randint(200, 300),)
y = int(np.random.rand() * 10) x = (np.random.rand(*dims) * 10).astype(np.float64)
res = _run_power(DYNAMIC, x, y) y = int(np.random.rand() * 10)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(DYNAMIC, x, y, place)
res = _run_power(STATIC, x, y) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
x = (np.random.rand(*dims) * 10).astype(np.int64)
y = int(np.random.rand() * 10) x = (np.random.rand(*dims) * 10).astype(np.int64)
res = _run_power(DYNAMIC, x, y) y = int(np.random.rand() * 10)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(DYNAMIC, x, y, place)
res = _run_power(STATIC, x, y) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
# test 1-d float tensor ** 1-d float tensor
dims = (np.random.randint(200, 300),) # test 1-d float tensor ** 1-d float tensor
x = (np.random.rand(*dims) * 10).astype(np.float64) dims = (np.random.randint(200, 300),)
y = (np.random.rand(*dims) * 10).astype(np.float64) x = (np.random.rand(*dims) * 10).astype(np.float64)
res = _run_power(DYNAMIC, x, y) y = (np.random.rand(*dims) * 10).astype(np.float64)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(DYNAMIC, x, y, place)
res = _run_power(STATIC, x, y) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
# test 1-d int tensor ** 1-d int tensor
dims = (np.random.randint(200, 300),) # test 1-d int tensor ** 1-d int tensor
x = (np.random.rand(*dims) * 10).astype(np.int64) dims = (np.random.randint(200, 300),)
y = (np.random.rand(*dims) * 10).astype(np.int64) x = (np.random.rand(*dims) * 10).astype(np.int64)
res = _run_power(DYNAMIC, x, y) y = (np.random.rand(*dims) * 10).astype(np.int64)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(DYNAMIC, x, y, place)
res = _run_power(STATIC, x, y) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
# test 1-d int tensor ** 1-d int tensor
dims = (np.random.randint(200, 300),) # test 1-d int tensor ** 1-d int tensor
x = (np.random.rand(*dims) * 10).astype(np.int32) dims = (np.random.randint(200, 300),)
y = (np.random.rand(*dims) * 10).astype(np.int32) x = (np.random.rand(*dims) * 10).astype(np.int32)
res = _run_power(DYNAMIC, x, y) y = (np.random.rand(*dims) * 10).astype(np.int32)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(DYNAMIC, x, y, place)
res = _run_power(STATIC, x, y) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
# test 1-d int tensor ** 1-d int tensor
dims = (np.random.randint(200, 300),) # test 1-d int tensor ** 1-d int tensor
x = (np.random.rand(*dims) * 10).astype(np.float32) dims = (np.random.randint(200, 300),)
y = (np.random.rand(*dims) * 10).astype(np.float32) x = (np.random.rand(*dims) * 10).astype(np.float32)
res = _run_power(DYNAMIC, x, y) y = (np.random.rand(*dims) * 10).astype(np.float32)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(DYNAMIC, x, y, place)
res = _run_power(STATIC, x, y) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
# test broadcast
dims = ( # test float scalar ** 2-d float tensor
np.random.randint(1, 10), dims = (np.random.randint(2, 10), np.random.randint(5, 10))
np.random.randint(5, 10), x = np.random.rand() * 10
np.random.randint(5, 10), y = (np.random.rand(*dims) * 10).astype(np.float32)
) res = _run_power(DYNAMIC, x, y, place)
x = (np.random.rand(*dims) * 10).astype(np.float64) np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
y = (np.random.rand(dims[-1]) * 10).astype(np.float64)
res = _run_power(DYNAMIC, x, y) # test 2-d float tensor ** float scalar
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) dims = (np.random.randint(2, 10), np.random.randint(5, 10))
res = _run_power(STATIC, x, y) x = (np.random.rand(*dims) * 10).astype(np.float32)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05) y = np.random.rand() * 10
res = _run_power(DYNAMIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
res = _run_power(STATIC, x, y, place)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
# test broadcast
dims = (
np.random.randint(1, 10),
np.random.randint(5, 10),
np.random.randint(5, 10),
)
x = (np.random.rand(*dims) * 10).astype(np.float64)
y = (np.random.rand(dims[-1]) * 10).astype(np.float64)
res = _run_power(DYNAMIC, x, y)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
res = _run_power(STATIC, x, y)
np.testing.assert_allclose(res, np.power(x, y), rtol=1e-05)
class TestPowerError(unittest.TestCase): class TestPowerError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册