convolution_op_test.cpp 3.2 KB
Newer Older
J
jiyuan 已提交
1
#include "oneflow/core/operator/convolution_op.h"
2 3 4

namespace oneflow {

5 6
namespace {

W
Will Zhang 已提交
7
std::shared_ptr<Operator> GetTestConvolutionOp() {
8
  OperatorConf op_conf;
W
Will Zhang 已提交
9
  op_conf.set_name("convolution_test");
W
willzhang4a58 已提交
10 11
  op_conf.mutable_convolution_conf()->set_in("convolution/in");
  op_conf.mutable_convolution_conf()->set_out("convolution/out");
W
Will Zhang 已提交
12 13 14 15 16 17 18 19
  op_conf.mutable_convolution_conf()->set_has_bias_term(true);
  op_conf.mutable_convolution_conf()->set_out_num(16);
  op_conf.mutable_convolution_conf()->set_pad_h(4);
  op_conf.mutable_convolution_conf()->set_pad_w(4);
  op_conf.mutable_convolution_conf()->set_kernel_size_h(20);
  op_conf.mutable_convolution_conf()->set_kernel_size_w(20);
  op_conf.mutable_convolution_conf()->set_stride_h(3);
  op_conf.mutable_convolution_conf()->set_stride_w(3);
W
willzhang4a58 已提交
20
  auto convolution_op = ConstructOp(op_conf);
W
Will Zhang 已提交
21 22 23 24
  JobConf job_conf;
  job_conf.set_default_data_type(DataType::kFloat);
  JobDesc::Singleton()->InitFromJobConf(job_conf);
  return convolution_op;
25 26 27 28 29
}

}  // namespace

TEST(ConvolutionOp, dataparallel_convolution) {
W
Will Zhang 已提交
30 31 32 33 34 35 36 37 38 39 40
  auto convolution_op = GetTestConvolutionOp();
  HashMap<std::string, BlobDesc*> bn2blob_desc_map{
      {"in", new BlobDesc(Shape({100, 64, 256, 256}), DataType::kFloat, false)},
      {"out", new BlobDesc},
      {"col_buf", new BlobDesc},
      {"weight", new BlobDesc},
      {"bias", new BlobDesc},
      {"bias_multiplier", new BlobDesc}};
  auto Bn2BlobDescFunc = [&bn2blob_desc_map](const std::string& bn) {
    return bn2blob_desc_map.at(bn);
  };
W
Will Zhang 已提交
41
  convolution_op->InferBlobDescs(Bn2BlobDescFunc, kDataParallel, 0, 1);
W
Will Zhang 已提交
42 43 44 45 46 47 48 49 50 51 52
  ASSERT_EQ(*Bn2BlobDescFunc("out"),
            BlobDesc(Shape({100, 16, 82, 82}), DataType::kFloat, false));
  ASSERT_EQ(
      *Bn2BlobDescFunc("col_buf"),
      BlobDesc(Shape({100, 82 * 82, 64 * 20 * 20}), DataType::kFloat, false));
  ASSERT_EQ(*Bn2BlobDescFunc("weight"),
            BlobDesc(Shape({16, 64 * 20 * 20}), DataType::kFloat, false));
  ASSERT_EQ(*Bn2BlobDescFunc("bias"),
            BlobDesc(Shape({16}), DataType::kFloat, false));
  ASSERT_EQ(*Bn2BlobDescFunc("bias_multiplier"),
            BlobDesc(Shape({82 * 82}), DataType::kFloat, false));
53 54 55
}

TEST(ConvolutionOp, modelparallel_convolution) {
W
Will Zhang 已提交
56 57 58 59 60 61 62 63 64 65 66
  auto convolution_op = GetTestConvolutionOp();
  HashMap<std::string, BlobDesc*> bn2shape_ptr{
      {"in", new BlobDesc(Shape({100, 64, 256, 256}), DataType::kFloat, false)},
      {"out", new BlobDesc},
      {"col_buf", new BlobDesc},
      {"weight", new BlobDesc},
      {"bias", new BlobDesc},
      {"bias_multiplier", new BlobDesc}};
  auto Bn2BlobDescFunc = [&bn2shape_ptr](const std::string& bn) {
    return bn2shape_ptr.at(bn);
  };
W
Will Zhang 已提交
67
  convolution_op->InferBlobDescs(Bn2BlobDescFunc, kModelParallel, 3, 8);
W
Will Zhang 已提交
68 69 70 71 72 73 74 75 76 77 78
  ASSERT_EQ(*Bn2BlobDescFunc("out"),
            BlobDesc(Shape({100, 2, 82, 82}), DataType::kFloat, false));
  ASSERT_EQ(
      *Bn2BlobDescFunc("col_buf"),
      BlobDesc(Shape({100, 82 * 82, 64 * 20 * 20}), DataType::kFloat, false));
  ASSERT_EQ(*Bn2BlobDescFunc("weight"),
            BlobDesc(Shape({2, 64 * 20 * 20}), DataType::kFloat, false));
  ASSERT_EQ(*Bn2BlobDescFunc("bias"),
            BlobDesc(Shape({2}), DataType::kFloat, false));
  ASSERT_EQ(*Bn2BlobDescFunc("bias_multiplier"),
            BlobDesc(Shape({82 * 82}), DataType::kFloat, false));
79 80
}

W
willzhang4a58 已提交
81
}  // namespace oneflow