未验证 提交 98b59cb8 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix elementwise_mod float point kernel. test=develop (#21183)

上级 835119c7
...@@ -29,7 +29,9 @@ struct ModFunctor { ...@@ -29,7 +29,9 @@ struct ModFunctor {
template <typename T> template <typename T>
struct ModFunctorFP { struct ModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const { return std::fmod(a, b); } inline HOSTDEVICE T operator()(T a, T b) const {
return fmod(b + fmod(a, b), b);
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -71,7 +71,7 @@ class TestElementwiseModOpFloat(TestElementwiseModOp): ...@@ -71,7 +71,7 @@ class TestElementwiseModOpFloat(TestElementwiseModOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype) self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype) self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype)
self.out = np.fmod(self.x, self.y) self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y)
def test_check_output(self): def test_check_output(self):
self.check_output(atol=2e-5) self.check_output(atol=2e-5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册