提交 250c5d7b 编写于 作者: L lixinqi

Merge branch 'dev_python' into dev_eager

#ifndef ONEFLOW_CORE_COMMON_STRUCT_MACRO_TRAITS_H_
#define ONEFLOW_CORE_COMMON_STRUCT_MACRO_TRAITS_H_
namespace oneflow {
#define STRUCT_FIELD(T, field) StructField<T, STRUCT_FIELD_OFFSET(T, field)>
#define DEFINE_STRUCT_FIELD(T, field) \
template<> \
struct StructField<T, STRUCT_FIELD_OFFSET(T, field)> final \
: public StructFieldImpl<T, STRUCT_FIELD_TYPE(T, field), STRUCT_FIELD_OFFSET(T, field)> {};
template<typename T, int offset>
struct StructField {};
template<typename T, typename F, int offset>
struct StructFieldImpl {
using struct_type = T;
using field_type = F;
static const int offset_value = offset;
static T* StructPtr4FieldPtr(const F* field_ptr) {
return (T*)(((char*)field_ptr) - offset_value);
}
static F* FieldPtr4StructPtr(const T* struct_ptr) {
return (F*)(((char*)struct_ptr) + offset_value);
}
};
#define STRUCT_FIELD_TYPE(T, field) decltype(((T*)nullptr)->field)
#define STRUCT_FIELD_OFFSET(T, field) ((int)(long long)&((T*)nullptr)->field)
}
#endif // ONEFLOW_CORE_COMMON_STRUCT_MACRO_TRAITS_H_
#include "oneflow/core/common/struct_traits.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
struct OneflowTestNamespaceFoo {
OneflowTestNamespaceFoo() : x(0), bar(0), const_bar(0) {}
int x;
int bar;
const int const_bar;
};
DEFINE_STRUCT_FIELD(OneflowTestNamespaceFoo, bar);
DEFINE_STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar);
TEST(StructField, mutable_struct_mutable_field) {
OneflowTestNamespaceFoo foo;
auto* bar = &foo.bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
TEST(StructField, mutable_struct_const_field) {
OneflowTestNamespaceFoo foo;
auto* bar = &foo.const_bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
TEST(StructField, const_struct_mutable_field) {
const OneflowTestNamespaceFoo foo;
auto* bar = &foo.bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
TEST(StructField, const_struct_const_field) {
const OneflowTestNamespaceFoo foo;
auto* bar = &foo.const_bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
} // namespace oneflow
......@@ -31,6 +31,6 @@ void GatherKernel<device_type, T>::ForwardDataContent(
this->kernel_conf().gather_conf().axis(), out);
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGatherConf, GatherKernel, FLOATING_DATA_TYPE_SEQ);
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGatherConf, GatherKernel, GATHER_DATA_TYPE_SEQ);
} // namespace oneflow
......@@ -20,15 +20,6 @@ void GatherForward(DeviceCtx* ctx, const Blob* indices, const Blob* in, int64_t
flat_in_shape, out->mut_dptr<T>(), offset);
}
template<DeviceType device_type, typename T, typename K>
void GatherBackward(DeviceCtx* ctx, const Blob* indices, const Blob* out_diff, int64_t axis,
Blob* in_diff, const int64_t offset) {
const Shape& flat_in_shape = GetFlatShape(in_diff->shape(), axis);
GatherKernelUtilImpl<device_type, T, K>::Backward(
ctx, indices->dptr<K>(), indices->shape().elem_cnt(), out_diff->dptr<T>(), flat_in_shape,
in_diff->mut_dptr<T>(), offset);
}
template<DeviceType device_type, typename T>
struct GatherSwitchUtil final {
#define MAKE_GATHER_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K>
......@@ -36,7 +27,6 @@ struct GatherSwitchUtil final {
DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_GATHER_SWITCH_ENTRY, \
MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ));
DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherForward);
DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherBackward);
#undef DEFINE_GATHER_STATIC_SWITCH_FUNC
#undef MAKE_GATHER_SWITCH_ENTRY
};
......@@ -49,13 +39,6 @@ void GatherKernelUtil<device_type, T>::Forward(DeviceCtx* ctx, const Blob* indic
GatherKernelUtil<device_type, T>::Forward(ctx, indices, in, axis, out, 0);
}
template<DeviceType device_type, typename T>
void GatherKernelUtil<device_type, T>::Backward(DeviceCtx* ctx, const Blob* indices,
const Blob* out_diff, const int64_t axis,
Blob* in_diff) {
GatherKernelUtil<device_type, T>::Backward(ctx, indices, out_diff, axis, in_diff, 0);
}
template<DeviceType device_type, typename T>
void GatherKernelUtil<device_type, T>::Forward(DeviceCtx* ctx, const Blob* indices, const Blob* in,
const int64_t axis, Blob* out,
......@@ -64,20 +47,10 @@ void GatherKernelUtil<device_type, T>::Forward(DeviceCtx* ctx, const Blob* indic
indices, in, axis, out, offset);
}
template<DeviceType device_type, typename T>
void GatherKernelUtil<device_type, T>::Backward(DeviceCtx* ctx, const Blob* indices,
const Blob* out_diff, const int64_t axis,
Blob* in_diff, const int64_t offset) {
GatherSwitchUtil<device_type, T>::SwitchGatherBackward(SwitchCase(indices->data_type()), ctx,
indices, out_diff, axis, in_diff, offset);
}
template<typename T, typename K>
struct GatherKernelUtilImpl<DeviceType::kCPU, T, K> final {
static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in,
const Shape& flat_in_shape, T* out, const int64_t offset);
static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff,
const Shape& flat_in_shape, T* in_diff, const int64_t offset);
};
template<typename T, typename K>
......@@ -103,42 +76,17 @@ void GatherKernelUtilImpl<DeviceType::kCPU, T, K>::Forward(DeviceCtx* ctx, const
}
}
template<typename T, typename K>
void GatherKernelUtilImpl<DeviceType::kCPU, T, K>::Backward(DeviceCtx* ctx, const K* indices,
int64_t num_indices, const T* out_diff,
const Shape& flat_in_shape, T* in_diff,
const int64_t offset) {
const int64_t outer_dim_size = flat_in_shape.At(0);
const int64_t gather_dim_size = flat_in_shape.At(1);
const int64_t inner_dim_size = flat_in_shape.At(2);
FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) {
FOR_RANGE(int64_t, i, 0, num_indices) {
CHECK_GE(indices[i], 0);
const int64_t idx = indices[i] - offset;
T* to = in_diff + outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size;
if (idx >= 0 && idx < gather_dim_size) {
const T* from = out_diff + outer_idx * num_indices * inner_dim_size + i * inner_dim_size;
std::transform(from, from + inner_dim_size, to, to, std::plus<T>());
}
}
}
}
#define INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL(in_type_pair, index_type_pair) \
template struct GatherKernelUtilImpl<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair), \
OF_PP_PAIR_FIRST(index_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL, FLOATING_DATA_TYPE_SEQ,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL, GATHER_DATA_TYPE_SEQ,
INDEX_DATA_TYPE_SEQ);
#undef INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL
#define INITIATE_GATHER_KERNEL_UTIL(device_type, in_type_pair) \
template struct GatherKernelUtil<device_type, OF_PP_PAIR_FIRST(in_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL, DEVICE_TYPE_SEQ,
FLOATING_DATA_TYPE_SEQ);
template struct GatherKernelUtil<DeviceType::kGPU, int32_t>;
GATHER_DATA_TYPE_SEQ);
#undef INITIATE_GATHER_KERNEL_UTIL
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
template struct GatherKernelUtil<DeviceType::kGPU, float16>;
#endif
} // namespace oneflow
......@@ -45,20 +45,6 @@ __global__ void GatherForwardGpu(const IDX elem_cnt, const K* indices, const IDX
}
}
template<typename T, typename K, typename IDX>
__global__ void GatherBackwardGpu(const IDX elem_cnt, const K* indices, const IDX num_indices,
const T* out_diff, const IDX gather_dim_size,
const IDX inner_dim_size, T* in_diff, const IDX offset) {
CUDA_1D_KERNEL_LOOP_T(IDX, i, elem_cnt) {
const T diff_val = out_diff[i];
if (diff_val != static_cast<T>(0)) {
const int64_t in_offset =
GetInOffset<K, IDX>(i, indices, num_indices, gather_dim_size, inner_dim_size, offset);
if (in_offset >= 0) { gpu_atomic_add(in_diff + in_offset, diff_val); }
}
}
}
bool IsSafeUseIndex32(const Shape& flat_in_shape, const int64_t num_indices) {
const int64_t in_elem_cnt = flat_in_shape.elem_cnt();
const int64_t out_elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2);
......@@ -84,21 +70,6 @@ struct GatherKernelUtilImpl<DeviceType::kGPU, T, K> final {
offset);
}
}
static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff,
const Shape& flat_in_shape, T* in_diff, const int64_t offset) {
const int64_t elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2);
if (IsSafeUseIndex32(flat_in_shape, num_indices)) {
GatherBackwardGpu<T, K, int32_t>
<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
elem_cnt, indices, num_indices, out_diff, flat_in_shape.At(1), flat_in_shape.At(2),
in_diff, offset);
} else {
GatherBackwardGpu<T, K, int64_t>
<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
elem_cnt, indices, num_indices, out_diff, flat_in_shape.At(1), flat_in_shape.At(2),
in_diff, offset);
}
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
......@@ -110,26 +81,13 @@ struct GatherKernelUtilImpl<DeviceType::kGPU, float16, K> final {
ctx, indices, num_indices, reinterpret_cast<const half*>(in), flat_in_shape,
reinterpret_cast<half*>(out), offset);
}
static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices,
const float16* out_diff, const Shape& flat_in_shape, float16* in_diff,
const int64_t offset) {
GatherKernelUtilImpl<DeviceType::kGPU, half, K>::Backward(
ctx, indices, num_indices, reinterpret_cast<const half*>(out_diff), flat_in_shape,
reinterpret_cast<half*>(in_diff), offset);
}
};
#endif
#define INITIATE_GATHER_KERNEL_UTIL_GPU_IMPL(in_type_pair, index_type_pair) \
template struct GatherKernelUtilImpl<DeviceType::kGPU, OF_PP_PAIR_FIRST(in_type_pair), \
OF_PP_PAIR_FIRST(index_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_GPU_IMPL,
FLOATING_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t,
DataType::kInt32)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
FLOAT16_DATA_TYPE_SEQ
#endif
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_GPU_IMPL, GATHER_DATA_TYPE_SEQ,
INDEX_DATA_TYPE_SEQ);
#undef INITIATE_GATHER_KERNEL_UTIL_GPU_IMPL
......
......@@ -8,22 +8,22 @@ namespace oneflow {
template<DeviceType device_type, typename T>
struct GatherKernelUtil final {
static void Forward(DeviceCtx* ctx, const Blob* indices, const Blob* in, int64_t axis, Blob* out);
static void Backward(DeviceCtx* ctx, const Blob* indices, const Blob* out_diff, int64_t axis,
Blob* in_diff);
static void Forward(DeviceCtx* ctx, const Blob* indices, const Blob* in, int64_t axis, Blob* out,
int64_t offset);
static void Backward(DeviceCtx* ctx, const Blob* indices, const Blob* out_diff, int64_t axis,
Blob* in_diff, int64_t offset);
};
template<DeviceType device_type, typename T, typename K>
struct GatherKernelUtilImpl final {
static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in,
const Shape& flat_in_shape, T* out, int64_t offset);
static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff,
const Shape& flat_in_shape, T* in_diff, int64_t offset);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
#define GATHER_DATA_TYPE_SEQ ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ
#else
#define GATHER_DATA_TYPE_SEQ ARITHMETIC_DATA_TYPE_SEQ
#endif
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_GATHER_KERNEL_UTIL_H_
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/gather_kernel_util.h"
#include "oneflow/core/kernel/unsorted_segment_sum_kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
template<DeviceType device_type, typename T, typename K>
class GatherMs0GradKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(GatherMs0GradKernel);
......@@ -16,38 +16,46 @@ class GatherMs0GradKernel final : public KernelIf<device_type> {
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
template<DeviceType device_type, typename T>
const PbMessage& GatherMs0GradKernel<device_type, T>::GetCustomizedOpConf() const {
template<DeviceType device_type, typename T, typename K>
const PbMessage& GatherMs0GradKernel<device_type, T, K>::GetCustomizedOpConf() const {
return this->op_conf().gather_ms0_grad_conf();
}
template<DeviceType device_type, typename T>
void GatherMs0GradKernel<device_type, T>::ForwardDataContent(
template<DeviceType device_type, typename T, typename K>
void GatherMs0GradKernel<device_type, T, K>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* indices = BnInOp2Blob("indices");
const Blob* out_diff = BnInOp2Blob("out_diff");
Blob* in_diff = BnInOp2Blob("in_diff");
const int64_t offset = this->kernel_conf().gather_ms0_grad_conf().offset();
Memset<device_type>(ctx.device_ctx, in_diff->mut_dptr<T>(), 0, in_diff->ByteSizeOfBlobBody());
GatherKernelUtil<device_type, T>::Backward(ctx.device_ctx, indices, out_diff, 0, in_diff, offset);
const int64_t num_segment_ids = indices->shape().elem_cnt();
const ShapeView& in_diff_shape = in_diff->shape();
const int64_t inner_dim_size = in_diff_shape.Count(1);
const int64_t num_segments = in_diff_shape.At(0);
CHECK_EQ(out_diff->shape().elem_cnt(), num_segment_ids * inner_dim_size);
UnsortedSegmentSumKernelUtil<device_type, T, K>::UnsortedSegmentSum(
ctx.device_ctx, indices->dptr<K>(), out_diff->dptr<T>(), num_segment_ids, num_segments, 1,
inner_dim_size, offset, in_diff->mut_dptr<T>());
}
namespace {
Kernel* CreateGatherGradKernel(const KernelConf& kernel_conf) {
static const HashMap<std::string, std::function<Kernel*()>> creators = {
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (GatherMs0GradKernel),
DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
MAKE_KERNEL_CREATOR_ENTRY(GatherMs0GradKernel, DeviceType::kGPU,
(float16, DataType::kFloat16))
#endif
};
return creators.at(
GetHashKey(kernel_conf.op_attribute().op_conf().device_type(), kernel_conf.data_type()))();
}
#define MAKE_GATHER_MS0_GRAD_KERNEL_ENTRY(device_type_v, data_type_pair, indices_type_pair) \
NEW_REGISTER_KERNEL(OperatorConf::kGatherMs0GradConf, \
GatherMs0GradKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ((kernel_conf.op_attribute().op_conf().device_type() == device_type_v) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.gather_ms0_grad_conf().indices_data_type())); \
});
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_GATHER_MS0_GRAD_KERNEL_ENTRY, DEVICE_TYPE_SEQ,
UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)
#undef MAKE_GATHER_MS0_GRAD_KERNEL_ENTRY
REGISTER_KERNEL_CREATOR(OperatorConf::kGatherMs0GradConf, CreateGatherGradKernel);
} // namespace
} // namespace oneflow
......@@ -31,6 +31,6 @@ void GatherMs0Kernel<device_type, T>::ForwardDataContent(
GatherKernelUtil<device_type, T>::Forward(ctx.device_ctx, indices, in, 0, out, offset);
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGatherMs0Conf, GatherMs0Kernel, FLOATING_DATA_TYPE_SEQ);
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGatherMs0Conf, GatherMs0Kernel, GATHER_DATA_TYPE_SEQ);
} // namespace oneflow
#include "oneflow/core/kernel/indexed_slices_reduce_sum_kernel_util.h"
#include "oneflow/core/kernel/unique_kernel_util.h"
#include "oneflow/core/kernel/gather_kernel_util.h"
#include "oneflow/core/kernel/unsorted_segment_sum_kernel_util.h"
namespace oneflow {
......@@ -24,8 +24,9 @@ void IndexedSlicesReduceSumKernelUtil<device_type, K, T, IDX>::ReduceSum(
unique_workspace_size);
const Shape flat_in_shape({1, n, m});
Memset<device_type>(ctx, values_out, 0, n * m * sizeof(T));
GatherKernelUtilImpl<device_type, T, IDX>::Backward(ctx, unique_idx_ptr, n, values, flat_in_shape,
values_out, 0);
UnsortedSegmentSumKernelUtil<device_type, T, IDX>::UnsortedSegmentSum(ctx, unique_idx_ptr, values,
n, n, 1, m, 0, values_out);
}
template<DeviceType device_type, typename K, typename T, typename IDX>
......
......@@ -151,6 +151,7 @@ message GatherMs0KernelConf {
message GatherMs0GradKernelConf {
required int64 offset = 1;
required DataType indices_data_type = 2;
}
message NcclTupleReduceConf {
......@@ -178,6 +179,10 @@ message IndexedSlicesReduceSumKernelConf {
required DataType indices_data_type = 1;
}
message UnsortedSegmentSumKernelConf {
required DataType indices_data_type = 1;
}
message KernelConf {
required OpAttribute op_attribute = 1;
required DataType data_type = 2;
......@@ -204,6 +209,7 @@ message KernelConf {
XrtLaunchKernelConf xrt_launch_conf = 353;
UniqueWithCountsKernelConf unique_with_counts_conf = 354;
IndexedSlicesReduceSumKernelConf indexed_slices_reduce_sum_conf = 355;
UnsortedSegmentSumKernelConf unsorted_segment_sum_conf = 356;
AccuracyKernelConf accuracy_conf = 401;
SliceKernelConf slice_conf = 402;
......
......@@ -138,6 +138,18 @@ double LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double
return lr * decayed;
}
double PiecewiseScalingLearningRate(const PiecewiseScalingConf& conf, double lr,
int64_t cur_batch_num) {
const PbRf<int64_t>& boundaries = conf.boundaries();
const PbRf<double>& scales = conf.scales();
CHECK_EQ(boundaries.size() + 1, scales.size());
size_t i = 0;
for (; i < boundaries.size(); ++i) {
if (cur_batch_num <= boundaries[i]) { break; }
}
return scales[i] * lr;
}
double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) {
if (conf.has_exponential_conf()) {
return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num);
......@@ -153,6 +165,8 @@ double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int6
return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num);
} else if (conf.has_linear_cosine_conf()) {
return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num);
} else if (conf.has_piecewise_scaling_conf()) {
return PiecewiseScalingLearningRate(conf.piecewise_scaling_conf(), lr, cur_batch_num);
} else {
UNIMPLEMENTED();
}
......
#include "oneflow/core/kernel/unsorted_segment_sum_kernel.h"
#include "oneflow/core/kernel/gather_kernel_util.h"
#include "oneflow/core/kernel/unsorted_segment_sum_kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
const PbMessage& UnsortedSegmentSumKernel<device_type, T>::GetCustomizedOpConf() const {
template<DeviceType device_type, typename T, typename K>
const PbMessage& UnsortedSegmentSumKernel<device_type, T, K>::GetCustomizedOpConf() const {
return this->op_conf().unsorted_segment_sum_conf();
}
template<DeviceType device_type, typename T>
void UnsortedSegmentSumKernel<device_type, T>::ForwardDataContent(
template<DeviceType device_type, typename T, typename K>
void UnsortedSegmentSumKernel<device_type, T, K>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* segment_ids = BnInOp2Blob("segment_ids");
const Blob* data = BnInOp2Blob("data");
Blob* out = BnInOp2Blob("out");
Memset<device_type>(ctx.device_ctx, out->mut_dptr<T>(), 0, out->ByteSizeOfBlobBody());
if (segment_ids->IsBodyEmpty() || data->IsBodyEmpty()) { return; }
GatherKernelUtil<device_type, T>::Backward(
ctx.device_ctx, segment_ids, data, this->op_conf().unsorted_segment_sum_conf().axis(), out);
}
const ShapeView& out_shape = out->shape();
const int64_t axis = this->op_conf().unsorted_segment_sum_conf().axis();
const int64_t outer_dim_size = out_shape.Count(0, axis);
const int64_t num_segments = this->op_conf().unsorted_segment_sum_conf().num_segments();
CHECK_EQ(out_shape.At(axis), num_segments);
const int64_t inner_dim_size = out_shape.Count(axis + 1);
const int64_t num_segment_ids = segment_ids->shape().elem_cnt();
CHECK_EQ(inner_dim_size * num_segment_ids * outer_dim_size, data->shape().elem_cnt());
UnsortedSegmentSumKernelUtil<device_type, T, K>::UnsortedSegmentSum(
ctx.device_ctx, segment_ids->dptr<K>(), data->dptr<T>(), num_segment_ids, num_segments,
outer_dim_size, inner_dim_size, 0, out->mut_dptr<T>());
} // namespace oneflow
namespace {
Kernel* CreateUnsortedSegmentSumKernel(const KernelConf& kernel_conf) {
static const HashMap<std::string, std::function<Kernel*()>> creators = {
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (UnsortedSegmentSumKernel),
DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
MAKE_KERNEL_CREATOR_ENTRY(UnsortedSegmentSumKernel, DeviceType::kGPU,
(float16, DataType::kFloat16))
#endif
};
return creators.at(
GetHashKey(kernel_conf.op_attribute().op_conf().device_type(), kernel_conf.data_type()))();
}
#define MAKE_UNSORTED_SEGMENT_SUM_KERNEL_ENTRY(device_type_v, data_type_pair, indices_type_pair) \
NEW_REGISTER_KERNEL(OperatorConf::kUnsortedSegmentSumConf, \
UnsortedSegmentSumKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ((kernel_conf.op_attribute().op_conf().device_type() == device_type_v) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.unsorted_segment_sum_conf().indices_data_type())); \
});
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_UNSORTED_SEGMENT_SUM_KERNEL_ENTRY, DEVICE_TYPE_SEQ,
UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)
#undef MAKE_UNSORTED_SEGMENT_SUM_KERNEL_ENTRY
REGISTER_KERNEL_CREATOR(OperatorConf::kUnsortedSegmentSumConf, CreateUnsortedSegmentSumKernel);
} // namespace
} // namespace oneflow
......@@ -5,7 +5,7 @@
namespace oneflow {
template<DeviceType device_type, typename T>
template<DeviceType device_type, typename T, typename K>
class UnsortedSegmentSumKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(UnsortedSegmentSumKernel);
......
#include "oneflow/core/kernel/unsorted_segment_sum_kernel_util.h"
namespace oneflow {
template<typename T, typename K>
struct UnsortedSegmentSumKernelUtil<DeviceType::kCPU, T, K> final {
static void UnsortedSegmentSum(DeviceCtx* ctx, const K* segment_ids, const T* data,
int64_t num_segment_ids, int64_t num_segments,
int64_t outer_dim_size, int64_t inner_dim_size,
int64_t segment_id_offset, T* out);
};
template<typename T, typename K>
void UnsortedSegmentSumKernelUtil<DeviceType::kCPU, T, K>::UnsortedSegmentSum(
DeviceCtx* ctx, const K* segment_ids, const T* data, int64_t num_segment_ids,
int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset,
T* out) {
FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) {
FOR_RANGE(int64_t, i, 0, num_segment_ids) {
CHECK_GE(segment_ids[i], 0);
const int64_t idx = segment_ids[i] - segment_id_offset;
T* to = out + outer_idx * num_segments * inner_dim_size + idx * inner_dim_size;
if (idx >= 0 && idx < num_segments) {
const T* from = data + outer_idx * num_segment_ids * inner_dim_size + i * inner_dim_size;
std::transform(from, from + inner_dim_size, to, to, std::plus<T>());
}
}
}
}
#define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU(in_type_pair, index_type_pair) \
template struct UnsortedSegmentSumKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair), \
OF_PP_PAIR_FIRST(index_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU,
UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);
#undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU
} // namespace oneflow
#include "oneflow/core/kernel/unsorted_segment_sum_kernel_util.h"
#include "oneflow/core/kernel/kernel_util.cuh"
#include "oneflow/core/kernel/kernel.h"
#include <assert.h>
namespace oneflow {
namespace {
template<typename K, typename IDX>
__device__ IDX GetOutOffset(const IDX data_offset, const K* segment_ids, const IDX num_segment_ids,
const IDX num_segments, const IDX inner_dim_size,
const IDX segment_id_offset) {
const IDX outer_dim_elem_cnt = num_segment_ids * inner_dim_size;
const IDX outer_idx = data_offset / outer_dim_elem_cnt;
const IDX segment_id_idx = data_offset % outer_dim_elem_cnt / inner_dim_size;
const IDX inner_idx = data_offset % inner_dim_size;
const K origin_idx = segment_ids[segment_id_idx];
assert(origin_idx >= 0);
const IDX idx = origin_idx - segment_id_offset;
if (idx >= 0 && idx < num_segments) {
return outer_idx * num_segments * inner_dim_size + idx * inner_dim_size + inner_idx;
} else {
return -1;
}
}
template<typename T, typename K, typename IDX>
__global__ void UnsortedSegmentSumGpu(const IDX data_elem_cnt, const K* segment_ids,
const IDX num_segment_ids, const T* data,
const IDX num_segments, const IDX inner_dim_size, T* out,
const IDX segment_id_offset) {
CUDA_1D_KERNEL_LOOP_T(IDX, i, data_elem_cnt) {
const T val = data[i];
if (val != static_cast<T>(0)) {
const int64_t out_offset = GetOutOffset<K, IDX>(i, segment_ids, num_segment_ids, num_segments,
inner_dim_size, segment_id_offset);
if (out_offset >= 0) { gpu_atomic_add(out + out_offset, val); }
}
}
}
bool IsSafeUseIndex32(const int64_t num_segment_ids, const int64_t num_segments,
const int64_t outer_dim_size, const int64_t inner_dim_size) {
const int64_t data_elem_cnt = outer_dim_size * num_segment_ids * inner_dim_size;
const int64_t out_elem_cnt = outer_dim_size * num_segments * inner_dim_size;
return std::max(out_elem_cnt, data_elem_cnt) < GetMaxVal<int32_t>() / 2;
}
} // namespace
template<typename T, typename K>
struct UnsortedSegmentSumKernelUtil<DeviceType::kGPU, T, K> final {
static void UnsortedSegmentSum(DeviceCtx* ctx, const K* segment_ids, const T* data,
int64_t num_segment_ids, int64_t num_segments,
int64_t outer_dim_size, int64_t inner_dim_size,
int64_t segment_id_offset, T* out) {
const int64_t data_elem_cnt = outer_dim_size * num_segment_ids * inner_dim_size;
if (IsSafeUseIndex32(num_segment_ids, num_segments, outer_dim_size, inner_dim_size)) {
UnsortedSegmentSumGpu<T, K, int32_t>
<<<BlocksNum4ThreadsNum(data_elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
data_elem_cnt, segment_ids, num_segment_ids, data, num_segments, inner_dim_size, out,
segment_id_offset);
} else {
UnsortedSegmentSumGpu<T, K, int64_t>
<<<BlocksNum4ThreadsNum(data_elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
data_elem_cnt, segment_ids, num_segment_ids, data, num_segments, inner_dim_size, out,
segment_id_offset);
}
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
template<typename K>
struct UnsortedSegmentSumKernelUtil<DeviceType::kGPU, float16, K> final {
static void UnsortedSegmentSum(DeviceCtx* ctx, const K* segment_ids, const float16* data,
int64_t num_segment_ids, int64_t num_segments,
int64_t outer_dim_size, int64_t inner_dim_size,
int64_t segment_id_offset, float16* out) {
UnsortedSegmentSumKernelUtil<DeviceType::kGPU, half, K>::UnsortedSegmentSum(
ctx, segment_ids, reinterpret_cast<const half*>(data), num_segment_ids, num_segments,
outer_dim_size, inner_dim_size, segment_id_offset, reinterpret_cast<half*>(out));
}
};
#endif
#define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_GPU(in_type_pair, index_type_pair) \
template struct UnsortedSegmentSumKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(in_type_pair), \
OF_PP_PAIR_FIRST(index_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_GPU,
UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);
#undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_GPU
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_
#define ONEFLOW_CORE_KERNEL_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_
#include "oneflow/core/kernel/kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T, typename K>
struct UnsortedSegmentSumKernelUtil final {
static void UnsortedSegmentSum(DeviceCtx* ctx, const K* segment_ids, const T* data,
int64_t num_segment_ids, int64_t num_segments,
int64_t outer_dim_size, int64_t inner_dim_size,
int64_t segment_id_offset, T* out);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000
#define UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ \
FLOATING_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) FLOAT16_DATA_TYPE_SEQ
#else
#define UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ \
FLOATING_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#endif
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_
......@@ -95,6 +95,8 @@ void GatherMs0GradOp::VirtualGenKernelConf(
BalancedSplitter bs(conf.gather_dim_size(), parallel_ctx->parallel_num());
int64_t offset = bs.At(parallel_ctx->parallel_id()).begin();
kernel_conf->mutable_gather_ms0_grad_conf()->set_offset(offset);
kernel_conf->mutable_gather_ms0_grad_conf()->set_indices_data_type(
GetBlobDesc4BnInOp("indices")->data_type());
}
REGISTER_OP(OperatorConf::kGatherMs0GradConf, GatherMs0GradOp);
......
......@@ -543,6 +543,11 @@ message LinearCosineDecayConf {
optional double beta = 4 [default = 0.001];
}
message PiecewiseScalingConf {
repeated int64 boundaries = 1;
repeated double scales = 2;
}
message LearningRateDecayConf {
oneof type {
ExponentialDecayConf exponential_conf = 2000;
......@@ -552,6 +557,7 @@ message LearningRateDecayConf {
PolynomialDecayConf polynomial_conf = 2004;
CosineDecayConf cosine_conf = 2005;
LinearCosineDecayConf linear_cosine_conf = 2006;
PiecewiseScalingConf piecewise_scaling_conf = 2007;
}
}
......
......@@ -74,6 +74,13 @@ Maybe<void> UnsortedSegmentSumOp::InferBatchAxis(
return Maybe<void>::Ok();
}
void UnsortedSegmentSumOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
kernel_conf->mutable_unsorted_segment_sum_conf()->set_indices_data_type(
GetBlobDesc4BnInOp("segment_ids")->data_type());
}
REGISTER_OP(OperatorConf::kUnsortedSegmentSumConf, UnsortedSegmentSumOp);
} // namespace oneflow
......@@ -22,6 +22,9 @@ class UnsortedSegmentSumOp final : public Operator {
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc*>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
KernelConf* kernel_conf) const override;
};
} // namespace oneflow
......
......@@ -293,12 +293,12 @@ def softmax_grad(y, dy, axis=None, name=None):
dx = oneflow.transpose(dx, perm=permute)
return dx
@oneflow_export("nn.sparse_softmax_cross_entropy_with_logits")
def sparse_softmax_cross_entropy_with_logits(
labels=None, logits=None, name=None
@oneflow_export("nn.sparse_cross_entropy")
def sparse_cross_entropy(
labels=None, prediction=None, name=None
):
assert labels is not None
assert logits is not None
assert prediction is not None
op_conf = op_conf_util.OperatorConf()
setattr(
op_conf,
......@@ -308,7 +308,7 @@ def sparse_softmax_cross_entropy_with_logits(
setattr(
op_conf.sparse_cross_entropy_conf,
"prediction",
softmax(logits).logical_blob_name,
prediction.logical_blob_name,
)
setattr(
op_conf.sparse_cross_entropy_conf, "label", labels.logical_blob_name
......@@ -320,6 +320,14 @@ def sparse_softmax_cross_entropy_with_logits(
lbi.blob_name = "out"
return remote_blob_util.RemoteBlob(lbi)
@oneflow_export("nn.sparse_softmax_cross_entropy_with_logits")
def sparse_softmax_cross_entropy_with_logits(
labels=None, logits=None, name=None
):
assert labels is not None
assert logits is not None
return sparse_cross_entropy(labels=labels, prediction=softmax(logits))
@oneflow_export("nn.sigmoid_cross_entropy_with_logits")
def sigmoid_cross_entropy_with_logits(
labels=None, logits=None, name=None
......
......@@ -31,8 +31,14 @@ def _run_test(test_case, indices, values, indices_dtype, values_dtype, device):
out_indices, out_values, num_unique = TestJob(indices, values).get()
_check(test_case, indices, values, out_indices.ndarray(), out_values.ndarray(), num_unique.ndarray())
def test_indexed_slices_reduce_sum(test_case):
def test_indexed_slices_reduce_sum_gpu(test_case):
indices = np.random.randint(0, 32, 1024).astype(np.int32)
values = np.random.rand(1024, 8).astype(np.float32)
_run_test(test_case, indices, values, flow.int32, flow.float32, 'gpu')
def test_indexed_slices_reduce_sum_cpu(test_case):
indices = np.random.randint(0, 32, 1024).astype(np.int32)
values = np.random.rand(1024, 8).astype(np.float32)
_run_test(test_case, indices, values, flow.int32, flow.float32, 'cpu')
import os
import numpy as np
import tensorflow as tf
import oneflow as flow
from collections import OrderedDict
from test_util import GenArgList
from test_util import GetSavePath
from test_util import Save
def compare_with_tensorflow(device_type, num_classes, batch_size):
assert device_type in ["gpu", "cpu"]
flow.clear_default_session()
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.train.primary_lr(1e-4)
func_config.train.model_update_conf(dict(naive_conf={}))
@flow.function(func_config)
def SparseSoftmaxCrossEntropyWithLogitsJob(
labels=flow.FixedTensorDef((batch_size, ), dtype=flow.int32)
):
with flow.device_prior_placement(device_type, "0:0"):
x = flow.get_variable(
"x",
shape=(batch_size, num_classes),
dtype=flow.float,
initializer=flow.random_uniform_initializer(minval=-10, maxval=10),
trainable=True,
)
prediction = flow.nn.softmax(logits=x)
loss = flow.nn.sparse_cross_entropy(labels=labels, prediction=prediction)
#loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=x)
loss = flow.identity(loss)
flow.losses.add_loss(loss)
flow.watch(x, Save("x"))
flow.watch_diff(x, Save("x_diff"))
flow.watch(loss, Save("loss"))
flow.watch_diff(loss, Save("loss_diff"))
return loss
# fake labels
labels = np.random.randint(0, num_classes, size=(batch_size, )).astype(np.int32)
# OneFlow
check_point = flow.train.CheckPoint()
check_point.init()
of_out = SparseSoftmaxCrossEntropyWithLogitsJob(labels).get()
# TensorFlow
with tf.GradientTape(persistent=True) as tape:
x = tf.Variable(np.load(os.path.join(GetSavePath(), "x.npy")))
tf_out = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, x)
loss_diff = np.load(os.path.join(GetSavePath(), "loss_diff.npy"))
tf_x_diff = tape.gradient(tf_out, x, loss_diff)
assert np.allclose(of_out.ndarray(), tf_out.numpy(), rtol=1e-5, atol=1e-5)
assert np.allclose(
np.load(os.path.join(GetSavePath(), "x_diff.npy")), tf_x_diff.numpy(), rtol=1e-5, atol=1e-5
)
flow.clear_default_session()
def test_sparse_softmax_cross_entropy_with_logits(test_case):
arg_dict = OrderedDict()
arg_dict["device_type"] = ["gpu"]
arg_dict["num_classes"] = [1000]
arg_dict["batch_size"] = [64]
for arg in GenArgList(arg_dict):
compare_with_tensorflow(*arg)
import oneflow as flow
import numpy as np
import tensorflow as tf
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
def _check(test_case, data, segment_ids, num_segments, out):
ref_out = tf.math.unsorted_segment_sum(data, segment_ids, num_segments).numpy()
test_case.assertTrue(np.allclose(ref_out, out))
def _run_test(test_case, data, segment_ids, num_segments, data_dtype, segment_id_dtype, device):
@flow.function(func_config)
def TestJob(
data=flow.FixedTensorDef(data.shape, dtype=data_dtype),
segment_ids=flow.FixedTensorDef(segment_ids.shape, dtype=segment_id_dtype)):
with flow.fixed_placement(device, "0:0"):
return flow.math.unsorted_segment_sum(data=data, segment_ids=segment_ids, num_segments=num_segments)
out = TestJob(data, segment_ids).get()
_check(test_case, data, segment_ids, num_segments, out.ndarray())
def test_unsorted_segment_sum_gpu(test_case):
data = np.random.rand(1024, 8).astype(np.float32)
segment_ids = np.random.randint(0, 32, 1024).astype(np.int32)
_run_test(test_case, data, segment_ids, 32, flow.float32, flow.int32, 'gpu')
def test_unsorted_segment_sum_cpu(test_case):
data = np.random.rand(1024, 8).astype(np.float32)
segment_ids = np.random.randint(0, 32, 1024).astype(np.int32)
_run_test(test_case, data, segment_ids, 32, flow.float32, flow.int32, 'cpu')
def test_unsorted_segment_sum_gpu_2d(test_case):
data = np.random.rand(1024, 8).astype(np.float32).reshape([4, 256, 8])
segment_ids = np.random.randint(0, 32, 1024).astype(np.int32).reshape([4, 256])
_run_test(test_case, data, segment_ids, 32, flow.float32, flow.int32, 'gpu')
def test_unsorted_segment_sum_cpu_2d(test_case):
data = np.random.rand(1024, 8).astype(np.float32).reshape([4, 256, 8])
segment_ids = np.random.randint(0, 32, 1024).astype(np.int32).reshape([4, 256])
_run_test(test_case, data, segment_ids, 32, flow.float32, flow.int32, 'cpu')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册