提交 c5310e4c 编写于 作者: S scxfjiang 提交者: Li Xinqi

Fix jxf reduce concat bug (#1606)

* refine logic to infer reduce_concat_op's elem_cnt of out blob, still have bugs...

* add RoundUp in reduce_concat

* CHECK_LE -> CHECK_EQ

* add CHECK


Former-commit-id: 962817e2a322ba6452c9966bae87fb5da9d4a86a
上级 f9bab665
...@@ -18,27 +18,40 @@ const PbMessage& ReduceConcatOp::GetCustomizedConf() const { ...@@ -18,27 +18,40 @@ const PbMessage& ReduceConcatOp::GetCustomizedConf() const {
void ReduceConcatOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, void ReduceConcatOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const { const ParallelContext* parallel_ctx) const {
int32_t in_num = op_conf().reduce_concat_conf().in_num(); const BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0));
BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0)); const DataType data_type = first_in_blob->data_type();
for (int32_t i = 1; i < op_conf().reduce_concat_conf().in_num(); ++i) {
CHECK_EQ(data_type, GetBlobDesc4BnInOp(input_bns().Get(i))->data_type());
}
BlobDesc* out_blob = GetBlobDesc4BnInOp(SoleObn()); BlobDesc* out_blob = GetBlobDesc4BnInOp(SoleObn());
*out_blob = *first_in_blob; *out_blob = *first_in_blob;
int64_t out_blob_elem_cnt = first_in_blob->shape().elem_cnt(); int64_t in_blob_body_size_sum = 0;
for (int32_t i = 1; i < in_num; ++i) { for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
out_blob_elem_cnt += GetBlobDesc4BnInOp(input_bns().Get(i))->shape().elem_cnt(); in_blob_body_size_sum +=
RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody();
} }
const int64_t data_type_byte_size =
static_cast<int64_t>(GetSizeOfDataType(first_in_blob->data_type()));
CHECK_EQ(in_blob_body_size_sum % data_type_byte_size, 0);
const int64_t out_blob_elem_cnt =
RoundUp(in_blob_body_size_sum / data_type_byte_size, parallel_ctx->parallel_num());
out_blob->mut_shape() = Shape({out_blob_elem_cnt}); out_blob->mut_shape() = Shape({out_blob_elem_cnt});
} }
void ReduceConcatOp::VirtualGenKernelConf( void ReduceConcatOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
KernelConf* kernel_conf) const { const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
ReduceConcatKernelConf* reduce_concat_conf = kernel_conf->mutable_reduce_concat_conf(); ReduceConcatKernelConf* reduce_concat_conf = kernel_conf->mutable_reduce_concat_conf();
int64_t offset = 0; int64_t offset = 0;
for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) { for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
reduce_concat_conf->mutable_data_offset()->Add(offset); reduce_concat_conf->mutable_data_offset()->Add(offset);
offset += RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody(); offset += RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody();
} }
CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleObn())).ByteSizeOfBlobBody()); const int64_t data_type_byte_size =
static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(input_bns().Get(0))->data_type()));
CHECK_EQ(RoundUp(offset, parallel_ctx->parallel_num() * data_type_byte_size),
RtBlobDesc(*GetBlobDesc4BnInOp(SoleObn())).ByteSizeOfBlobBody());
} }
LogicalBlobId ReduceConcatOp::obn2lbi(const std::string& output_bn) const { LogicalBlobId ReduceConcatOp::obn2lbi(const std::string& output_bn) const {
......
...@@ -15,15 +15,18 @@ void ReduceSplitOp::InitFromOpConf() { ...@@ -15,15 +15,18 @@ void ReduceSplitOp::InitFromOpConf() {
const PbMessage& ReduceSplitOp::GetCustomizedConf() const { return op_conf().reduce_split_conf(); } const PbMessage& ReduceSplitOp::GetCustomizedConf() const { return op_conf().reduce_split_conf(); }
void ReduceSplitOp::VirtualGenKernelConf( void ReduceSplitOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
KernelConf* kernel_conf) const { const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
ReduceSplitKernelConf* reduce_split_conf = kernel_conf->mutable_reduce_split_conf(); ReduceSplitKernelConf* reduce_split_conf = kernel_conf->mutable_reduce_split_conf();
int64_t offset = 0; int64_t offset = 0;
for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) { for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) {
reduce_split_conf->mutable_data_offset()->Add(offset); reduce_split_conf->mutable_data_offset()->Add(offset);
offset += RtBlobDesc(*(GetBlobDesc4BnInOp(output_bns().Get(i)))).ByteSizeOfBlobBody(); offset += RtBlobDesc(*(GetBlobDesc4BnInOp(output_bns().Get(i)))).ByteSizeOfBlobBody();
} }
CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleIbn())).ByteSizeOfBlobBody()); const int64_t data_type_byte_size =
static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(SoleIbn())->data_type()));
CHECK_EQ(RoundUp(offset, parallel_ctx->parallel_num() * data_type_byte_size),
RtBlobDesc(*GetBlobDesc4BnInOp(SoleIbn())).ByteSizeOfBlobBody());
} }
REGISTER_OP(OperatorConf::kReduceSplitConf, ReduceSplitOp); REGISTER_OP(OperatorConf::kReduceSplitConf, ReduceSplitOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册