提交 4e99ffcf 编写于 作者: L lixinqi

Merge branch 'dev_python' of https://github.com/Oneflow-Inc/oneflow into dev_python

#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/ndarray/ndarray_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class BroadcastFloorModKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastFloorModKernel);
BroadcastFloorModKernel() = default;
~BroadcastFloorModKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
template<DeviceType device_type, typename T>
void BroadcastFloorModKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* a = BnInOp2Blob("a");
const Blob* b = BnInOp2Blob("b");
Blob* out = BnInOp2Blob("out");
size_t num_axes = out->shape().NumAxes();
NdarrayUtil<device_type, T>::BroadcastFloorMod(ctx.device_ctx, XpuVarNdarray<T>(out, num_axes),
XpuVarNdarray<const T>(a, num_axes),
XpuVarNdarray<const T>(b, num_axes));
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastFloorModConf, BroadcastFloorModKernel,
ARITHMETIC_DATA_TYPE_SEQ);
} // namespace oneflow
......@@ -12,7 +12,7 @@
#include "oneflow/core/common/util.h"
namespace oneflow {
#define ARITHMETIC_BINARY_FUNC_NAME_SEQ (Add)(Sub)(Mul)(Div)(Min)(Max)
#define ARITHMETIC_BINARY_FUNC_NAME_SEQ (Add)(Sub)(Mul)(Div)(Min)(Max)(FloorMod)
#define LOGICAL_BINARY_FUNC_NAME_SEQ (EQ)(NE)(GT)(GE)(LT)(LE)(AND)
#define PREPEND_PREFIX_BINARY_FUNC(name) OF_PP_CAT(BinaryFunc, name)
......@@ -73,6 +73,12 @@ struct BinaryFuncDiv final {
};
SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncDiv);
template<typename T>
struct BinaryFuncFloorMod final {
static OF_DEVICE_FUNC const T Invoke(const T x, const T y) { return x % y; }
};
SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncFloorMod);
template<typename T>
struct BinaryFuncMax final {
static OF_DEVICE_FUNC const T Invoke(const T x, const T y) { return x > y ? x : y; }
......@@ -198,6 +204,58 @@ struct BinaryFuncMin<half> final {
#endif // defined(__CUDACC__)
#if defined(__CUDACC__)
template<>
struct BinaryFuncFloorMod<float> final {
static __device__ __forceinline__ const float Invoke(const float x, const float y) {
return fmodf(x, y);
}
};
template<>
struct BinaryFuncFloorMod<double> final {
static __device__ __forceinline__ const double Invoke(const double x, const double y) {
return fmod(x, y);
}
};
template<>
struct BinaryFuncFloorMod<half> final {
static __device__ __forceinline__ const half Invoke(const half x, const half y) {
#if __CUDA_ARCH__ >= 530
return __float2half(fmodf(__half2float(x), __half2float(y)));
#else
NO_HALF_UTIL_FOUND;
#endif
}
};
#else
template<>
struct BinaryFuncFloorMod<float> final {
static __device__ __forceinline__ const float Invoke(const float x, const float y) {
return std::fmod(x, y);
}
};
template<>
struct BinaryFuncFloorMod<double> final {
static __device__ __forceinline__ const double Invoke(const double x, const double y) {
return std::fmod(x, y);
}
};
template<>
struct BinaryFuncFloorMod<float16> final {
static __device__ __forceinline__ const float16 Invoke(const float16 x, const float16 y) {
return static_cast<float16>(std::fmod(static_cast<float>(x), static_cast<float>(y)));
}
};
#endif // defined(__CUDACC__)
template<typename T, template<typename> class binary_func>
struct UnitOfBinaryFunc;
......
#include "oneflow/core/operator/broadcast_binary_op.h"
namespace oneflow {
class BroadcastFloorModOp final : public BroadcastBinaryOp {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastFloorModOp);
BroadcastFloorModOp() = default;
~BroadcastFloorModOp() override = default;
private:
const PbMessage& GetCustomizedConf() const override;
Maybe<void> VirtualGetSbpSignatures(
const std::function<Maybe<const BlobDesc*>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
};
const PbMessage& BroadcastFloorModOp::GetCustomizedConf() const {
return op_conf().broadcast_floor_mod_conf();
}
Maybe<void> BroadcastFloorModOp::VirtualGetSbpSignatures(
const std::function<Maybe<const BlobDesc*>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder().PartialSum("a").Broadcast("b").PartialSum("out").Build(
sbp_sig_list->mutable_sbp_signature()->Add());
return Maybe<void>::Ok();
}
REGISTER_OP(OperatorConf::kBroadcastFloorModConf, BroadcastFloorModOp);
} // namespace oneflow
......@@ -1206,6 +1206,13 @@ message BroadcastDivOpConf {
optional bool is_const = 4 [default = false];
}
message BroadcastFloorModOpConf {
required string a = 1;
required string b = 2;
required string out = 3;
optional bool is_const = 4 [default = false];
}
message BroadcastLikeOpConf {
required string x = 1;
required string like = 2;
......@@ -1718,6 +1725,7 @@ message OperatorConf {
BroadcastSubOpConf broadcast_sub_conf = 501;
BroadcastMulOpConf broadcast_mul_conf = 502;
BroadcastDivOpConf broadcast_div_conf = 503;
BroadcastFloorModOpConf broadcast_floor_mod_conf = 504;
SquareOpConf square_conf = 513;
SqrtOpConf sqrt_conf = 514;
RsqrtOpConf rsqrt_conf = 515;
......
......@@ -42,11 +42,11 @@ class BlobDef(blob_desc.BlobDesc):
@property
def is_dynamic(self):
raise NotImplementedError
@property
def disable_boxing(self):
raise NotImplementedError
@property
def is_tensor_list(self):
raise NotImplementedError
......@@ -54,7 +54,7 @@ class BlobDef(blob_desc.BlobDesc):
@property
def disable_boxing(self):
raise NotImplementedError
@property
def parallel_conf(self):
raise NotImplementedError
......@@ -101,6 +101,9 @@ class BlobDef(blob_desc.BlobDesc):
def __div__(self, rhs):
return oneflow.math.divide(self, rhs)
def __mod__(self, rhs):
return oneflow.math.mod(self, rhs)
def __eq__(self, rhs):
return oneflow.math.equal(self, rhs)
......@@ -145,11 +148,11 @@ class ConsistentBlob(BlobDef):
@property
def is_dynamic(self):
return c_api_util.JobBuildAndInferCtx_IsDynamic(self.job_name_, self.lbn_)
@property
def disable_boxing(self):
return c_api_util.JobBuildAndInferCtx_DisableBoxing(self.job_name_, self.lbn_)
@property
def is_tensor_list(self):
return c_api_util.JobBuildAndInferCtx_IsTensorList(self.job_name_, self.lbn_)
......@@ -157,7 +160,7 @@ class ConsistentBlob(BlobDef):
@property
def disable_boxing(self):
return c_api_util.JobBuildAndInferCtx_DisableBoxing(self.job_name_, self.lbn_)
@property
def parallel_conf(self):
return c_api_util.JobBuildAndInferCtx_GetParallelConfFromProducerView(self.job_name_,
......@@ -175,10 +178,10 @@ class MirroredBlob(BlobDef):
consistent_blob = ConsistentBlob(sub_lbi, auto_watched_within_scope=False)
self.sub_consistent_blob_list_.append(consistent_blob)
watch_scope_util.TryWatchOnce(self)
@property
@property
def sub_consistent_blob_list(self): return self.sub_consistent_blob_list_
@property
def static_shape(self):
return c_api_util.JobBuildAndInferCtx_MirroredBlobGetStaticShape(self.job_name_, self.lbn_)
......@@ -199,11 +202,11 @@ class MirroredBlob(BlobDef):
@property
def is_dynamic(self):
return c_api_util.JobBuildAndInferCtx_MirroredBlobIsDynamic(self.job_name_, self.lbn_)
@property
def disable_boxing(self):
return c_api_util.JobBuildAndInferCtx_MirroredBlobDisableBoxing(self.job_name_, self.lbn_)
@property
def is_tensor_list(self):
return c_api_util.JobBuildAndInferCtx_MirroredBlobIsTensorList(self.job_name_, self.lbn_)
......@@ -211,7 +214,7 @@ class MirroredBlob(BlobDef):
@property
def disable_boxing(self):
return c_api_util.JobBuildAndInferCtx_MirroredBlobDisableBoxing(self.job_name_, self.lbn_)
@property
def parallel_conf(self):
return c_api_util.JobBuildAndInferCtx_MirroredBlobGetParallelConfFromProducerView(
......
......@@ -59,6 +59,19 @@ def divide(x, y, name=None):
return broadcast_div(x, y, name)
@oneflow_export("math.mod")
def floor_mod(x, y, name=None):
if isinstance(x, (int, float)):
raise NotImplementedError
elif isinstance(y, (int, float)):
raise NotImplementedError
elif x.static_shape == y.static_shape:
# TODO: add element-wise op
return broadcast_floor_mod(x, y, name)
else:
return broadcast_floor_mod(x, y, name)
def scalar_add(x, operand, name=None):
op_conf = op_conf_util.OperatorConf()
setattr(
......@@ -197,6 +210,23 @@ def broadcast_div(x, y, name=None):
return remote_blob_util.RemoteBlob(lbi)
def broadcast_floor_mod(x, y, name=None):
op_conf = op_conf_util.OperatorConf()
setattr(
op_conf,
"name",
name if name is not None else id_util.UniqueStr("BroadcastMod_"),
)
op_conf.broadcast_floor_mod_conf.a = x.logical_blob_name
op_conf.broadcast_floor_mod_conf.b = y.logical_blob_name
op_conf.broadcast_floor_mod_conf.out = "out"
compile_context.CurJobAddOp(op_conf)
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_conf.name
lbi.blob_name = "out"
return remote_blob_util.RemoteBlob(lbi)
@oneflow_export("math.tanh")
def tanh(x, name=None):
op_conf = op_conf_util.OperatorConf()
......
import oneflow as flow
import numpy as np
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
def test_naive(test_case):
@flow.function(func_config)
def ModJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((5, 2))):
return a % b
x = np.random.rand(5, 2).astype(np.float32)
y = np.random.rand(5, 2).astype(np.float32)
z = None
z = ModJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x % y))
def test_broadcast(test_case):
@flow.function(func_config)
def ModJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((1, 2))):
return a % b
x = np.random.rand(5, 2).astype(np.float32)
y = np.random.rand(1, 2).astype(np.float32)
z = None
z = ModJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x % y))
def test_xy_mod_x1(test_case):
GenerateTest(test_case, (64, 64), (64, 1))
def test_xy_mod_1y(test_case):
GenerateTest(test_case, (64, 64), (1, 64))
def test_xyz_mod_x1z(test_case):
GenerateTest(test_case, (64, 64, 64), (64, 1, 64))
def test_xyz_mod_1y1(test_case):
GenerateTest(test_case, (64, 64, 64), (1, 64, 1))
def GenerateTest(test_case, a_shape, b_shape):
@flow.function(func_config)
def ModJob(a=flow.FixedTensorDef(a_shape), b=flow.FixedTensorDef(b_shape)):
return a % b
a = np.random.rand(*a_shape).astype(np.float32)
b = np.random.rand(*b_shape).astype(np.float32)
y = ModJob(a, b).get().ndarray()
test_case.assertTrue(np.array_equal(y, a % b))
import oneflow as flow
import numpy as np
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.int32)
#func_config.default_data_type(flow.float32)
def test_naive(test_case):
@flow.function(func_config)
def ModJob(a=flow.FixedTensorDef((5, 2), dtype=flow.int32), b=flow.FixedTensorDef((5, 2), dtype=flow.int32)):
#def ModJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((5, 2))):
return a % b
x = (np.random.rand(5, 2)*1000).astype(np.int32)+1
y = (np.random.rand(5, 2)*1000).astype(np.int32)+1
z = None
z = ModJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x % y))
def test_broadcast(test_case):
@flow.function(func_config)
def ModJob(a=flow.FixedTensorDef((5, 2), dtype=flow.int32), b=flow.FixedTensorDef((1, 2), dtype=flow.int32)):
return a % b
x = (np.random.rand(5, 2)*1000).astype(np.int32)+1
y = (np.random.rand(1, 2)*1000).astype(np.int32)+1
z = None
z = ModJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x % y))
def test_xy_mod_x1(test_case):
GenerateTest(test_case, (64, 64), (64, 1))
def test_xy_mod_1y(test_case):
GenerateTest(test_case, (64, 64), (1, 64))
def test_xyz_mod_x1z(test_case):
GenerateTest(test_case, (64, 64, 64), (64, 1, 64))
def test_xyz_mod_1y1(test_case):
GenerateTest(test_case, (64, 64, 64), (1, 64, 1))
def GenerateTest(test_case, a_shape, b_shape):
@flow.function(func_config)
def ModJob(a=flow.FixedTensorDef(a_shape, dtype=flow.int32), b=flow.FixedTensorDef(b_shape, dtype=flow.int32)):
return a % b
a = (np.random.rand(*a_shape)*1000).astype(np.int32)+1
b = (np.random.rand(*b_shape)*1000).astype(np.int32)+1
y = ModJob(a, b).get().ndarray()
test_case.assertTrue(np.array_equal(y, a % b))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册