convolution_op_test.cpp 2.0 KB
Newer Older
J
jiyuan 已提交
1
#include "oneflow/core/operator/convolution_op.h"
2 3 4 5

namespace oneflow {

TEST(ConvolutionOp, TestForInferShape4FwBlobs) {
W
willzhang4a58 已提交
6
  // create conv_op
7 8 9 10 11
  OperatorConf op_conf;
  op_conf.set_name("convolution_test");
  op_conf.mutable_convolution_conf()->set_in("convolution/in");
  op_conf.mutable_convolution_conf()->set_out("convolution/out");
  op_conf.mutable_convolution_conf()->set_out_num(16);
W
Wind5 已提交
12 13 14 15 16 17
  op_conf.mutable_convolution_conf()->add_pad(4);
  op_conf.mutable_convolution_conf()->add_pad(4);
  op_conf.mutable_convolution_conf()->add_kernel_size(20);
  op_conf.mutable_convolution_conf()->add_kernel_size(20);
  op_conf.mutable_convolution_conf()->add_stride(3);
  op_conf.mutable_convolution_conf()->add_stride(3);
W
willzhang4a58 已提交
18
  auto convolution_op = OpMgr::Singleton()->ConstructOp(op_conf);
19
  std::vector<int64_t> input_vec = {100, 64, 256, 256};
W
willzhang4a58 已提交
20 21 22 23 24 25 26
  HashMap<std::string, Shape*> bn2shape_ptr{
      {convolution_op->SoleIbn(), new Shape(input_vec)},
      {convolution_op->SoleObn(), new Shape},
      {convolution_op->data_tmp_bns().at(0), new Shape},
      {convolution_op->model_bns().at(0), new Shape},
      {convolution_op->model_bns().at(1), new Shape},
      {convolution_op->model_tmp_bns().at(0), new Shape}};
W
Wind5 已提交
27 28 29
  auto fp = [&bn2shape_ptr](const std::string& bn) {
    return bn2shape_ptr.at(bn);
  };
W
willzhang4a58 已提交
30
  // infershape
31
  convolution_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
W
willzhang4a58 已提交
32
  // test
W
Wind5 已提交
33 34 35 36 37
  Shape* output_shape_ptr = fp(convolution_op->SoleObn());
  Shape* colbuf_shape_ptr = fp(convolution_op->data_tmp_bns().at(0));
  Shape* weight_shape_ptr = fp(convolution_op->model_bns().at(0));
  Shape* bias_shape_ptr = fp(convolution_op->model_bns().at(1));
  Shape* biasmult_shape_ptr = fp(convolution_op->model_tmp_bns().at(0));
38 39
  ASSERT_EQ(*output_shape_ptr, Shape({100, 16, 82, 82}));
  ASSERT_EQ(*colbuf_shape_ptr, Shape({100, 82 * 82, 64 * 20 * 20}));
W
willzhang4a58 已提交
40 41
  ASSERT_EQ(*weight_shape_ptr, Shape({16, 64 * 20 * 20}));
  ASSERT_EQ(*bias_shape_ptr, Shape({16}));
42 43 44
  ASSERT_EQ(*biasmult_shape_ptr, Shape({82 * 82}));
}

W
willzhang4a58 已提交
45
}  // namespace oneflow