提交 517e1533 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

Dev pb data type encode (#1241)

* patch protobuf encode/decode files

* patch EncodeConf

* patch common/preprocessor.h
上级 58f43ff5
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
#define OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, ...) \ #define OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, ...) \
OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, __VA_ARGS__) OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, __VA_ARGS__)
#define OF_PP_SEQ_PRODUCT(seq0, ...) OF_PP_INTERNAL_SEQ_PRODUCT(seq0, __VA_ARGS__)
#define OF_PP_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE(seq) #define OF_PP_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE(seq)
#define OF_PP_FORCE(...) OF_PP_TUPLE2VARADIC(OF_PP_CAT((__VA_ARGS__), )) #define OF_PP_FORCE(...) OF_PP_TUPLE2VARADIC(OF_PP_CAT((__VA_ARGS__), ))
......
...@@ -164,9 +164,6 @@ inline uint32_t NewRandomSeed() { ...@@ -164,9 +164,6 @@ inline uint32_t NewRandomSeed() {
#define BOOL_SEQ (true)(false) #define BOOL_SEQ (true)(false)
#define PARALLEL_POLICY_SEQ (ParallelPolicy::kModelParallel)(ParallelPolicy::kDataParallel) #define PARALLEL_POLICY_SEQ (ParallelPolicy::kModelParallel)(ParallelPolicy::kDataParallel)
#define ENCODE_CASE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(EncodeCase::kRaw) \
OF_PP_MAKE_TUPLE_SEQ(EncodeCase::kJpeg)
#define FOR_RANGE(type, i, begin, end) for (type i = (begin), __end = (end); i < __end; ++i) #define FOR_RANGE(type, i, begin, end) for (type i = (begin), __end = (end); i < __end; ++i)
#define FOR_EACH(it, container) for (auto it = container.begin(); it != container.end(); ++it) #define FOR_EACH(it, container) for (auto it = container.begin(); it != container.end(); ++it)
...@@ -217,6 +214,11 @@ inline T MaxVal() { ...@@ -217,6 +214,11 @@ inline T MaxVal() {
return std::numeric_limits<T>::max(); return std::numeric_limits<T>::max();
} }
// encode case
#define ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT \
OF_PP_SEQ_PRODUCT((EncodeCase::kRaw)(EncodeCase::kJpeg), ARITHMETIC_DATA_TYPE_SEQ) \
OF_PP_SEQ_PRODUCT((EncodeCase::kProtobuf), PB_LIST_DATA_TYPE_SEQ)
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_UTIL_H_ #endif // ONEFLOW_CORE_COMMON_UTIL_H_
...@@ -488,9 +488,13 @@ message EncodeConf { ...@@ -488,9 +488,13 @@ message EncodeConf {
oneof encode { oneof encode {
EncodeRaw raw = 1; EncodeRaw raw = 1;
EncodeJpeg jpeg = 2; EncodeJpeg jpeg = 2;
EncodeProtobuf protobuf = 3;
} }
} }
message EncodeProtobuf {
}
message EncodeRaw { message EncodeRaw {
} }
......
#include "oneflow/core/record/ofrecord_decoder.h" #include "oneflow/core/record/ofrecord_decoder.h"
#include "oneflow/core/record/ofrecord_raw_decoder.h" #include "oneflow/core/record/ofrecord_raw_decoder.h"
#include "oneflow/core/record/ofrecord_jpeg_decoder.h" #include "oneflow/core/record/ofrecord_jpeg_decoder.h"
#include "oneflow/core/record/ofrecord_protobuf_decoder.h"
#include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/thread/thread_manager.h"
...@@ -56,7 +57,14 @@ void DoScalePreprocess(const ScalePreprocessConf& conf, T* dptr, int64_t n) { ...@@ -56,7 +57,14 @@ void DoScalePreprocess(const ScalePreprocessConf& conf, T* dptr, int64_t n) {
} }
template<typename T> template<typename T>
void DoPreprocess(const PreprocessConf& conf, T* dptr, const Shape& shape) { typename std::enable_if<IsPbType<T>::value>::type DoPreprocess(const PreprocessConf& conf, T* dptr,
const Shape& shape) {
UNIMPLEMENTED();
}
template<typename T>
typename std::enable_if<!IsPbType<T>::value>::type DoPreprocess(const PreprocessConf& conf, T* dptr,
const Shape& shape) {
int64_t n = shape.Count(1); int64_t n = shape.Count(1);
if (conf.has_subtract_conf()) { if (conf.has_subtract_conf()) {
DoSubtractPreprocess<T>(conf.subtract_conf(), dptr, n); DoSubtractPreprocess<T>(conf.subtract_conf(), dptr, n);
...@@ -188,13 +196,9 @@ void OFRecordDecoder<encode_case, T>::ReadPartDataContent( ...@@ -188,13 +196,9 @@ void OFRecordDecoder<encode_case, T>::ReadPartDataContent(
OFRecordDecoderIf* GetOFRecordDecoder(EncodeCase encode_case, DataType data_type) { OFRecordDecoderIf* GetOFRecordDecoder(EncodeCase encode_case, DataType data_type) {
static const HashMap<std::string, OFRecordDecoderIf*> obj = { static const HashMap<std::string, OFRecordDecoderIf*> obj = {
#define MAKE_ENTRY(et, dt) \ #define MAKE_ENTRY(et, dt) \
{GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordDecoderImpl<et, OF_PP_PAIR_FIRST(dt)>}, {GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordDecoderImpl<et, OF_PP_PAIR_FIRST(dt)>},
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT)};
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_SEQ, ARITHMETIC_DATA_TYPE_SEQ)
};
return obj.at(GetHashKey(encode_case, data_type)); return obj.at(GetHashKey(encode_case, data_type));
} }
......
#include "oneflow/core/record/ofrecord_encoder.h" #include "oneflow/core/record/ofrecord_encoder.h"
#include "oneflow/core/record/ofrecord_raw_encoder.h" #include "oneflow/core/record/ofrecord_raw_encoder.h"
#include "oneflow/core/record/ofrecord_jpeg_encoder.h" #include "oneflow/core/record/ofrecord_jpeg_encoder.h"
#include "oneflow/core/record/ofrecord_protobuf_encoder.h"
namespace oneflow { namespace oneflow {
OFRecordEncoderIf* GetOFRecordEncoder(EncodeCase encode_case, DataType data_type) { OFRecordEncoderIf* GetOFRecordEncoder(EncodeCase encode_case, DataType data_type) {
static const HashMap<std::string, OFRecordEncoderIf*> obj = { static const HashMap<std::string, OFRecordEncoderIf*> obj = {
#define MAKE_ENTRY(et, dt) \ #define MAKE_ENTRY(et, dt) \
{GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordEncoderImpl<et, OF_PP_PAIR_FIRST(dt)>}, {GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordEncoderImpl<et, OF_PP_PAIR_FIRST(dt)>},
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT)};
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_SEQ, ARITHMETIC_DATA_TYPE_SEQ)
};
return obj.at(GetHashKey(encode_case, data_type)); return obj.at(GetHashKey(encode_case, data_type));
} }
......
#include "oneflow/core/record/ofrecord_protobuf_decoder.h"
namespace oneflow {
namespace {
template<typename T>
decltype(std::declval<T>().value()) GetFeatureDataList(const Feature& feature);
#define SPECIALIZE_GET_PB_LIST_DATA_LIST(T, type_proto, data_list) \
template<> \
decltype(std::declval<T>().value()) GetFeatureDataList<T>(const Feature& feature) { \
return feature.data_list(); \
}
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_PB_LIST_DATA_LIST, PB_LIST_DATA_TYPE_PB_LIST_FIELD_SEQ);
} // namespace
template<typename T>
int32_t OFRecordDecoderImpl<EncodeCase::kProtobuf, T>::GetColNumOfFeature(
const Feature& feature, int64_t one_col_elem_num) const {
return 1;
}
template<typename T>
void OFRecordDecoderImpl<EncodeCase::kProtobuf, T>::ReadOneCol(
DeviceCtx* ctx, const Feature& feature, const BlobConf& blob_conf, int32_t col_id, T* out_dptr,
int64_t one_col_elem_num, std::function<int32_t(void)> NextRandomInt) const {
*out_dptr->mutable_value() = GetFeatureDataList<T>(feature);
CheckPbListSize<T>(*out_dptr);
}
#define INSTANTIATE_OFRECORD_PROTOBUF_DECODER(type_cpp, type_proto) \
template class OFRecordDecoderImpl<EncodeCase::kProtobuf, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_OFRECORD_PROTOBUF_DECODER, PB_LIST_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_
#define ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_
#include "oneflow/core/record/ofrecord_decoder.h"
namespace oneflow {
template<typename T>
class OFRecordDecoderImpl<EncodeCase::kProtobuf, T> final
: public OFRecordDecoder<EncodeCase::kProtobuf, T> {
private:
int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override;
void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id,
T* out_dptr, int64_t one_col_elem_num,
std::function<int32_t(void)> NextRandomInt) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_
#include "oneflow/core/record/ofrecord_protobuf_encoder.h"
namespace oneflow {
namespace {
template<typename T>
decltype(std::declval<T>().mutable_value()) GetMutFeatureDataList(Feature& feature);
#define SPECIALIZE_GET_MUT_PB_LIST_DATA_LIST(T, type_proto, data_list) \
template<> \
decltype(std::declval<T>().mutable_value()) GetMutFeatureDataList<T>(Feature & feature) { \
return feature.mutable_##data_list(); \
}
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_MUT_PB_LIST_DATA_LIST, PB_LIST_DATA_TYPE_PB_LIST_FIELD_SEQ);
} // namespace
template<typename T>
void OFRecordEncoderImpl<EncodeCase::kProtobuf, T>::EncodeOneCol(
DeviceCtx* ctx, const Blob* in_blob, int64_t in_offset, Feature& feature,
const std::string& field_name, int64_t one_col_elem_num) const {
const T& data = in_blob->dptr<T>()[in_offset];
CheckPbListSize<T>(data);
*GetMutFeatureDataList<T>(feature) = data.value();
}
#define INSTANTIATE_OFRECORD_PROTOBUF_ENCODER(type_cpp, type_proto) \
template class OFRecordEncoderImpl<EncodeCase::kProtobuf, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_OFRECORD_PROTOBUF_ENCODER, PB_LIST_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_
#define ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_
#include "oneflow/core/record/ofrecord_encoder.h"
namespace oneflow {
template<typename T>
class OFRecordEncoderImpl<EncodeCase::kProtobuf, T> final : public OFRecordEncoderIf {
private:
void EncodeOneCol(DeviceCtx*, const Blob* in_blob, int64_t in_offset, Feature&,
const std::string& field_name, int64_t one_col_elem_num) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_
...@@ -7,7 +7,6 @@ namespace oneflow { ...@@ -7,7 +7,6 @@ namespace oneflow {
template<typename T> template<typename T>
class OFRecordDecoderImpl<EncodeCase::kRaw, T> final : public OFRecordDecoder<EncodeCase::kRaw, T> { class OFRecordDecoderImpl<EncodeCase::kRaw, T> final : public OFRecordDecoder<EncodeCase::kRaw, T> {
public:
private: private:
int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override; int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override;
void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id, void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册