提交 0c19fdbc 编写于 作者: D Daniel Sun 提交者: Will Zhang

add buffer blob_desc in dataloader op (#498)

* add dataloader buffer blob

* debug

* change the comment

* add col_num and max_col_num

* change the bug
上级 160367a8
......@@ -7,6 +7,9 @@ void BasicDataLoaderOp::InitFromOpConf() {
CHECK(op_conf().has_basic_data_loader_conf());
EnrollOutputBn("out", false);
if (op_conf().basic_data_loader_conf().max_sequence_size() > 1) {
EnrollDataTmpBn("buffer");
}
}
const PbMessage& BasicDataLoaderOp::GetSpecialConf() const {
......@@ -17,6 +20,8 @@ void BasicDataLoaderOp::InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const BasicDataLoaderOpConf& conf = op_conf().basic_data_loader_conf();
// out
BlobDesc* out = GetBlobDesc4BnInOp("out");
std::vector<int64_t> dim_vec(1 + conf.shape().dim_size());
dim_vec[0] = JobDesc::Singleton()->SinglePieceSize();
......@@ -26,6 +31,19 @@ void BasicDataLoaderOp::InferBlobDescs(
out->mut_shape() = Shape(dim_vec);
out->set_data_type(conf.data_type());
out->set_has_data_id_field(JobDesc::Singleton()->SizeOfOneDataId() > 0);
out->set_has_col_num_field(conf.max_sequence_size() > 1);
out->set_max_col_num(conf.max_sequence_size());
if (conf.max_sequence_size() > 1) {
// buffer
BlobDesc* buffer = GetBlobDesc4BnInOp("buffer");
dim_vec.insert(dim_vec.begin() + 1, conf.max_sequence_size());
buffer->mut_shape() = Shape(dim_vec);
buffer->set_data_type(conf.data_type());
buffer->set_has_data_id_field(JobDesc::Singleton()->SizeOfOneDataId() > 0);
buffer->set_has_col_num_field(true);
buffer->set_max_col_num(conf.max_sequence_size());
}
}
REGISTER_OP(OperatorConf::kBasicDataLoaderConf, BasicDataLoaderOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册