deconv_2d.cc 3.4 KB
Newer Older
L
liutuo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright 2018 Xiaomi, Inc.  All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mace/kernels/deconv_2d.h"
16
#include "mace/kernels/opencl/image/deconv_2d.h"
L
liutuo 已提交
17 18 19 20

namespace mace {
namespace kernels {

21 22 23 24 25 26
template <typename T>
Deconv2dFunctor<DeviceType::GPU, T>::Deconv2dFunctor(
    OpKernelContext *context,
    const std::vector<int> &strides,
    const Padding &padding_type,
    const std::vector<int> &paddings,
L
liutuo 已提交
27
    const FrameworkType model_type,
28 29 30 31 32 33
    const ActivationType activation,
    const float relux_max_limit)
    : Deconv2dFunctorBase(context,
                          strides,
                          padding_type,
                          paddings,
L
liutuo 已提交
34
                          model_type,
35 36 37 38 39 40
                          activation,
                          relux_max_limit) {
  if (context->device()->opencl_runtime()->UseImageMemory()) {
    kernel_.reset(new opencl::image::Deconv2dKernel<T>);
  } else {
    MACE_NOT_IMPLEMENTED;
L
liutuo 已提交
41 42 43 44
  }
}

template <typename T>
李寅 已提交
45 46 47 48
MaceStatus Deconv2dFunctor<DeviceType::GPU, T>::operator()(
    const Tensor *input,
    const Tensor *filter,
    const Tensor *bias,
L
liutuo 已提交
49
    const Tensor *output_shape_tensor,
李寅 已提交
50 51
    Tensor *output,
    StatsFuture *future) {
L
liutuo 已提交
52 53 54
  MACE_CHECK_NOTNULL(input);
  MACE_CHECK_NOTNULL(filter);
  MACE_CHECK_NOTNULL(output);
L
liutuo 已提交
55
  std::vector<int> paddings(2);
L
liutuo 已提交
56
  std::vector<int> out_paddings(2);
L
liutuo 已提交
57
  std::vector<index_t> output_shape(4);
L
liutuo 已提交
58
  if (model_type_ == FrameworkType::TENSORFLOW) {
L
liutuo 已提交
59
    paddings = std::vector<int>(2, 0);
L
liutuo 已提交
60 61 62 63 64 65 66
    MACE_CHECK_NOTNULL(output_shape_tensor);
    MACE_CHECK(output_shape_tensor->size() == 4);
    Tensor::MappingGuard output_shape_mapper(output_shape_tensor);
    auto output_shape_data =
        output_shape_tensor->data<int32_t>();
    output_shape =
        std::vector<index_t>(output_shape_data, output_shape_data + 4);
L
liutuo 已提交
67 68 69 70 71 72
    CalcDeconvPaddingAndInputSize(input->shape().data(),
                                  filter->shape().data(),
                                  strides_.data(),
                                  padding_type_,
                                  output_shape.data(),
                                  paddings.data());
L
liutuo 已提交
73
  } else {
L
liutuo 已提交
74 75
    out_paddings = paddings_;
    paddings = std::vector<int>(2, 0);
L
liutuo 已提交
76 77 78 79 80
    output_shape = std::vector<index_t>(4, 0);
    CalcDeconvOutputSize(input->shape().data(),
                         filter->shape().data(),
                         strides_.data(),
                         output_shape.data(),
L
liutuo 已提交
81
                         out_paddings.data(),
L
liutuo 已提交
82
                         paddings.data());
L
liutuo 已提交
83 84
  }

85 86 87
  return kernel_->Compute(context_, input, filter, bias,
                          strides_.data(), paddings.data(), activation_,
                          relux_max_limit_, output_shape, output, future);
L
liutuo 已提交
88 89 90 91 92 93 94
}

template struct Deconv2dFunctor<DeviceType::GPU, float>;
template struct Deconv2dFunctor<DeviceType::GPU, half>;

}  // namespace kernels
}  // namespace mace