conv_fusion_op.cc 14.8 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
#include <vector>
17

18 19
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/generator/get_expected_kernel_func.h"
20
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
21
#include "paddle/phi/kernels/cpu/conv_util.h"
Q
qingqing01 已提交
22 23 24 25

namespace paddle {
namespace operators {

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
inline int ConvOutputSize(int input_size,
                          int filter_size,
                          int dilation,
                          int padding_1,
                          int padding_2,
                          int stride) {
  const int dkernel = dilation * (filter_size - 1) + 1;
  int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1;
  PADDLE_ENFORCE_GT(
      output_size,
      0,
      platform::errors::InvalidArgument(
          "The output's size is expected to be greater than 0. "
          "But received: output's size is %d. The output's size is computed by "
          "((input_size + padding_1 + padding_2 - (dilation * (filter_size - "
          "1) + 1)) / stride + 1), where input_size is %d, padding is "
          "(%d, %d), filter_size is %d, dilation is %d, stride is %d.",
          output_size,
          input_size,
          padding_1,
          padding_2,
          filter_size,
          dilation,
          stride));

  return output_size;
}

Q
qingqing01 已提交
54 55 56 57 58 59
// This fused conv follows the equation:
//   y = act ( alpha1 * conv(x) + alpha2 * z + bias ).
//   here, y is Output,
//         x is Input,
//         z is ResidualData,
//         bias is Bias
T
tianshuo78520a 已提交
60
// When `split_channels` is set, y will be split into multiple outputs,
Q
qingqing01 已提交
61
// each output has split_channels[i] number of channels.
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
class Conv2DFusionOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("Input", "(Tensor), input 0 of conv2d op.");
    AddInput("Filter", "(Tensor), input 1 of conv2d op.");
    AddOutput("Output", "(Tensor), output 0 of conv2d op.");
    AddAttr<std::vector<int>>("strides",
                              "(std::vector<int>), attribute 0 for conv2d op.")
        .SetDefault({1, 1});
    AddAttr<std::vector<int>>("paddings",
                              "(std::vector<int>), attribute 1 for conv2d op.")
        .SetDefault({0, 0});
    AddAttr<std::string>("padding_algorithm",
                         "(std::string), attribute 2 for conv2d op.")
        .SetDefault("EXPLICIT");
    AddAttr<std::vector<int>>("dilations",
                              "(std::vector<int>), attribute 3 for conv2d op.")
        .SetDefault({1, 1});
    AddAttr<int>("groups", "(int), attribute 4 for conv2d op.").SetDefault(1);
    AddAttr<std::string>("data_format",
                         "(std::string), attribute 5 for conv2d op.")
        .SetDefault("NCHW");
    AddComment(R"DOC(
TODO: Documentation of conv2d op.
)DOC");
    Apply();
  }

Q
qingqing01 已提交
90
 protected:
91
  void Apply() {
92 93 94 95 96 97 98 99 100 101
    AddInput("Bias",
             "(Tensor) Bias to be added to each output of filter application."
             "The format of output tensor is X (one-dimensional) of size equal"
             "to the number of output channels. Only used with MKL-DNN.")
        .AsDispensable();
    AddInput("ResidualData",
             "(Tensor) Tensor with residual data "
             "to which convolution output will be added."
             "Used with fuse_residual_connection fusion.")
        .AsDispensable();
Q
qingqing01 已提交
102 103 104 105 106
    AddAttr<std::string>(
        "activation",
        "The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' "
        "'relux' , 'tanh', 'band_pass'")
        .SetDefault("relu");
Q
qingqing01 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    AddAttr<std::vector<int>>(
        "split_channels",
        "When `split_channels` are set, there will be multiple outputs, the "
        "output size is equal to the number of `split_channels`.")
        .SetDefault({});
    AddOutput("Outputs",
              "This Outputs is used when setting `split_channels`."
              "Usually used to fuse conv with same input and same filter size, "
              "padding, stride, dilation size.")
        .AsDuplicable()
        .AsDispensable();
    AddInput("AlgoCache",
             "The cache of convolution algorithm, a RAW type variable.")
        .AsDispensable();
    AddAttr<int>(
        "search_times",
        "The number of exhaustive search times for convolution algorithm.")
        .SetDefault(-1);
125 126 127 128
    AddAttr<bool>(
        "use_cudnn",
        "(bool, default false) Only used in cudnn kernel, need install cudnn")
        .SetDefault(true);
Q
qingqing01 已提交
129 130
  }
};
Q
qingqing01 已提交
131

132
class Conv2DFusionOp : public framework::OperatorWithKernel {
Q
qingqing01 已提交
133
 public:
134
  using framework::OperatorWithKernel::OperatorWithKernel;
Z
Zeng Jinle 已提交
135 136 137

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
138 139
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv2DFusion");
    OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "Conv2DFusion");
140

141 142
    // In some case, attribute data_format is "AnyLayout".
    std::string data_format = ctx->Attrs().Get<std::string>("data_format");
143 144 145 146
    // MKL-DNN Kernels are using NCHW order of dims description
    // so we ignore data_format consideration for MKL-DNN kernel
    const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
                              (data_format == "NHWC" || data_format == "NDHWC");
147 148
    std::vector<int64_t> output_shape =
        ComputeOutputShape(ctx, data_format, channel_last);
149
    ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
150
    ctx->ShareLoD("Input", "Output");
151

152
    std::vector<int> split_channels =
Q
qingqing01 已提交
153
        ctx->Attrs().Get<std::vector<int>>("split_channels");
154
    if (split_channels.size()) {
155 156
      OP_INOUT_CHECK(
          ctx->HasOutputs("Outputs"), "Output", "Outputs", "Conv2DFusion");
157
      PADDLE_ENFORCE_EQ(
158 159
          ctx->Outputs("Outputs").size(),
          split_channels.size(),
160 161 162 163 164
          platform::errors::InvalidArgument(
              "The number of Output(Outputs) of operator 'Conv2DFusion' is "
              "expected to be equal to the length of Attr(split_channels). But "
              "reiceved: the number of Output(Outputs) = %u; the length of "
              "Attr(split_channels) = %u, the content = [%s].",
165 166
              ctx->Outputs("Outputs").size(),
              split_channels.size(),
167
              phi::make_ddim(split_channels)));
168 169 170 171 172

      int split_channels_sum = 0;
      std::vector<framework::DDim> output_shapes(split_channels.size());
      for (size_t i = 0; i < split_channels.size(); ++i) {
        split_channels_sum += split_channels[i];
173 174 175 176 177 178 179 180 181 182 183
        if (channel_last) {
          output_shapes[i] = phi::make_ddim({output_shape[0],
                                             output_shape[1],
                                             output_shape[2],
                                             split_channels[i]});
        } else {
          output_shapes[i] = phi::make_ddim({output_shape[0],
                                             split_channels[i],
                                             output_shape[2],
                                             output_shape[3]});
        }
Q
qingqing01 已提交
184
      }
185 186 187
      int output_channels = output_shape[1];
      // for NHWC
      if (channel_last) output_channels = output_shape[3];
188
      PADDLE_ENFORCE_EQ(
189
          split_channels_sum,
190
          output_channels,
191
          platform::errors::InvalidArgument(
192 193
              "The sum of Attr(split_channels) is expected to be equal to "
              "the "
194
              "total output channels. But received: the sum of "
195
              "Attr(split_channels) = %d, the total output channels = %d.",
196
              split_channels_sum,
197
              output_channels));
198
      ctx->SetOutputsDim("Outputs", output_shapes);
Q
qingqing01 已提交
199 200
    }
  }
H
hong 已提交
201

202 203 204
  std::vector<int64_t> ComputeOutputShape(framework::InferShapeContext* ctx,
                                          const std::string& data_format,
                                          bool channel_last) const {
H
hong 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
    OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");

    auto in_dims = ctx->GetInputDim("Input");
    auto filter_dims = ctx->GetInputDim("Filter");

    std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
    std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
    std::string padding_algorithm =
        ctx->Attrs().Get<std::string>("padding_algorithm");
    int groups = ctx->Attrs().Get<int>("groups");
    std::vector<int> dilations =
        ctx->Attrs().Get<std::vector<int>>("dilations");
    int dilation_size = dilations.size();
    for (int i = 0; i < dilation_size; ++i) {
      PADDLE_ENFORCE_GT(
221 222
          dilations[i],
          0,
H
hong 已提交
223 224 225 226 227 228 229
          platform::errors::InvalidArgument(
              "The dilation of Op(Conv) should be larget than 0, but received "
              "dilation is %d.",
              dilations[i]));
    }

    PADDLE_ENFORCE_EQ(
230 231
        in_dims.size() == 4 || in_dims.size() == 5,
        true,
H
hong 已提交
232 233 234
        platform::errors::InvalidArgument(
            "The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
            "received: input's dimension is %u, input's shape is [%s].",
235 236
            in_dims.size(),
            in_dims));
H
hong 已提交
237 238

    PADDLE_ENFORCE_EQ(
239 240
        in_dims.size(),
        filter_dims.size(),
H
hong 已提交
241 242 243 244 245 246
        platform::errors::InvalidArgument(
            "The input's dimension and filter's dimension of "
            "Op(Conv) should be equal. But received: the input's shape is "
            "[%s], "
            "the input's dimension is %d; the filter's shape is [%s],  "
            "the filter's dimension is %d.",
247 248 249 250
            in_dims,
            in_dims.size(),
            filter_dims,
            filter_dims.size()));
H
hong 已提交
251 252 253 254

    int stride_size = strides.size();
    for (int i = 0; i < stride_size; ++i) {
      PADDLE_ENFORCE_GT(
255 256
          strides[i],
          0,
H
hong 已提交
257 258 259 260 261 262 263
          platform::errors::InvalidArgument(
              "The stride of Op(Conv) should be larget than 0, but received "
              "stride is %d.",
              strides[i]));
    }

    PADDLE_ENFORCE_EQ(
264 265
        in_dims.size(),
        strides.size() + 2U,
H
hong 已提交
266 267 268 269 270 271
        platform::errors::InvalidArgument(
            "The difference of input's dimension and Attr(strides)'s "
            "length must be euqal to 2 for Op(Conv). "
            "But received: input's dimension is %d, input's shape is [%s]; "
            "Attr(stride)'s length is %d, Attr(stride) is [%s]; "
            "difference of input's dimention and Attr(strides)'s length = %u.",
272 273 274 275
            in_dims.size(),
            in_dims,
            strides.size(),
            phi::make_ddim(strides),
276
            in_dims.size() - stride_size));
H
hong 已提交
277 278 279 280 281

    const auto input_channels =
        channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];

    PADDLE_ENFORCE_EQ(
282
        input_channels,
283 284
        (channel_last ? filter_dims[filter_dims.size() - 1] : filter_dims[1]) *
            groups,
H
hong 已提交
285 286 287 288 289 290 291
        platform::errors::InvalidArgument(
            "The number of input's channels should be equal to filter's "
            "channels "
            "* groups for Op(Conv). But received: the input's channels is %d, "
            "the input's shape is [%s]; the filter's channels is %d, the "
            "filter's shape is [%s]; the groups is %d, the data_format is %s. "
            "The error may come from wrong data_format setting.",
292 293
            input_channels,
            in_dims,
294
            channel_last ? filter_dims[filter_dims.size() - 1] : filter_dims[1],
295 296
            filter_dims,
            groups,
H
hong 已提交
297 298
            data_format));
    PADDLE_ENFORCE_EQ(
299 300
        filter_dims[0] % groups,
        0,
H
hong 已提交
301 302 303 304 305
        platform::errors::InvalidArgument(
            "The number of output's channels (filter's first dimension) of "
            "Op(Conv) should be divided by groups. But received: "
            "the output channels is %d, the filter's shape is [%s], "
            "the groups is %d.",
306 307 308
            filter_dims[0],
            filter_dims,
            groups));
H
hong 已提交
309 310 311

    if (ctx->IsRuntime()) {
      PADDLE_ENFORCE_GT(
312 313
          filter_dims[0],
          0,
H
hong 已提交
314 315 316 317 318 319 320 321 322 323 324
          platform::errors::InvalidArgument(
              "the size of filter at axis 0 should be greater than 0"));
    }

    framework::DDim in_data_dims;
    if (channel_last) {
      in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
    } else {
      in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
    }

325 326 327 328 329 330 331
    framework::DDim filter_data_dims;
    if (channel_last) {
      filter_data_dims =
          phi::slice_ddim(filter_dims, 1, filter_dims.size() - 1);
    } else {
      filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size());
    }
H
hong 已提交
332 333

    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
334
    phi::UpdatePaddingAndDilation(
335
        &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
H
hong 已提交
336 337 338 339 340 341 342 343 344 345

    std::vector<int64_t> output_shape({in_dims[0]});
    if (!channel_last) {
      output_shape.push_back(filter_dims[0]);
    }
    for (int i = 0; i < in_data_dims.size(); ++i) {
      if ((!ctx->IsRuntime()) &&
          (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
        output_shape.push_back(-1);
      } else {
346 347 348 349 350 351
        output_shape.push_back(ConvOutputSize(in_data_dims[i],
                                              filter_data_dims[i],
                                              dilations[i],
                                              paddings[2 * i],
                                              paddings[2 * i + 1],
                                              strides[i]));
H
hong 已提交
352 353 354 355 356 357 358 359
      }
    }
    if (channel_last) {
      output_shape.push_back(filter_dims[0]);
    }

    return output_shape;
  }
360 361 362 363 364

  phi::KernelKey GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetConvExpectedKernelType(ctx, this);
  }
Q
qingqing01 已提交
365 366
};

Q
qingqing01 已提交
367 368 369 370 371 372
// TODO(qingqing): add gradient operator for conv2d_fusion

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
373
REGISTER_OPERATOR(
374 375 376
    conv2d_fusion,
    ops::Conv2DFusionOp,
    ops::Conv2DFusionOpMaker,
H
hong 已提交
377 378
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
379 380 381 382 383 384 385 386 387

// This op is used by cutlass, conv2d_fusion_cutlass is a intermediate op
// produced by conv2d_fusion_layout_transfer_pass.
REGISTER_OPERATOR(
    conv2d_fusion_cutlass,
    ops::Conv2DFusionOp,
    ops::Conv2DFusionOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);