未验证 提交 ff3dc8ac 编写于 作者: S ShenLiang 提交者: GitHub

fix the remainder (#26995)

上级 352ac149
...@@ -31,6 +31,15 @@ struct ModFunctor { ...@@ -31,6 +31,15 @@ struct ModFunctor {
} }
}; };
template <typename T>
struct InverseModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
T res = b % a;
if ((res != 0) && ((res < 0) != (a < 0))) res += a;
return res;
}
};
template <typename T> template <typename T>
struct ModFunctorFP { struct ModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const { inline HOSTDEVICE T operator()(T a, T b) const {
...@@ -40,13 +49,29 @@ struct ModFunctorFP { ...@@ -40,13 +49,29 @@ struct ModFunctorFP {
} }
}; };
template <typename T>
struct InverseModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const {
T res = fmod(b, a);
if ((res != 0) && ((a < 0) != (res < 0))) res += a;
return res;
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void elementwise_mod(const framework::ExecutionContext &ctx, void elementwise_mod(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) { framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z); ModFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseModFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseModFunctor<T>(), z);
}
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -54,8 +79,15 @@ void elementwise_mod_fp(const framework::ExecutionContext &ctx, ...@@ -54,8 +79,15 @@ void elementwise_mod_fp(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) { framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(ctx, x, y, axis, auto x_dims = x->dims();
ModFunctorFP<T>(), z); auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(
ctx, x, y, axis, ModFunctorFP<T>(), z);
} else {
ElementwiseComputeEx<InverseModFunctorFP<T>, DeviceContext, T>(
ctx, x, y, axis, InverseModFunctorFP<T>(), z);
}
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -220,6 +220,14 @@ class TestRemainderAPI(unittest.TestCase): ...@@ -220,6 +220,14 @@ class TestRemainderAPI(unittest.TestCase):
z_expected = np.array([0, 1, 1, -1]) z_expected = np.array([0, 1, 1, -1])
self.assertEqual(np.allclose(z_expected, z.numpy()), True) self.assertEqual(np.allclose(z_expected, z.numpy()), True)
np_x = np.array([-3, 3])
np_y = np.array([[2, 3], [-2, -1]])
x = paddle.to_tensor(np_x, dtype="int64")
y = paddle.to_tensor(np_y, dtype="int64")
z = x % y
z_expected = np.array([[1, 0], [-1, 0]])
self.assertEqual(np.allclose(z_expected, z.numpy()), True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册