conv_op_xpu.cc 7.9 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
14 15

#include "paddle/fluid/operators/conv_op.h"
X
xiaoting 已提交
16 17 18 19 20 21 22
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class GemmConvXPUKernel : public framework::OpKernel<T> {
23 24
  using XPUT = typename XPUTypeTrait<T>::Type;

X
xiaoting 已提交
25
 public:
26 27
  void Compute(const framework::ExecutionContext &context) const override {
    const Tensor *input = context.Input<Tensor>("Input");
X
xiaoting 已提交
28 29 30 31
    // The filter will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
32
    Tensor *output = context.Output<Tensor>("Output");
X
xiaoting 已提交
33 34 35 36 37
    output->mutable_data<T>(context.GetPlace());
    int groups = context.Attr<int>("groups");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
38 39 40 41
    const std::string data_format = context.Attr<std::string>("data_format");
    const std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

42 43 44 45
    PADDLE_ENFORCE_EQ(
        data_format == "NDHWC", false,
        platform::errors::InvalidArgument(
            ("XPU does not support data_format is NDHWC in conv op.")));
46 47

    framework::DDim in_data_dims =
48
        phi::slice_ddim(input->dims(), 2, input->dims().size());
49
    framework::DDim filter_data_dims =
50 51
        phi::slice_ddim(filter.dims(), 2, filter.dims().size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
52 53 54
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

55 56 57 58 59 60 61 62 63 64 65 66
    int batch_size = static_cast<int>(input->dims()[0]);
    int img_c = static_cast<int>(input->dims()[1]);
    int img_h = static_cast<int>(input->dims()[2]);
    int img_w = static_cast<int>(input->dims()[3]);
    int f = static_cast<int>(filter.dims()[0]);
    bool is_nchw = true;
    if (data_format == "NHWC") {
      img_c = static_cast<int>(input->dims()[3]);
      img_h = static_cast<int>(input->dims()[1]);
      img_w = static_cast<int>(input->dims()[2]);
      is_nchw = false;
    }
67 68 69 70 71 72 73 74 75

    const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
    const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
    XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());

    auto &dev_ctx = context.template device_context<DeviceContext>();
    int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
        dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
        img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
76
        nullptr, nullptr, nullptr, is_nchw);
77 78 79 80
    PADDLE_ENFORCE_EQ(
        r, XPU_SUCCESS,
        platform::errors::External("XPU conv kernel return wrong value[%d %s]",
                                   r, XPUAPIErrorMsg[r]));
X
xiaoting 已提交
81 82
  }
};
83

X
xiaoting 已提交
84 85
template <typename DeviceContext, typename T>
class GemmConvGradXPUKernel : public framework::OpKernel<T> {
86 87
  using XPUT = typename XPUTypeTrait<T>::Type;

X
xiaoting 已提交
88
 public:
89 90 91
  void Compute(const framework::ExecutionContext &context) const override {
    const Tensor *input = context.Input<Tensor>("Input");
    const Tensor *output_grad =
X
xiaoting 已提交
92
        context.Input<Tensor>(framework::GradVarName("Output"));
93
    Tensor *input_grad =
X
xiaoting 已提交
94
        context.Output<Tensor>(framework::GradVarName("Input"));
95
    Tensor *filter_grad =
X
xiaoting 已提交
96 97 98 99 100 101 102 103 104 105
        context.Output<Tensor>(framework::GradVarName("Filter"));
    // The filter and filter_grad will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
    if (!input_grad && !filter_grad) return;
    int groups = context.Attr<int>("groups");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
106 107 108 109 110
    const std::string data_format = context.Attr<std::string>("data_format");
    const std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    PADDLE_ENFORCE_EQ(
111
        data_format == "NDHWC", false,
112
        platform::errors::InvalidArgument(
113
            ("XPU doesn't support data_format is NDHWC in conv grad op.")));
114 115

    framework::DDim in_data_dims =
116
        phi::slice_ddim(input->dims(), 2, input->dims().size());
117
    framework::DDim filter_data_dims =
118 119
        phi::slice_ddim(filter.dims(), 2, filter.dims().size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
120 121 122
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

123 124 125 126 127 128 129 130 131 132 133 134
    int batch_size = static_cast<int>(input->dims()[0]);
    int img_c = static_cast<int>(input->dims()[1]);
    int img_h = static_cast<int>(input->dims()[2]);
    int img_w = static_cast<int>(input->dims()[3]);
    int f = static_cast<int>(filter.dims()[0]);
    bool is_nchw = true;
    if (data_format == "NHWC") {
      img_c = static_cast<int>(input->dims()[3]);
      img_h = static_cast<int>(input->dims()[1]);
      img_w = static_cast<int>(input->dims()[2]);
      is_nchw = false;
    }
135 136 137 138 139 140

    const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
    const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
    const XPUT *output_grad_data =
        reinterpret_cast<const XPUT *>(output_grad->data<T>());
    XPUT *input_grad_data = nullptr;
X
xiaoting 已提交
141 142
    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
143
      input_grad_data = reinterpret_cast<XPUT *>(input_grad->data<T>());
X
xiaoting 已提交
144
    }
145
    XPUT *filter_grad_data = nullptr;
X
xiaoting 已提交
146 147
    if (filter_grad) {
      filter_grad->mutable_data<T>(context.GetPlace());
148
      filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
X
xiaoting 已提交
149
    }
150 151 152 153 154
    auto &dev_ctx = context.template device_context<DeviceContext>();
    int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(
        dev_ctx.x_context(), input_data, filter_data, output_grad_data,
        input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
        ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
155
        nullptr, nullptr, is_nchw);
156 157 158 159
    PADDLE_ENFORCE_EQ(
        r, XPU_SUCCESS,
        platform::errors::External("XPU conv kernel return wrong value[%d %s]",
                                   r, XPUAPIErrorMsg[r]));
X
xiaoting 已提交
160 161 162 163 164 165
  }
};
}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
166 167 168
    conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
    ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
                           paddle::platform::float16>);
X
xiaoting 已提交
169 170
REGISTER_OP_XPU_KERNEL(
    conv2d_grad,
171 172 173 174 175
    ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
    ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
                               paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
    depthwise_conv2d,
176
    ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
177 178
REGISTER_OP_XPU_KERNEL(
    depthwise_conv2d_grad,
179
    ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
X
xiaoting 已提交
180
#endif