提交 517e1533 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

Dev pb data type encode (#1241)

* patch protobuf encode/decode files

* patch EncodeConf

* patch common/preprocessor.h
上级 58f43ff5
......@@ -32,6 +32,8 @@
#define OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, ...) \
OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, __VA_ARGS__)
#define OF_PP_SEQ_PRODUCT(seq0, ...) OF_PP_INTERNAL_SEQ_PRODUCT(seq0, __VA_ARGS__)
#define OF_PP_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE(seq)
#define OF_PP_FORCE(...) OF_PP_TUPLE2VARADIC(OF_PP_CAT((__VA_ARGS__), ))
......
......@@ -164,9 +164,6 @@ inline uint32_t NewRandomSeed() {
#define BOOL_SEQ (true)(false)
#define PARALLEL_POLICY_SEQ (ParallelPolicy::kModelParallel)(ParallelPolicy::kDataParallel)
#define ENCODE_CASE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(EncodeCase::kRaw) \
OF_PP_MAKE_TUPLE_SEQ(EncodeCase::kJpeg)
#define FOR_RANGE(type, i, begin, end) for (type i = (begin), __end = (end); i < __end; ++i)
#define FOR_EACH(it, container) for (auto it = container.begin(); it != container.end(); ++it)
......@@ -217,6 +214,11 @@ inline T MaxVal() {
return std::numeric_limits<T>::max();
}
// encode case
#define ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT \
OF_PP_SEQ_PRODUCT((EncodeCase::kRaw)(EncodeCase::kJpeg), ARITHMETIC_DATA_TYPE_SEQ) \
OF_PP_SEQ_PRODUCT((EncodeCase::kProtobuf), PB_LIST_DATA_TYPE_SEQ)
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_UTIL_H_
......@@ -488,9 +488,13 @@ message EncodeConf {
oneof encode {
EncodeRaw raw = 1;
EncodeJpeg jpeg = 2;
EncodeProtobuf protobuf = 3;
}
}
message EncodeProtobuf {
}
message EncodeRaw {
}
......
#include "oneflow/core/record/ofrecord_decoder.h"
#include "oneflow/core/record/ofrecord_raw_decoder.h"
#include "oneflow/core/record/ofrecord_jpeg_decoder.h"
#include "oneflow/core/record/ofrecord_protobuf_decoder.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/thread/thread_manager.h"
......@@ -56,7 +57,14 @@ void DoScalePreprocess(const ScalePreprocessConf& conf, T* dptr, int64_t n) {
}
template<typename T>
void DoPreprocess(const PreprocessConf& conf, T* dptr, const Shape& shape) {
typename std::enable_if<IsPbType<T>::value>::type DoPreprocess(const PreprocessConf& conf, T* dptr,
const Shape& shape) {
UNIMPLEMENTED();
}
template<typename T>
typename std::enable_if<!IsPbType<T>::value>::type DoPreprocess(const PreprocessConf& conf, T* dptr,
const Shape& shape) {
int64_t n = shape.Count(1);
if (conf.has_subtract_conf()) {
DoSubtractPreprocess<T>(conf.subtract_conf(), dptr, n);
......@@ -188,13 +196,9 @@ void OFRecordDecoder<encode_case, T>::ReadPartDataContent(
OFRecordDecoderIf* GetOFRecordDecoder(EncodeCase encode_case, DataType data_type) {
static const HashMap<std::string, OFRecordDecoderIf*> obj = {
#define MAKE_ENTRY(et, dt) \
{GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordDecoderImpl<et, OF_PP_PAIR_FIRST(dt)>},
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_SEQ, ARITHMETIC_DATA_TYPE_SEQ)
};
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT)};
return obj.at(GetHashKey(encode_case, data_type));
}
......
#include "oneflow/core/record/ofrecord_encoder.h"
#include "oneflow/core/record/ofrecord_raw_encoder.h"
#include "oneflow/core/record/ofrecord_jpeg_encoder.h"
#include "oneflow/core/record/ofrecord_protobuf_encoder.h"
namespace oneflow {
OFRecordEncoderIf* GetOFRecordEncoder(EncodeCase encode_case, DataType data_type) {
static const HashMap<std::string, OFRecordEncoderIf*> obj = {
#define MAKE_ENTRY(et, dt) \
{GetHashKey(et, OF_PP_PAIR_SECOND(dt)), new OFRecordEncoderImpl<et, OF_PP_PAIR_FIRST(dt)>},
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_SEQ, ARITHMETIC_DATA_TYPE_SEQ)
};
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ENCODE_CASE_DATA_TYPE_SEQ_PRODUCT)};
return obj.at(GetHashKey(encode_case, data_type));
}
......
#include "oneflow/core/record/ofrecord_protobuf_decoder.h"
namespace oneflow {
namespace {
template<typename T>
decltype(std::declval<T>().value()) GetFeatureDataList(const Feature& feature);
#define SPECIALIZE_GET_PB_LIST_DATA_LIST(T, type_proto, data_list) \
template<> \
decltype(std::declval<T>().value()) GetFeatureDataList<T>(const Feature& feature) { \
return feature.data_list(); \
}
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_PB_LIST_DATA_LIST, PB_LIST_DATA_TYPE_PB_LIST_FIELD_SEQ);
} // namespace
template<typename T>
int32_t OFRecordDecoderImpl<EncodeCase::kProtobuf, T>::GetColNumOfFeature(
const Feature& feature, int64_t one_col_elem_num) const {
return 1;
}
template<typename T>
void OFRecordDecoderImpl<EncodeCase::kProtobuf, T>::ReadOneCol(
DeviceCtx* ctx, const Feature& feature, const BlobConf& blob_conf, int32_t col_id, T* out_dptr,
int64_t one_col_elem_num, std::function<int32_t(void)> NextRandomInt) const {
*out_dptr->mutable_value() = GetFeatureDataList<T>(feature);
CheckPbListSize<T>(*out_dptr);
}
#define INSTANTIATE_OFRECORD_PROTOBUF_DECODER(type_cpp, type_proto) \
template class OFRecordDecoderImpl<EncodeCase::kProtobuf, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_OFRECORD_PROTOBUF_DECODER, PB_LIST_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_
#define ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_
#include "oneflow/core/record/ofrecord_decoder.h"
namespace oneflow {
template<typename T>
class OFRecordDecoderImpl<EncodeCase::kProtobuf, T> final
: public OFRecordDecoder<EncodeCase::kProtobuf, T> {
private:
int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override;
void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id,
T* out_dptr, int64_t one_col_elem_num,
std::function<int32_t(void)> NextRandomInt) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_DECODER_H_
#include "oneflow/core/record/ofrecord_protobuf_encoder.h"
namespace oneflow {
namespace {
template<typename T>
decltype(std::declval<T>().mutable_value()) GetMutFeatureDataList(Feature& feature);
#define SPECIALIZE_GET_MUT_PB_LIST_DATA_LIST(T, type_proto, data_list) \
template<> \
decltype(std::declval<T>().mutable_value()) GetMutFeatureDataList<T>(Feature & feature) { \
return feature.mutable_##data_list(); \
}
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_MUT_PB_LIST_DATA_LIST, PB_LIST_DATA_TYPE_PB_LIST_FIELD_SEQ);
} // namespace
template<typename T>
void OFRecordEncoderImpl<EncodeCase::kProtobuf, T>::EncodeOneCol(
DeviceCtx* ctx, const Blob* in_blob, int64_t in_offset, Feature& feature,
const std::string& field_name, int64_t one_col_elem_num) const {
const T& data = in_blob->dptr<T>()[in_offset];
CheckPbListSize<T>(data);
*GetMutFeatureDataList<T>(feature) = data.value();
}
#define INSTANTIATE_OFRECORD_PROTOBUF_ENCODER(type_cpp, type_proto) \
template class OFRecordEncoderImpl<EncodeCase::kProtobuf, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_OFRECORD_PROTOBUF_ENCODER, PB_LIST_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_
#define ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_
#include "oneflow/core/record/ofrecord_encoder.h"
namespace oneflow {
template<typename T>
class OFRecordEncoderImpl<EncodeCase::kProtobuf, T> final : public OFRecordEncoderIf {
private:
void EncodeOneCol(DeviceCtx*, const Blob* in_blob, int64_t in_offset, Feature&,
const std::string& field_name, int64_t one_col_elem_num) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_RECORD_OFRECORD_PROTOBUF_ENCODER_H_
......@@ -7,7 +7,6 @@ namespace oneflow {
template<typename T>
class OFRecordDecoderImpl<EncodeCase::kRaw, T> final : public OFRecordDecoder<EncodeCase::kRaw, T> {
public:
private:
int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override;
void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册