convolution_op.cpp 3.1 KB
Newer Older
J
jiyuan 已提交
1
#include "oneflow/core/operator/convolution_op.h"
2
#include "oneflow/core/common/balanced_splitter.h"
W
willzhang4a58 已提交
3 4 5

namespace oneflow {

W
Will Zhang 已提交
6 7
void ConvolutionOp::InitFromOpConf() {
  CHECK(op_conf().has_convolution_conf());
W
willzhang4a58 已提交
8

W
willzhang4a58 已提交
9 10 11
  EnrollInputBn("in");
  EnrollOutputBn("out");
  EnrollDataTmpBn("col_buf");
W
willzhang4a58 已提交
12

W
willzhang4a58 已提交
13
  EnrollModelBn("weight");
14 15 16 17
  if (GetBoolFromSpecialConf("has_bias_term")) {
    EnrollModelBn("bias");
    EnrollModelTmpBn("bias_multiplier");
  }
W
willzhang4a58 已提交
18 19
}

W
willzhang4a58 已提交
20 21
const PbMessage& ConvolutionOp::GetSpecialConf() const {
  return op_conf().convolution_conf();
22
}
W
willzhang4a58 已提交
23

W
Will Zhang 已提交
24
void ConvolutionOp::InferBlobDescs(
W
willzhang4a58 已提交
25
    std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
W
Will Zhang 已提交
26
    const ParallelContext* parallel_ctx) {
W
Will Zhang 已提交
27 28 29
  const ConvolutionOpConf& conf = op_conf().convolution_conf();
  const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(SoleIbn());
  CHECK_EQ(in_blob_desc->shape().NumAxes(), 4);
duduscript's avatar
duduscript 已提交
30 31
  CHECK_EQ(in_blob_desc->data_type(),
           JobDesc::Singleton()->default_data_type());
W
Will Zhang 已提交
32 33
  int64_t data_num = in_blob_desc->shape().At(0);
  int64_t c_i = in_blob_desc->shape().At(1);
34 35

  int32_t out_num = GetInt32FromSpecialConf("out_num");
W
Will Zhang 已提交
36 37 38
  if (parallel_ctx->policy() == kModelParallel) {
    BalancedSplitter splitter(out_num, parallel_ctx->parallel_num());
    out_num = splitter.At(parallel_ctx->parallel_id()).size();
39 40 41
  }
  int64_t c_o = out_num;

W
willzhang4a58 已提交
42
  int64_t h_len =
W
proto2  
willzhang4a58 已提交
43
      (in_blob_desc->shape().At(2) + 2 * conf.pad_h() - conf.kernel_h())
W
Will Zhang 已提交
44
          / conf.stride_h()
W
willzhang4a58 已提交
45 46
      + 1;
  int64_t w_len =
W
proto2  
willzhang4a58 已提交
47
      (in_blob_desc->shape().At(3) + 2 * conf.pad_w() - conf.kernel_w())
W
Will Zhang 已提交
48
          / conf.stride_w()
W
willzhang4a58 已提交
49
      + 1;
W
Will Zhang 已提交
50
  int64_t output_size = h_len * w_len;
W
proto2  
willzhang4a58 已提交
51
  int64_t kernel = conf.kernel_h() * conf.kernel_w();
52

W
Will Zhang 已提交
53 54 55
  // out
  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(SoleObn());
  out_blob_desc->mut_shape() = Shape({data_num, c_o, h_len, w_len});
duduscript's avatar
duduscript 已提交
56
  out_blob_desc->set_data_type(JobDesc::Singleton()->default_data_type());
W
Will Zhang 已提交
57
  out_blob_desc->set_has_data_id(in_blob_desc->has_data_id());
58

W
Will Zhang 已提交
59 60
  // col_buf
  BlobDesc* col_buf_blob_desc = GetBlobDesc4BnInOp("col_buf");
W
proto2  
willzhang4a58 已提交
61
  col_buf_blob_desc->mut_shape() = Shape({data_num, output_size, c_i * kernel});
W
Will Zhang 已提交
62 63 64 65 66
  col_buf_blob_desc->set_data_type(JobDesc::Singleton()->default_data_type());
  col_buf_blob_desc->set_has_data_id(false);

  // weight
  BlobDesc* weight_blob_desc = GetBlobDesc4BnInOp("weight");
W
proto2  
willzhang4a58 已提交
67
  weight_blob_desc->mut_shape() = Shape({c_o, c_i * kernel});
W
Will Zhang 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
  weight_blob_desc->set_data_type(JobDesc::Singleton()->default_data_type());
  weight_blob_desc->set_has_data_id(false);

  if (conf.has_bias_term()) {
    // bias
    BlobDesc* bias_blob_desc = GetBlobDesc4BnInOp("bias");
    bias_blob_desc->mut_shape() = Shape({c_o});
    bias_blob_desc->set_data_type(JobDesc::Singleton()->default_data_type());
    bias_blob_desc->set_has_data_id(false);

    // bias multiplier
    BlobDesc* bias_multiplier_blob_desc = GetBlobDesc4BnInOp("bias_multiplier");
    bias_multiplier_blob_desc->mut_shape() = Shape({output_size});
    bias_multiplier_blob_desc->set_data_type(
        JobDesc::Singleton()->default_data_type());
    bias_multiplier_blob_desc->set_has_data_id(false);
84
  }
85 86
}

W
willzhang4a58 已提交
87 88
REGISTER_OP(OperatorConf::kConvolutionConf, ConvolutionOp);

W
willzhang4a58 已提交
89
}  // namespace oneflow