diff --git a/oneflow/core/common/preprocessor.h b/oneflow/core/common/preprocessor.h index 5a59bdc0c046e3aa819f1d696e2215d94971f308..a73ea5de01088d4634f159afde62105fa7dd89f7 100644 --- a/oneflow/core/common/preprocessor.h +++ b/oneflow/core/common/preprocessor.h @@ -32,6 +32,8 @@ #define OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, ...) \ OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, __VA_ARGS__) +#define OF_PP_SEQ_PRODUCT(seq0, ...) OF_PP_INTERNAL_SEQ_PRODUCT(seq0, __VA_ARGS__) + #define OF_PP_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE(seq) #define OF_PP_FORCE(...) OF_PP_TUPLE2VARADIC(OF_PP_CAT((__VA_ARGS__), )) diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h index 3f4bb279fbe3cd69fb92a9acbdbb2651d8aab489..79ba7e5e2a1181b176649b16f418315f32cb2cf8 100644 --- a/oneflow/core/common/util.h +++ b/oneflow/core/common/util.h @@ -164,9 +164,6 @@ inline uint32_t NewRandomSeed() { #define BOOL_SEQ (true)(false) #define PARALLEL_POLICY_SEQ (ParallelPolicy::kModelParallel)(ParallelPolicy::kDataParallel) -#define ENCODE_CASE_SEQ \ - OF_PP_MAKE_TUPLE_SEQ(EncodeCase::kRaw) \ - OF_PP_MAKE_TUPLE_SEQ(EncodeCase::kJpeg) #define FOR_RANGE(type, i, begin, end) for (type i = (begin), __end = (end); i < __end; ++i) #define FOR_EACH(it, container) for (auto it = container.begin(); it != container.end(); ++it) @@ -217,6 +214,11 @@ inline T MaxVal() { return std::numeric_limits::max(); } +// encode case +#define ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT \ + OF_PP_SEQ_PRODUCT((EncodeCase::kRaw)(EncodeCase::kJpeg), ARITHMETIC_DATA_TYPE_SEQ) \ + OF_PP_SEQ_PRODUCT((EncodeCase::kProtobuf), PB_LIST_DATA_TYPE_SEQ) + } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_UTIL_H_ diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 586225eca87e25a7ddc19f194d701fb1e2c8d6cb..d31b4e4672c5ffb7afc3ea1cf0959d869af3254d 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -488,9 +488,13 @@ message EncodeConf { oneof encode { EncodeRaw raw = 1; EncodeJpeg jpeg = 2; + EncodeProtobuf protobuf = 3; } } +message EncodeProtobuf { +} + message EncodeRaw { } diff --git a/oneflow/core/record/ofrecord_decoder.cpp b/oneflow/core/record/ofrecord_decoder.cpp index ce492316f45fc8ff2ce67876458b5bb6cc921251..9ae5e063badb7c9bc3599223318e37435d069567 100644 --- a/oneflow/core/record/ofrecord_decoder.cpp +++ b/oneflow/core/record/ofrecord_decoder.cpp @@ -1,6 +1,7 @@ #include "oneflow/core/record/ofrecord_decoder.h" #include "oneflow/core/record/ofrecord_raw_decoder.h" #include "oneflow/core/record/ofrecord_jpeg_decoder.h" +#include "oneflow/core/record/ofrecord_protobuf_decoder.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/thread/thread_manager.h" @@ -56,7 +57,14 @@ void DoScalePreprocess(const ScalePreprocessConf& conf, T* dptr, int64_t n) { } template -void DoPreprocess(const PreprocessConf& conf, T* dptr, const Shape& shape) { +typename std::enable_if::value>::type DoPreprocess(const PreprocessConf& conf, T* dptr, + const Shape& shape) { + UNIMPLEMENTED(); +} + +template +typename std::enable_if::value>::type DoPreprocess(const PreprocessConf& conf, T* dptr, + const Shape& shape) { int64_t n = shape.Count(1); if (conf.has_subtract_conf()) { DoSubtractPreprocess(conf.subtract_conf(), dptr, n); @@ -188,13 +196,9 @@ void OFRecordDecoder::ReadPartDataContent( OFRecordDecoderIf* GetOFRecordDecoder(EncodeCase encode_case, DataType data_type) { static const HashMap obj = { - #define MAKE_ENTRY(et, dt) \ {GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordDecoderImpl}, - - OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) - - }; + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT)}; return obj.at(GetHashKey(encode_case, data_type)); } diff --git a/oneflow/core/record/ofrecord_encoder.cpp b/oneflow/core/record/ofrecord_encoder.cpp index 29f8fd24541016f3047cc519ca2cde400d127c52..0b624f041933cd1601a8de0e65fc0235bb9d42a0 100644 --- a/oneflow/core/record/ofrecord_encoder.cpp +++ b/oneflow/core/record/ofrecord_encoder.cpp @@ -1,18 +1,15 @@ #include "oneflow/core/record/ofrecord_encoder.h" #include "oneflow/core/record/ofrecord_raw_encoder.h" #include "oneflow/core/record/ofrecord_jpeg_encoder.h" +#include "oneflow/core/record/ofrecord_protobuf_encoder.h" namespace oneflow { OFRecordEncoderIf* GetOFRecordEncoder(EncodeCase encode_case, DataType data_type) { static const HashMap obj = { - #define MAKE_ENTRY(et, dt) \ {GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordEncoderImpl}, - - OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) - - }; + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT)}; return obj.at(GetHashKey(encode_case, data_type)); } diff --git a/oneflow/core/record/ofrecord_protobuf_decoder.cpp b/oneflow/core/record/ofrecord_protobuf_decoder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..347a4e4deb52ce0ec555cae9605358e665a6ab2e --- /dev/null +++ b/oneflow/core/record/ofrecord_protobuf_decoder.cpp @@ -0,0 +1,37 @@ +#include "oneflow/core/record/ofrecord_protobuf_decoder.h" + +namespace oneflow { + +namespace { + +template +decltype(std::declval().value()) GetFeatureDataList(const Feature& feature); + +#define SPECIALIZE_GET_PB_LIST_DATA_LIST(T, type_proto, data_list) \ + template<> \ + decltype(std::declval().value()) GetFeatureDataList(const Feature& feature) { \ + return feature.data_list(); \ + } +OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_PB_LIST_DATA_LIST, PB_LIST_DATA_TYPE_PB_LIST_FIELD_SEQ); + +} // namespace + +template +int32_t OFRecordDecoderImpl::GetColNumOfFeature( + const Feature& feature, int64_t one_col_elem_num) const { + return 1; +} + +template +void OFRecordDecoderImpl::ReadOneCol( + DeviceCtx* ctx, const Feature& feature, const BlobConf& blob_conf, int32_t col_id, T* out_dptr, + int64_t one_col_elem_num, std::function NextRandomInt) const { + *out_dptr->mutable_value() = GetFeatureDataList(feature); + CheckPbListSize(*out_dptr); +} + +#define INSTANTIATE_OFRECORD_PROTOBUF_DECODER(type_cpp, type_proto) \ + template class OFRecordDecoderImpl; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_OFRECORD_PROTOBUF_DECODER, PB_LIST_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/record/ofrecord_protobuf_decoder.h b/oneflow/core/record/ofrecord_protobuf_decoder.h new file mode 100644 index 0000000000000000000000000000000000000000..488f6fcdd81f64f0e847e86b50fccf9d0993fc1f --- /dev/null +++ b/oneflow/core/record/ofrecord_protobuf_decoder.h @@ -0,0 +1,20 @@ +#ifndef ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_ +#define ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_ + +#include "oneflow/core/record/ofrecord_decoder.h" + +namespace oneflow { + +template +class OFRecordDecoderImpl final + : public OFRecordDecoder { + private: + int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override; + void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id, + T* out_dptr, int64_t one_col_elem_num, + std::function NextRandomInt) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_ diff --git a/oneflow/core/record/ofrecord_protobuf_encoder.cpp b/oneflow/core/record/ofrecord_protobuf_encoder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50ea4015b768166390d4e94ed84763d1966f9ac3 --- /dev/null +++ b/oneflow/core/record/ofrecord_protobuf_encoder.cpp @@ -0,0 +1,32 @@ +#include "oneflow/core/record/ofrecord_protobuf_encoder.h" + +namespace oneflow { + +namespace { + +template +decltype(std::declval().mutable_value()) GetMutFeatureDataList(Feature& feature); + +#define SPECIALIZE_GET_MUT_PB_LIST_DATA_LIST(T, type_proto, data_list) \ + template<> \ + decltype(std::declval().mutable_value()) GetMutFeatureDataList(Feature & feature) { \ + return feature.mutable_##data_list(); \ + } +OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_MUT_PB_LIST_DATA_LIST, PB_LIST_DATA_TYPE_PB_LIST_FIELD_SEQ); + +} // namespace + +template +void OFRecordEncoderImpl::EncodeOneCol( + DeviceCtx* ctx, const Blob* in_blob, int64_t in_offset, Feature& feature, + const std::string& field_name, int64_t one_col_elem_num) const { + const T& data = in_blob->dptr()[in_offset]; + CheckPbListSize(data); + *GetMutFeatureDataList(feature) = data.value(); +} + +#define INSTANTIATE_OFRECORD_PROTOBUF_ENCODER(type_cpp, type_proto) \ + template class OFRecordEncoderImpl; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_OFRECORD_PROTOBUF_ENCODER, PB_LIST_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/record/ofrecord_protobuf_encoder.h b/oneflow/core/record/ofrecord_protobuf_encoder.h new file mode 100644 index 0000000000000000000000000000000000000000..c3e7a1da5aa6697205296fceb199867dd2f59df0 --- /dev/null +++ b/oneflow/core/record/ofrecord_protobuf_encoder.h @@ -0,0 +1,17 @@ +#ifndef ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_ +#define ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_ + +#include "oneflow/core/record/ofrecord_encoder.h" + +namespace oneflow { + +template +class OFRecordEncoderImpl final : public OFRecordEncoderIf { + private: + void EncodeOneCol(DeviceCtx*, const Blob* in_blob, int64_t in_offset, Feature&, + const std::string& field_name, int64_t one_col_elem_num) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_ diff --git a/oneflow/core/record/ofrecord_raw_decoder.h b/oneflow/core/record/ofrecord_raw_decoder.h index ee135d3cec337239d24ca5e67a25523517cd36ff..071dbfb33f05c9a348985ccfa5fafc14482338e1 100644 --- a/oneflow/core/record/ofrecord_raw_decoder.h +++ b/oneflow/core/record/ofrecord_raw_decoder.h @@ -7,7 +7,6 @@ namespace oneflow { template class OFRecordDecoderImpl final : public OFRecordDecoder { - public: private: int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override; void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id,