未验证 提交 14625ffe 编写于 作者: K Kaipeng Deng 提交者: GitHub

add elementwise mod support float/double. test=develop (#19570)

上级 5b07ca9c
......@@ -33,4 +33,6 @@ REGISTER_OP_WITHOUT_GRADIENT(elementwise_mod, ops::ElementwiseOp,
REGISTER_OP_CPU_KERNEL(
elementwise_mod,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -19,4 +19,6 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>);
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseModFPKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseModFPKernel<plat::CUDADeviceContext, double>);
......@@ -27,6 +27,11 @@ struct ModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a % b; }
};
template <typename T>
struct ModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const { return std::fmod(a, b); }
};
template <typename DeviceContext, typename T>
void elementwise_mod(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
......@@ -36,6 +41,15 @@ void elementwise_mod(const framework::ExecutionContext &ctx,
ModFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
void elementwise_mod_fp(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctorFP<T>(), z);
}
template <typename DeviceContext, typename T>
class ElementwiseModKernel : public framework::OpKernel<T> {
public:
......@@ -51,5 +65,20 @@ class ElementwiseModKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseModFPKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is float or double
elementwise_mod_fp<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators
} // namespace paddle
......@@ -27,7 +27,6 @@ class TestElementwiseModOp(OpTest):
def setUp(self):
self.op_type = "elementwise_mod"
self.dtype = np.int32
self.axis = -1
self.init_dtype()
self.init_input_output()
......@@ -50,7 +49,7 @@ class TestElementwiseModOp(OpTest):
self.out = np.mod(self.x, self.y)
def init_dtype(self):
pass
self.dtype = np.int32
def init_axis(self):
pass
......@@ -65,5 +64,23 @@ class TestElementwiseModOp_scalar(TestElementwiseModOp):
self.out = np.mod(self.x, self.y)
class TestElementwiseModOpFloat(TestElementwiseModOp):
def init_dtype(self):
self.dtype = np.float32
def init_input_output(self):
self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype)
self.out = np.fmod(self.x, self.y)
def test_check_output(self):
self.check_output(atol=2e-5)
class TestElementwiseModOpDouble(TestElementwiseModOpFloat):
def init_dtype(self):
self.dtype = np.float64
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册