提交 42795ac5 编写于 作者: G grinchcoder 提交者: Will Zhang

Remove middle blobs in boxing_op if it is not a concat-clone box (#175)

boxing op need midlle blob
上级 df6af422
......@@ -10,7 +10,10 @@ void BoxingOp::InitFromOpConf(const OperatorConf& op_conf) {
for (int64_t i = 0; i < op_conf.boxing_conf().in_num(); ++i) {
EnrollInputBn("in_" + std::to_string(i));
}
EnrollDataTmpBn("middle");
if (op_conf.boxing_conf().in_box_case() == BoxingOpConf::kConcatBox
&& op_conf.boxing_conf().out_box_case() == BoxingOpConf::kCloneBox) {
EnrollDataTmpBn("middle");
}
for (int64_t i = 0; i < op_conf.boxing_conf().out_num(); ++i) {
EnrollOutputBn("out_" + std::to_string(i));
}
......@@ -36,7 +39,7 @@ void BoxingOp::InferShape4FwBlobs(
std::vector<int64_t> data_tmp_blob_shape_vec =
GetShapePtr4BnInOp(input_bns().at(0))->dim_vec();
// if it is a concat-box, concat input blob shape to middle blob shape
// if it is a concat-box, accumulate the dimensions on concat-axis.
// otherwise only check all boxes are in the same shape.
int32_t concat_axis = 0;
if (in_box_case == BoxingOpConf::kConcatBox) {
......@@ -53,8 +56,13 @@ void BoxingOp::InferShape4FwBlobs(
}
}
}
*GetShapePtr4BnInOp(SoleDtbn()) = Shape(data_tmp_blob_shape_vec);
auto out_box_case = boxing_conf.out_box_case();
// Although the shape of data_tmp is caculated in all kinds of concat boxes,
// it is stored back if and only if this is a concat-clone box
if (in_box_case == BoxingOpConf::kConcatBox
&& out_box_case == BoxingOpConf::kCloneBox) {
*GetShapePtr4BnInOp(SoleDtbn()) = Shape(data_tmp_blob_shape_vec);
}
CHECK_NE(out_box_case, BoxingOpConf::OUT_BOX_NOT_SET);
if (out_box_case == BoxingOpConf::kDataSplitBox) {
int32_t out_num = output_bns().size();
......
......@@ -26,7 +26,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
{boxing_op->input_bns()[1], new Shape(input_shape_vec2)},
{boxing_op->input_bns()[2], new Shape(input_shape_vec2)},
{boxing_op->input_bns()[3], new Shape(input_shape_vec1)},
{boxing_op->SoleDtbn(), new Shape},
{"middle", new Shape},
{boxing_op->output_bns()[0], new Shape},
{boxing_op->output_bns()[1], new Shape},
{boxing_op->output_bns()[2], new Shape},
......@@ -39,10 +39,6 @@ TEST(BoxingOp, box_4_10x5x6x6) {
boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1);
// test results
// data_tmp_shape is {10, 17, 6, 6}, and the 17 = 5 + 4 + 4 + 4
Shape* data_tmp_shape_ptr = bn2shape_ptr.at(boxing_op->SoleDtbn());
std::vector<int64_t> data_temp_shape_vec = {10, 17, 6, 6};
ASSERT_EQ(*data_tmp_shape_ptr, Shape(data_temp_shape_vec));
// output_shape should be:
// out1 {4, 17, 6, 6}
// out2 {3, 17, 6, 6}
......@@ -54,7 +50,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
ASSERT_EQ(*output_shape_ptr, Shape(output_shape_vec));
}
// Test add_box shape function
// Test add clone box shape function
boxing_conf->set_in_num(3);
boxing_conf->set_out_num(1);
boxing_conf->mutable_add_box();
......@@ -66,9 +62,33 @@ TEST(BoxingOp, box_4_10x5x6x6) {
// test results
// output shape should be the same as input
for (const std::string& bn : boxing_op->output_bns()) {
Shape* output_shape_ptr = bn2shape_ptr.at(bn);
ASSERT_EQ(*output_shape_ptr, Shape(input_shape_vec2));
}
// Test concat clone shape function, this box has data_tmp_shape
boxing_conf->set_in_num(4);
boxing_conf->set_out_num(1);
boxing_conf->mutable_concat_box()->set_axis(1);
boxing_conf->mutable_clone_box();
boxing_op = OpMgr::Singleton()->ConstructOp(op_conf);
// do infer shape
boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1);
// data_tmp_shape is {10, 17, 6, 6}, and the 17 = 4 + 4 + 4 + 5
Shape* data_tmp_shape_ptr = bn2shape_ptr.at(boxing_op->SoleDtbn());
std::vector<int64_t> data_temp_shape_vec = {10, 17, 6, 6};
ASSERT_EQ(*data_tmp_shape_ptr, Shape(data_temp_shape_vec));
// test results
// output shape should be the same as data_tmp_shape
data_tmp_shape_ptr = bn2shape_ptr.at(boxing_op->SoleDtbn());
Shape* output_shape_ptr = bn2shape_ptr.at(boxing_op->output_bns()[0]);
ASSERT_EQ(*output_shape_ptr, Shape(input_shape_vec2));
for (const std::string& bn : boxing_op->output_bns()) {
Shape* output_shape_ptr = bn2shape_ptr.at(bn);
ASSERT_EQ(*output_shape_ptr, *data_tmp_shape_ptr);
}
}
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册