提交 52a6c519 编写于 作者: S scxfjiang 提交者: Li Xinqi

refine CHECK in AllReduce (#1618)

* refine CHECK in AllReduce

* move ReduceConcatOpCtx definition to .cpp file


Former-commit-id: 5a50f692cb92c5a6a7074be2063cbc1ec325c1ca
上级 d72a21e2
......@@ -4,6 +4,11 @@
namespace oneflow {
struct ReduceConcatOpCtx : public OpContext {
ReduceConcatOpCtx(const int64_t elem_cnt) : out_blob_elem_cnt(elem_cnt) {}
int64_t out_blob_elem_cnt;
};
void ReduceConcatOp::InitFromOpConf() {
CHECK(op_conf().has_reduce_concat_conf());
for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
......@@ -17,7 +22,8 @@ const PbMessage& ReduceConcatOp::GetCustomizedConf() const {
}
void ReduceConcatOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx,
std::function<void(OpContext*)> EnrollOpCtx) const {
const BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0));
const DataType data_type = first_in_blob->data_type();
for (int32_t i = 1; i < op_conf().reduce_concat_conf().in_num(); ++i) {
......@@ -37,11 +43,15 @@ void ReduceConcatOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)>
const int64_t out_blob_elem_cnt =
RoundUp(in_blob_body_size_sum / data_type_byte_size, parallel_ctx->parallel_num());
out_blob->mut_shape() = Shape({out_blob_elem_cnt});
// construct reduce_concat_op_ctx for later CHECK in ReduceConcatOp::VirtualGenKernelConf
ReduceConcatOpCtx* reduce_concat_op_ctx = new ReduceConcatOpCtx(out_blob_elem_cnt);
EnrollOpCtx(reduce_concat_op_ctx);
}
void ReduceConcatOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
const ParallelContext* parallel_ctx, KernelConf* kernel_conf, const OpContext* op_ctx) const {
ReduceConcatKernelConf* reduce_concat_conf = kernel_conf->mutable_reduce_concat_conf();
int64_t offset = 0;
for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
......@@ -50,8 +60,11 @@ void ReduceConcatOp::VirtualGenKernelConf(
}
const int64_t data_type_byte_size =
static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(input_bns().Get(0))->data_type()));
CHECK_EQ(RoundUp(offset, parallel_ctx->parallel_num() * data_type_byte_size),
RtBlobDesc(*GetBlobDesc4BnInOp(SoleObn())).ByteSizeOfBlobBody());
CHECK_EQ(offset % data_type_byte_size, 0);
const int64_t out_blob_elem_cnt =
RoundUp(offset / data_type_byte_size, parallel_ctx->parallel_num());
const ReduceConcatOpCtx* reduce_concat_op_ctx = static_cast<const ReduceConcatOpCtx*>(op_ctx);
CHECK_EQ(reduce_concat_op_ctx->out_blob_elem_cnt, out_blob_elem_cnt);
}
LogicalBlobId ReduceConcatOp::obn2lbi(const std::string& output_bn) const {
......
......@@ -15,11 +15,13 @@ class ReduceConcatOp final : public Operator {
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
const ParallelContext* parallel_ctx,
std::function<void(OpContext*)> EnrollOpCtx) const override;
private:
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*, KernelConf*) const override;
const ParallelContext*, KernelConf*,
const OpContext* op_ctx) const override;
LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
LogicalBlobId obn2lbi(const std::string& output_bn) const override;
};
......
......@@ -25,8 +25,11 @@ void ReduceSplitOp::VirtualGenKernelConf(
}
const int64_t data_type_byte_size =
static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(SoleIbn())->data_type()));
CHECK_EQ(RoundUp(offset, parallel_ctx->parallel_num() * data_type_byte_size),
RtBlobDesc(*GetBlobDesc4BnInOp(SoleIbn())).ByteSizeOfBlobBody());
CHECK_EQ(offset % data_type_byte_size, 0);
const int64_t out_blob_elem_cnt_sum =
RoundUp(offset / data_type_byte_size, parallel_ctx->parallel_num());
const int64_t in_blob_elem_cnt = GetBlobDesc4BnInOp(SoleIbn())->shape().elem_cnt();
CHECK_EQ(out_blob_elem_cnt_sum, in_blob_elem_cnt);
}
REGISTER_OP(OperatorConf::kReduceSplitConf, ReduceSplitOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册