fused_conv_2d.h 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//

#ifndef MACE_KERNELS_FUSED_CONV_2D_H_
#define MACE_KERNELS_FUSED_CONV_2D_H_

#include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/conv_2d.h"

namespace mace {
namespace kernels {

struct FusedConv2dFunctorBase {
  FusedConv2dFunctorBase(const int *strides,
                         const Padding &paddings,
                         const int *dilations)
      : strides_(strides), dilations_(dilations), paddings_(paddings) {}

  const int *strides_;         // [stride_h, stride_w]
  const int *dilations_;       // [dilation_h, dilation_w]
  Padding paddings_;
};

template<DeviceType D, typename T>
struct FusedConv2dFunctor : FusedConv2dFunctorBase {
  FusedConv2dFunctor(const int *strides,
                     const Padding &paddings,
                     const int *dilations)
      : FusedConv2dFunctorBase(strides, paddings, dilations) {}

  void operator()(const Tensor *input,
                  const Tensor *filter,
                  const Tensor *bias,
                  Tensor *output) {
    Conv2dFunctor<D, T>(strides_, paddings_, dilations_)(input, filter, bias, output);
    T *output_data = output->mutable_data<T>();

    T zero_value;
    if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
      zero_value = half_float::half_cast<half>(0.0f);
    } else {
      zero_value = 0;
    }
    auto output_size = output->size();
    for (int n = 0; n < output_size; ++n) {
      *output_data = *output_data < 0 ? zero_value : *output_data;
      output_data++;
    }
  }

};

template<typename T>
struct FusedConv2dFunctor<DeviceType::OPENCL, T> : FusedConv2dFunctorBase {
  FusedConv2dFunctor(const int *strides,
                     const Padding &paddings,
                     const int *dilations)
      : FusedConv2dFunctorBase(strides, paddings, dilations) {}

  void operator()(const Tensor *input,
                  const Tensor *filter,
                  const Tensor *bias,
                  Tensor *output);
};

}  // namespace kernels
}  // namespace mace

#endif  // MACE_KERNELS_FUSED_CONV_2D_H_