diff --git a/oneflow/core/kernel/broadcast_floor_mod_kernel.cpp b/oneflow/core/kernel/broadcast_floor_mod_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..150a8bc6c885d645820f4a97bfb0d060880db424 --- /dev/null +++ b/oneflow/core/kernel/broadcast_floor_mod_kernel.cpp @@ -0,0 +1,33 @@ +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/ndarray/ndarray_util.h" + +namespace oneflow { + +template +class BroadcastFloorModKernel final : public KernelIf { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastFloorModKernel); + BroadcastFloorModKernel() = default; + ~BroadcastFloorModKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function) const override; +}; + +template +void BroadcastFloorModKernel::ForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* a = BnInOp2Blob("a"); + const Blob* b = BnInOp2Blob("b"); + Blob* out = BnInOp2Blob("out"); + + size_t num_axes = out->shape().NumAxes(); + NdarrayUtil::BroadcastFloorMod(ctx.device_ctx, XpuVarNdarray(out, num_axes), + XpuVarNdarray(a, num_axes), + XpuVarNdarray(b, num_axes)); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastFloorModConf, BroadcastFloorModKernel, + ARITHMETIC_DATA_TYPE_SEQ); +} // namespace oneflow diff --git a/oneflow/core/ndarray/binary_func.h b/oneflow/core/ndarray/binary_func.h index 1720f9c96dbdbb96d6e597ab9d9b57ee6827bdb1..26e55fe38bbbff4b4b9eaf361153223baa5600fe 100644 --- a/oneflow/core/ndarray/binary_func.h +++ b/oneflow/core/ndarray/binary_func.h @@ -12,7 +12,7 @@ #include "oneflow/core/common/util.h" namespace oneflow { -#define ARITHMETIC_BINARY_FUNC_NAME_SEQ (Add)(Sub)(Mul)(Div)(Min)(Max) +#define ARITHMETIC_BINARY_FUNC_NAME_SEQ (Add)(Sub)(Mul)(Div)(Min)(Max)(FloorMod) #define LOGICAL_BINARY_FUNC_NAME_SEQ (EQ)(NE)(GT)(GE)(LT)(LE)(AND) #define PREPEND_PREFIX_BINARY_FUNC(name) OF_PP_CAT(BinaryFunc, name) @@ -73,6 +73,12 @@ struct BinaryFuncDiv final { }; SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncDiv); +template +struct BinaryFuncFloorMod final { + static OF_DEVICE_FUNC const T Invoke(const T x, const T y) { return x % y; } +}; +SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncFloorMod); + template struct BinaryFuncMax final { static OF_DEVICE_FUNC const T Invoke(const T x, const T y) { return x > y ? x : y; } @@ -198,6 +204,58 @@ struct BinaryFuncMin final { #endif // defined(__CUDACC__) +#if defined(__CUDACC__) + +template<> +struct BinaryFuncFloorMod final { + static __device__ __forceinline__ const float Invoke(const float x, const float y) { + return fmodf(x, y); + } +}; + +template<> +struct BinaryFuncFloorMod final { + static __device__ __forceinline__ const double Invoke(const double x, const double y) { + return fmod(x, y); + } +}; + +template<> +struct BinaryFuncFloorMod final { + static __device__ __forceinline__ const half Invoke(const half x, const half y) { +#if __CUDA_ARCH__ >= 530 + return __float2half(fmodf(__half2float(x), __half2float(y))); +#else + NO_HALF_UTIL_FOUND; +#endif + } +}; + +#else + +template<> +struct BinaryFuncFloorMod final { + static __device__ __forceinline__ const float Invoke(const float x, const float y) { + return std::fmod(x, y); + } +}; + +template<> +struct BinaryFuncFloorMod final { + static __device__ __forceinline__ const double Invoke(const double x, const double y) { + return std::fmod(x, y); + } +}; + +template<> +struct BinaryFuncFloorMod final { + static __device__ __forceinline__ const float16 Invoke(const float16 x, const float16 y) { + return static_cast(std::fmod(static_cast(x), static_cast(y))); + } +}; + +#endif // defined(__CUDACC__) + template class binary_func> struct UnitOfBinaryFunc; diff --git a/oneflow/core/operator/broadcast_floor_mod_op.cpp b/oneflow/core/operator/broadcast_floor_mod_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..706cfcdb1a5f3bef64dd8692eccdf8989884300a --- /dev/null +++ b/oneflow/core/operator/broadcast_floor_mod_op.cpp @@ -0,0 +1,32 @@ +#include "oneflow/core/operator/broadcast_binary_op.h" + +namespace oneflow { + +class BroadcastFloorModOp final : public BroadcastBinaryOp { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastFloorModOp); + BroadcastFloorModOp() = default; + ~BroadcastFloorModOp() override = default; + + private: + const PbMessage& GetCustomizedConf() const override; + Maybe VirtualGetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + SbpSignatureList* sbp_sig_list) const override; +}; + +const PbMessage& BroadcastFloorModOp::GetCustomizedConf() const { + return op_conf().broadcast_floor_mod_conf(); +} + +Maybe BroadcastFloorModOp::VirtualGetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + SbpSignatureList* sbp_sig_list) const { + SbpSignatureBuilder().PartialSum("a").Broadcast("b").PartialSum("out").Build( + sbp_sig_list->mutable_sbp_signature()->Add()); + return Maybe::Ok(); +} + +REGISTER_OP(OperatorConf::kBroadcastFloorModConf, BroadcastFloorModOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 01992bbee566632d4e98d6e3f5f9dc53559c8d87..7b036fded7cce7ad759b0a9665ea192a3e747372 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -1206,6 +1206,13 @@ message BroadcastDivOpConf { optional bool is_const = 4 [default = false]; } +message BroadcastFloorModOpConf { + required string a = 1; + required string b = 2; + required string out = 3; + optional bool is_const = 4 [default = false]; +} + message BroadcastLikeOpConf { required string x = 1; required string like = 2; @@ -1718,6 +1725,7 @@ message OperatorConf { BroadcastSubOpConf broadcast_sub_conf = 501; BroadcastMulOpConf broadcast_mul_conf = 502; BroadcastDivOpConf broadcast_div_conf = 503; + BroadcastFloorModOpConf broadcast_floor_mod_conf = 504; SquareOpConf square_conf = 513; SqrtOpConf sqrt_conf = 514; RsqrtOpConf rsqrt_conf = 515; diff --git a/oneflow/python/framework/remote_blob.py b/oneflow/python/framework/remote_blob.py index f9e51cb7cdb70adc7ff846328badbf914696c819..40e5d4142a9faebe69daa0ae9f44b9e0cedc3ce0 100644 --- a/oneflow/python/framework/remote_blob.py +++ b/oneflow/python/framework/remote_blob.py @@ -42,11 +42,11 @@ class BlobDef(blob_desc.BlobDesc): @property def is_dynamic(self): raise NotImplementedError - + @property def disable_boxing(self): raise NotImplementedError - + @property def is_tensor_list(self): raise NotImplementedError @@ -54,7 +54,7 @@ class BlobDef(blob_desc.BlobDesc): @property def disable_boxing(self): raise NotImplementedError - + @property def parallel_conf(self): raise NotImplementedError @@ -101,6 +101,9 @@ class BlobDef(blob_desc.BlobDesc): def __div__(self, rhs): return oneflow.math.divide(self, rhs) + def __mod__(self, rhs): + return oneflow.math.mod(self, rhs) + def __eq__(self, rhs): return oneflow.math.equal(self, rhs) @@ -145,11 +148,11 @@ class ConsistentBlob(BlobDef): @property def is_dynamic(self): return c_api_util.JobBuildAndInferCtx_IsDynamic(self.job_name_, self.lbn_) - + @property def disable_boxing(self): return c_api_util.JobBuildAndInferCtx_DisableBoxing(self.job_name_, self.lbn_) - + @property def is_tensor_list(self): return c_api_util.JobBuildAndInferCtx_IsTensorList(self.job_name_, self.lbn_) @@ -157,7 +160,7 @@ class ConsistentBlob(BlobDef): @property def disable_boxing(self): return c_api_util.JobBuildAndInferCtx_DisableBoxing(self.job_name_, self.lbn_) - + @property def parallel_conf(self): return c_api_util.JobBuildAndInferCtx_GetParallelConfFromProducerView(self.job_name_, @@ -175,10 +178,10 @@ class MirroredBlob(BlobDef): consistent_blob = ConsistentBlob(sub_lbi, auto_watched_within_scope=False) self.sub_consistent_blob_list_.append(consistent_blob) watch_scope_util.TryWatchOnce(self) - - @property + + @property def sub_consistent_blob_list(self): return self.sub_consistent_blob_list_ - + @property def static_shape(self): return c_api_util.JobBuildAndInferCtx_MirroredBlobGetStaticShape(self.job_name_, self.lbn_) @@ -199,11 +202,11 @@ class MirroredBlob(BlobDef): @property def is_dynamic(self): return c_api_util.JobBuildAndInferCtx_MirroredBlobIsDynamic(self.job_name_, self.lbn_) - + @property def disable_boxing(self): return c_api_util.JobBuildAndInferCtx_MirroredBlobDisableBoxing(self.job_name_, self.lbn_) - + @property def is_tensor_list(self): return c_api_util.JobBuildAndInferCtx_MirroredBlobIsTensorList(self.job_name_, self.lbn_) @@ -211,7 +214,7 @@ class MirroredBlob(BlobDef): @property def disable_boxing(self): return c_api_util.JobBuildAndInferCtx_MirroredBlobDisableBoxing(self.job_name_, self.lbn_) - + @property def parallel_conf(self): return c_api_util.JobBuildAndInferCtx_MirroredBlobGetParallelConfFromProducerView( diff --git a/oneflow/python/ops/math_ops.py b/oneflow/python/ops/math_ops.py index 3aaa9dc51afd0c227a2ee1da1406d3069d37ddab..4d7d0d3313525971202692916a419ea7981e5fbe 100644 --- a/oneflow/python/ops/math_ops.py +++ b/oneflow/python/ops/math_ops.py @@ -59,6 +59,19 @@ def divide(x, y, name=None): return broadcast_div(x, y, name) +@oneflow_export("math.mod") +def floor_mod(x, y, name=None): + if isinstance(x, (int, float)): + raise NotImplementedError + elif isinstance(y, (int, float)): + raise NotImplementedError + elif x.static_shape == y.static_shape: + # TODO: add element-wise op + return broadcast_floor_mod(x, y, name) + else: + return broadcast_floor_mod(x, y, name) + + def scalar_add(x, operand, name=None): op_conf = op_conf_util.OperatorConf() setattr( @@ -197,6 +210,23 @@ def broadcast_div(x, y, name=None): return remote_blob_util.RemoteBlob(lbi) +def broadcast_floor_mod(x, y, name=None): + op_conf = op_conf_util.OperatorConf() + setattr( + op_conf, + "name", + name if name is not None else id_util.UniqueStr("BroadcastMod_"), + ) + op_conf.broadcast_floor_mod_conf.a = x.logical_blob_name + op_conf.broadcast_floor_mod_conf.b = y.logical_blob_name + op_conf.broadcast_floor_mod_conf.out = "out" + compile_context.CurJobAddOp(op_conf) + lbi = logical_blob_id_util.LogicalBlobId() + lbi.op_name = op_conf.name + lbi.blob_name = "out" + return remote_blob_util.RemoteBlob(lbi) + + @oneflow_export("math.tanh") def tanh(x, name=None): op_conf = op_conf_util.OperatorConf() diff --git a/oneflow/python/test/ops/test_mod.py b/oneflow/python/test/ops/test_mod.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8e94c9514bdfbb5edcbba9d9ddee32c0b0afe0 --- /dev/null +++ b/oneflow/python/test/ops/test_mod.py @@ -0,0 +1,49 @@ +import oneflow as flow +import numpy as np + +func_config = flow.FunctionConfig() +func_config.default_data_type(flow.float) + +def test_naive(test_case): + @flow.function(func_config) + def ModJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((5, 2))): + return a % b + + x = np.random.rand(5, 2).astype(np.float32) + y = np.random.rand(5, 2).astype(np.float32) + z = None + z = ModJob(x, y).get().ndarray() + test_case.assertTrue(np.array_equal(z, x % y)) + +def test_broadcast(test_case): + @flow.function(func_config) + def ModJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((1, 2))): + return a % b + + x = np.random.rand(5, 2).astype(np.float32) + y = np.random.rand(1, 2).astype(np.float32) + z = None + z = ModJob(x, y).get().ndarray() + test_case.assertTrue(np.array_equal(z, x % y)) + +def test_xy_mod_x1(test_case): + GenerateTest(test_case, (64, 64), (64, 1)) + +def test_xy_mod_1y(test_case): + GenerateTest(test_case, (64, 64), (1, 64)) + +def test_xyz_mod_x1z(test_case): + GenerateTest(test_case, (64, 64, 64), (64, 1, 64)) + +def test_xyz_mod_1y1(test_case): + GenerateTest(test_case, (64, 64, 64), (1, 64, 1)) + +def GenerateTest(test_case, a_shape, b_shape): + @flow.function(func_config) + def ModJob(a=flow.FixedTensorDef(a_shape), b=flow.FixedTensorDef(b_shape)): + return a % b + + a = np.random.rand(*a_shape).astype(np.float32) + b = np.random.rand(*b_shape).astype(np.float32) + y = ModJob(a, b).get().ndarray() + test_case.assertTrue(np.array_equal(y, a % b)) diff --git a/oneflow/python/test/ops/test_mod_int.py b/oneflow/python/test/ops/test_mod_int.py new file mode 100644 index 0000000000000000000000000000000000000000..7086341b525fd7129a1764bcb41ee746c1dfce00 --- /dev/null +++ b/oneflow/python/test/ops/test_mod_int.py @@ -0,0 +1,51 @@ +import oneflow as flow +import numpy as np + +func_config = flow.FunctionConfig() +func_config.default_data_type(flow.int32) +#func_config.default_data_type(flow.float32) + +def test_naive(test_case): + @flow.function(func_config) + def ModJob(a=flow.FixedTensorDef((5, 2), dtype=flow.int32), b=flow.FixedTensorDef((5, 2), dtype=flow.int32)): + #def ModJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((5, 2))): + return a % b + + x = (np.random.rand(5, 2)*1000).astype(np.int32)+1 + y = (np.random.rand(5, 2)*1000).astype(np.int32)+1 + z = None + z = ModJob(x, y).get().ndarray() + test_case.assertTrue(np.array_equal(z, x % y)) + +def test_broadcast(test_case): + @flow.function(func_config) + def ModJob(a=flow.FixedTensorDef((5, 2), dtype=flow.int32), b=flow.FixedTensorDef((1, 2), dtype=flow.int32)): + return a % b + + x = (np.random.rand(5, 2)*1000).astype(np.int32)+1 + y = (np.random.rand(1, 2)*1000).astype(np.int32)+1 + z = None + z = ModJob(x, y).get().ndarray() + test_case.assertTrue(np.array_equal(z, x % y)) + +def test_xy_mod_x1(test_case): + GenerateTest(test_case, (64, 64), (64, 1)) + +def test_xy_mod_1y(test_case): + GenerateTest(test_case, (64, 64), (1, 64)) + +def test_xyz_mod_x1z(test_case): + GenerateTest(test_case, (64, 64, 64), (64, 1, 64)) + +def test_xyz_mod_1y1(test_case): + GenerateTest(test_case, (64, 64, 64), (1, 64, 1)) + +def GenerateTest(test_case, a_shape, b_shape): + @flow.function(func_config) + def ModJob(a=flow.FixedTensorDef(a_shape, dtype=flow.int32), b=flow.FixedTensorDef(b_shape, dtype=flow.int32)): + return a % b + + a = (np.random.rand(*a_shape)*1000).astype(np.int32)+1 + b = (np.random.rand(*b_shape)*1000).astype(np.int32)+1 + y = ModJob(a, b).get().ndarray() + test_case.assertTrue(np.array_equal(y, a % b))