提交 dfc8d2b7 编写于 作者: L Liangliang He

Merge branch 'refactor-opencl-kernel' into 'master'

Refactor opencl kernel : move opencl kernel to member variable.

See merge request !230
......@@ -8,6 +8,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/types.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -128,6 +129,7 @@ class ActivationFunctor<DeviceType::OPENCL, T> {
ActivationType activation_;
T relux_max_limit_;
T prelu_alpha_;
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -7,6 +7,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -42,6 +43,8 @@ template <typename T>
struct AddNFunctor<DeviceType::OPENCL, T> {
void operator()(const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor, StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -9,6 +9,7 @@
#include "mace/core/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -133,6 +134,7 @@ struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
const float epsilon,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namepsace kernels
......
......@@ -8,6 +8,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/public/mace.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -63,6 +64,7 @@ struct BiasAddFunctor<DeviceType::OPENCL, T> {
const Tensor *bias,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namepsace kernels
......
......@@ -10,6 +10,7 @@
#include "mace/core/types.h"
#include "mace/core/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -81,6 +82,7 @@ struct ConcatFunctor<DeviceType::OPENCL, T> : ConcatFunctorBase{
void operator()(const std::vector<const Tensor *> &input_list,
Tensor *output, StatsFuture *future);
cl::Kernel kernel_;
};
......
......@@ -9,6 +9,7 @@
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -176,6 +177,8 @@ struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
const Tensor *bias,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -9,6 +9,7 @@
#include "mace/core/future.h"
#include "mace/core/public/mace.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -57,7 +58,6 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(bias);
MACE_CHECK_NOTNULL(output);
// Create a fake conv_2d filter to calculate the paddings and output size
......@@ -113,7 +113,7 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *filter_ptr = filter->data<T>();
const T *bias_ptr = bias->data<T>();
const T *bias_ptr = bias == nullptr ? nullptr : bias->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for collapse(2)
......@@ -153,6 +153,10 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
}
}
}
output_ptr = output->mutable_data<T>();
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
}
};
......@@ -178,13 +182,15 @@ struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T>
dilations,
activation,
relux_max_limit,
prelu_alpha) {}
prelu_alpha){}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -7,6 +7,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -58,6 +59,8 @@ struct MatMulFunctor<DeviceType::OPENCL, T> {
const Tensor *B,
Tensor *C,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -22,52 +22,60 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
std::string tuning_key_prefix;
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation");
built_options.emplace("-Dactivation=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
switch (activation_) {
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation");
built_options.emplace("-Dactivation=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
switch (activation_) {
case RELU:
tuning_key_prefix = "relu_opencl_kernel_";
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
tuning_key_prefix = "relux_opencl_kernel_";
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
tuning_key_prefix = "prelu_opencl_kernel_";
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
tuning_key_prefix = "tanh_opencl_kernel_";
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
tuning_key_prefix = "sigmoid_opencl_kernel_";
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
kernel_ =
runtime->BuildKernel("activation", kernel_name, built_options);
int idx = 0;
kernel_.setArg(
idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, static_cast<float>(relux_max_limit_));
kernel_.setArg(idx++, static_cast<float>(prelu_alpha_));
kernel_.setArg(idx++,
*(static_cast<cl::Image2D *>(output->buffer())));
}
cl::Kernel activation_kernel =
runtime->BuildKernel("activation", kernel_name, built_options);
int idx = 0;
activation_kernel.setArg(
idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
activation_kernel.setArg(idx++, relux_max_limit_);
activation_kernel.setArg(idx++, prelu_alpha_);
activation_kernel.setArg(idx++,
*(static_cast<cl::Image2D *>(output->buffer())));
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key =
Concat("relu_opencl_kernel_", activation_, output->dim(0), output->dim(1),
Concat(tuning_key_prefix, output->dim(0), output->dim(1),
output->dim(2), output->dim(3));
TuningOrRun3DKernel(activation_kernel, tuning_key, gws, lws, future);
TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future);
}
template struct ActivationFunctor<DeviceType::OPENCL, float>;
......
......@@ -11,53 +11,6 @@
namespace mace {
namespace kernels {
template <typename T>
static void AddN(const std::vector<const Tensor *> &input_tensors,
Tensor *output, StatsFuture *future) {
if (input_tensors.size() > 4) {
MACE_NOT_IMPLEMENTED;
}
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
const index_t channels = output->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height;
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("addn");
built_options.emplace("-Daddn=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace("-DINPUT_NUM=" + ToString(input_tensors.size()));
auto addn_kernel = runtime->BuildKernel("addn", kernel_name, built_options);
uint32_t idx = 0;
for (auto input : input_tensors) {
addn_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
}
addn_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
const uint32_t gws[2] = {
static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)
};
const std::vector<uint32_t> lws = {64, 16, 1};
std::stringstream ss;
ss << "addn_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun2DKernel(addn_kernel, ss.str(), gws, lws, future);
}
template <typename T>
void AddNFunctor<DeviceType::OPENCL, T>::operator()(
const std::vector<const Tensor *> &input_tensors,
......@@ -84,7 +37,44 @@ void AddNFunctor<DeviceType::OPENCL, T>::operator()(
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output_tensor->ResizeImage(output_shape, output_image_shape);
AddN<T>(input_tensors, output_tensor, future);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height;
if (kernel_.get() == nullptr) {
if (input_tensors.size() > 4) {
MACE_NOT_IMPLEMENTED;
}
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("addn");
built_options.emplace("-Daddn=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace("-DINPUT_NUM=" + ToString(input_tensors.size()));
kernel_ = runtime->BuildKernel("addn", kernel_name, built_options);
uint32_t idx = 0;
for (auto input : input_tensors) {
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
}
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
}
const uint32_t gws[2] = {
static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)
};
const std::vector<uint32_t> lws = {64, 16, 1};
std::stringstream ss;
ss << "addn_opencl_kernel_"
<< output_shape[0] << "_"
<< output_shape[1] << "_"
<< output_shape[2] << "_"
<< output_shape[3];
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
};
template
......
......@@ -30,55 +30,57 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const index_t channel_blocks = RoundUpDiv4(channels);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("batch_norm");
built_options.emplace("-Dbatch_norm=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (folded_constant_) {
built_options.emplace("-DFOLDED_CONSTANT");
}
switch (activation_) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("batch_norm");
built_options.emplace("-Dbatch_norm=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (folded_constant_) {
built_options.emplace("-DFOLDED_CONSTANT");
}
switch (activation_) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
auto bm_kernel =
runtime->BuildKernel("batch_norm", kernel_name, built_options);
kernel_ =
runtime->BuildKernel("batch_norm", kernel_name, built_options);
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(scale->buffer())));
bm_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(offset->buffer())));
if (!folded_constant_) {
bm_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(var->buffer())));
bm_kernel.setArg(idx++, epsilon);
uint32_t idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(scale->buffer())));
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(offset->buffer())));
if (!folded_constant_) {
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(mean->buffer())));
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(var->buffer())));
kernel_.setArg(idx++, epsilon);
}
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
}
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
bm_kernel.setArg(idx++, relux_max_limit_);
bm_kernel.setArg(idx++, prelu_alpha_);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
......@@ -87,7 +89,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
std::string tuning_key =
Concat("batch_norm_opencl_kernel_", activation_, output->dim(0),
output->dim(1), output->dim(2), output->dim(3), folded_constant_);
TuningOrRun3DKernel(bm_kernel, tuning_key, gws, lws, future);
TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future);
}
template struct BatchNormFunctor<DeviceType::OPENCL, float>;
......
......@@ -24,30 +24,30 @@ void BiasAddFunctor<DeviceType::OPENCL, T>::operator()(
const index_t channel_blocks = RoundUpDiv4(channels);
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("bias_add");
built_options.emplace("-Dbias_add=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("bias_add", kernel_name, built_options);
uint32_t idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("bias_add");
built_options.emplace("-Dbias_add=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto bias_kernel = runtime->BuildKernel("bias_add", kernel_name, built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bias_kernel);
const std::vector<uint32_t> lws = {8, 16, 8};
uint32_t idx = 0;
bias_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
bias_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
bias_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bias_kernel, cl::NullRange,
kernel_, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]),
nullptr, &event);
......
#include <common.h>
__kernel void activation(__read_only image2d_t input,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const float relux_max_limit,
__private const float prelu_alpha,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
......
......@@ -2,24 +2,24 @@
// Only multiplier = 1 is supported
__kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * kh * kw * m, cin/4 */
__read_only image2d_t filter, /* cout%4 * kh * kw * m, cin/4 */
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const short in_height,
__private const short in_width,
__private const short in_ch_blks,
__private const short out_height,
__private const short out_width,
__private const short filter_height,
__private const short filter_width,
__private const short padding_top,
__private const short padding_left,
__private const short dilation_h,
__private const short dilation_w) {
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const short in_height,
__private const short in_width,
__private const short in_ch_blks,
__private const short out_height,
__private const short out_width,
__private const short filter_height,
__private const short filter_width,
__private const short padding_top,
__private const short padding_left,
__private const short dilation_h,
__private const short dilation_w) {
const short out_ch_blk = get_global_id(0);
const short out_w_blk = get_global_id(1);
const short out_w_blks = get_global_size(1);
......@@ -52,7 +52,6 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
int in_width1 = ((out_w_blk + out_w_blks) << 1) - padding_left;
int in_width2 = ((out_w_blk + (out_w_blks << 1)) << 1) - padding_left;
int in_width3 = ((out_w_blk + (out_w_blks << 1) + out_w_blks) << 1) - padding_left;
int in_width4 = ((out_w_blk + (out_w_blks << 2)) << 1) - padding_left;
const int height_idx = (out_h << 1) - padding_top;
#else
const short in_width_stride = mul24(out_w_blks, STRIDE);
......@@ -90,7 +89,7 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
READ_INPUT(3);
#undef READ_INPUT
DATA_TYPE4 weights = READ_IMAGET(filter, SAMPLER,
(int2)(filter_idx, in_ch_blk));
......@@ -127,3 +126,120 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
if (w >= out_width) return;
WRITE_IMAGET(output, (int2)(out_x_base + w, out_hb), out3);
}
__kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * kh * kw * m, cin/4 */
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const short in_height,
__private const short in_width,
__private const short in_ch_blks,
__private const short out_height,
__private const short out_width,
__private const short filter_height,
__private const short filter_width,
__private const short padding_top,
__private const short padding_left) {
const short out_ch_blk = get_global_id(0);
const short out_w_blk = get_global_id(1) << 2;
const short out_hb = get_global_id(2);
const short rounded_in_ch = in_ch_blks << 2;
const short in_ch_blk = out_ch_blk; // multiplier = 1
#ifdef BIAS
DATA_TYPE4 out0 =
READ_IMAGET(bias, SAMPLER, (int2)(out_ch_blk, 0));
DATA_TYPE4 out1 = out0;
DATA_TYPE4 out2 = out0;
DATA_TYPE4 out3 = out0;
#else
DATA_TYPE4 out0 = 0;
DATA_TYPE4 out1 = 0;
DATA_TYPE4 out2 = 0;
DATA_TYPE4 out3 = 0;
#endif
const short out_h = out_hb % out_height;
const short in_width0 = out_w_blk - padding_left;
const short in_width1 = in_width0 + 1;
const short in_width2 = in_width1 + 1;
const short in_width3 = in_width2 + 1;
const short height_idx = out_h - padding_top;
const short batch_idx = mul24((out_hb / out_height), in_height);
const short rounded_in_ch_x_filter_width = mul24(rounded_in_ch, filter_width);
const short in_idx = mul24(in_ch_blk, in_width);
short filter_idx = 0;
short in_hb_idx = height_idx;
const short in_w_idx0 = select(in_idx + in_width0,
-1,
(in_width0 < 0 || in_width0 >= in_width));
const short in_w_idx1 = select(in_idx + in_width1,
-1,
(in_width1 < 0 || in_width1 >= in_width));
const short in_w_idx2 = select(in_idx + in_width2,
-1,
(in_width2 < 0 || in_width2 >= in_width));
short in_w;
DATA_TYPE4 in0, in1, in2, in3;
for (short filter_h_idx = 0; filter_h_idx < filter_height; ++filter_h_idx) {
short in_hb = select(in_hb_idx + batch_idx,
-1,
(in_hb_idx < 0 || in_hb_idx >= in_height));
in1 = READ_IMAGET(input, SAMPLER, (int2)(in_w_idx0, in_hb));
in2 = READ_IMAGET(input, SAMPLER, (int2)(in_w_idx1, in_hb));
in3 = READ_IMAGET(input, SAMPLER, (int2)(in_w_idx2, in_hb));
for (short filter_w_idx = 0; filter_w_idx < filter_width; ++filter_w_idx) {
in0 = in1;
in1 = in2;
in2 = in3;
in_w = in_width3 + filter_w_idx;
in_w = select(in_idx + in_w,
-1,
(in_w < 0 || in_w >= in_width));
in3 = READ_IMAGET(input, SAMPLER, (int2)(in_w, in_hb));
DATA_TYPE4 weights = READ_IMAGET(filter, SAMPLER,
(int2)(filter_idx, in_ch_blk));
out0 = mad(in0, weights, out0);
out1 = mad(in1, weights, out1);
out2 = mad(in2, weights, out2);
out3 = mad(in3, weights, out3);
++filter_idx;
}
in_hb_idx += 1;
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
#endif
const short out_x_base = mul24(out_ch_blk, out_width);
short w = out_w_blk;
WRITE_IMAGET(output, (int2)(out_x_base + w, out_hb), out0);
w += 1;
if (w >= out_width) return;
WRITE_IMAGET(output, (int2)(out_x_base + w, out_hb), out1);
w += 1;
if (w >= out_width) return;
WRITE_IMAGET(output, (int2)(out_x_base + w, out_hb), out2);
w += 1;
if (w >= out_width) return;
WRITE_IMAGET(output, (int2)(out_x_base + w, out_hb), out3);
}
......@@ -11,7 +11,8 @@
namespace mace {
namespace kernels {
static void Concat2(const Tensor *input0,
static void Concat2(cl::Kernel *kernel,
const Tensor *input0,
const Tensor *input1,
const DataType dt,
Tensor *output,
......@@ -23,27 +24,29 @@ static void Concat2(const Tensor *input0,
const int channel_blk = RoundUpDiv4(channel);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("concat_channel");
built_options.emplace("-Dconcat_channel=" + kernel_name);
if (input0->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
}
if (input0->dim(3) % 4 == 0) {
built_options.emplace("-DDIVISIBLE_FOUR");
}
auto concat_kernel = runtime->BuildKernel("concat", kernel_name, built_options);
if (kernel->get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("concat_channel");
built_options.emplace("-Dconcat_channel=" + kernel_name);
if (input0->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
}
if (input0->dim(3) % 4 == 0) {
built_options.emplace("-DDIVISIBLE_FOUR");
}
*kernel = runtime->BuildKernel("concat", kernel_name, built_options);
uint32_t idx = 0;
concat_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input0->buffer())));
concat_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input1->buffer())));
concat_kernel.setArg(idx++, static_cast<int32_t>(input0->dim(3)));
concat_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
uint32_t idx = 0;
kernel->setArg(idx++, *(static_cast<const cl::Image2D *>(input0->buffer())));
kernel->setArg(idx++, *(static_cast<const cl::Image2D *>(input1->buffer())));
kernel->setArg(idx++, static_cast<int32_t>(input0->dim(3)));
kernel->setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
}
const uint32_t gws[3] = {
static_cast<uint32_t>(channel_blk),
......@@ -57,7 +60,7 @@ static void Concat2(const Tensor *input0,
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun3DKernel(concat_kernel, ss.str(), gws, lws, future);
TuningOrRun3DKernel(*kernel, ss.str(), gws, lws, future);
}
template<typename T>
......@@ -90,7 +93,7 @@ void ConcatFunctor<DeviceType::OPENCL, T>::operator()(const std::vector<const Te
switch (inputs_count) {
case 2:
Concat2(input_list[0], input_list[1], DataTypeToEnum<T>::value,
Concat2(&kernel_, input_list[0], input_list[1], DataTypeToEnum<T>::value,
output, future);
break;
default:MACE_NOT_IMPLEMENTED;
......
......@@ -3,64 +3,44 @@
//
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
extern void Conv2dOpenclK1x1S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK3x3S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK3x3S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpencl(const Tensor *input,
extern void Conv2dOpencl(cl::Kernel *kernel,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const uint32_t stride,
const int stride,
const int *padding,
const int *dilations,
const ActivationType activation,
......@@ -70,24 +50,21 @@ extern void Conv2dOpencl(const Tensor *input,
Tensor *output,
StatsFuture *future);
template <typename T>
template<typename T>
void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
typedef void (*Conv2dOpenclFunction)(
const Tensor *input, const Tensor *filter, const Tensor *bias,
cl::Kernel *kernel,
const Tensor *input, const Tensor *filter, const Tensor *bias, const int stride,
const int *padding, const int *dilations, const ActivationType activation,
const float relux_max_limit, const float prelu_alpha, const DataType dt,
Tensor *output, StatsFuture *future);
// Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
{nullptr, nullptr},
{Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2},
{nullptr, nullptr},
{nullptr, nullptr}};
static const Conv2dOpenclFunction selector[5] =
{Conv2dOpenclK1x1, nullptr, Conv2dOpenclK3x3, nullptr, nullptr};
index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1);
......@@ -113,20 +90,23 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
output->ResizeImage(output_shape, output_image_shape);
if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1][strides_[0] - 1] != nullptr) {
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_func(input, filter, bias, paddings.data(), dilations_, activation_,
selector[kernel_h - 1] != nullptr &&
0 < strides_[0] && strides_[0] < 3 ) {
auto conv2d_func = selector[kernel_h - 1];
conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_,
relux_max_limit_, prelu_alpha_, DataTypeToEnum<T>::value,
output, future);
} else {
Conv2dOpencl(input, filter, bias, strides_[0], paddings.data(), dilations_,
Conv2dOpencl(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_,
activation_, relux_max_limit_, prelu_alpha_,
DataTypeToEnum<T>::value, output, future);
}
}
template struct Conv2dFunctor<DeviceType::OPENCL, float>;
template struct Conv2dFunctor<DeviceType::OPENCL, half>;
template
struct Conv2dFunctor<DeviceType::OPENCL, float>;
template
struct Conv2dFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -3,26 +3,26 @@
//
#include "mace/kernels/conv_2d.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
void Conv1x1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
......@@ -36,62 +36,64 @@ void Conv1x1(const Tensor *input,
const index_t width_blocks = RoundUpDiv4(width);
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
MACE_CHECK(input_batch == batch);
if (kernel->get() == nullptr) {
MACE_CHECK(input_batch == batch);
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d_1x1");
built_options.emplace("-Dconv_2d_1x1=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace("-DSTRIDE=" + ToString(stride));
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d_1x1");
built_options.emplace("-Dconv_2d_1x1=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace("-DSTRIDE=" + ToString(stride));
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel =
runtime->BuildKernel("conv_2d_1x1", kernel_name, built_options);
auto runtime = OpenCLRuntime::Global();
*kernel =
runtime->BuildKernel("conv_2d_1x1", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
uint32_t idx = 0;
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
}
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
// FIXME handle flexable data type: half not supported
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<int>(input_height));
kernel->setArg(idx++, static_cast<int>(input_width));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width));
}
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
// FIXME handle flexable data type: half not supported
conv_2d_kernel.setArg(idx++, relux_max_limit);
conv_2d_kernel.setArg(idx++, prelu_alpha);
conv_2d_kernel.setArg(idx++, static_cast<int>(input_height));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_width));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channel_blocks));
conv_2d_kernel.setArg(idx++, static_cast<int>(height));
conv_2d_kernel.setArg(idx++, static_cast<int>(width));
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks),
......@@ -100,38 +102,9 @@ void Conv1x1(const Tensor *input,
std::string tuning_key =
Concat("conv2d_1x1_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
TuningOrRun3DKernel(conv_2d_kernel, tuning_key, gws, lws, future);
TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future);
}
extern void Conv2dOpenclK1x1S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv1x1(input, filter, bias, 1, activation, relux_max_limit, prelu_alpha, dt,
output, future);
};
extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv1x1(input, filter, bias, 2, activation, relux_max_limit, prelu_alpha, dt,
output, future);
};
} // namespace kernels
} // namespace mace
......@@ -13,18 +13,19 @@
namespace mace {
namespace kernels {
static void Conv2d3x3S12(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const uint32_t stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
......@@ -35,61 +36,63 @@ static void Conv2d3x3S12(const Tensor *input,
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
const index_t width_blocks = RoundUpDiv<index_t, 5>(width);
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d_3x3");
built_options.emplace("-Dconv_2d_3x3=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d_3x3");
built_options.emplace("-Dconv_2d_3x3=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel =
runtime->BuildKernel("conv_2d_3x3", kernel_name, built_options);
auto runtime = OpenCLRuntime::Global();
*kernel =
runtime->BuildKernel("conv_2d_3x3", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
uint32_t idx = 0;
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
}
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<int>(input->dim(1)));
kernel->setArg(idx++, static_cast<int>(input->dim(2)));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width));
kernel->setArg(idx++, padding[0] / 2);
kernel->setArg(idx++, padding[1] / 2);
kernel->setArg(idx++, dilations[0]);
kernel->setArg(idx++, dilations[1]);
}
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++, relux_max_limit);
conv_2d_kernel.setArg(idx++, prelu_alpha);
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(1)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channel_blocks));
conv_2d_kernel.setArg(idx++, static_cast<int>(height));
conv_2d_kernel.setArg(idx++, static_cast<int>(width));
conv_2d_kernel.setArg(idx++, padding[0] / 2);
conv_2d_kernel.setArg(idx++, padding[1] / 2);
conv_2d_kernel.setArg(idx++, dilations[0]);
conv_2d_kernel.setArg(idx++, dilations[1]);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks),
......@@ -98,37 +101,8 @@ static void Conv2d3x3S12(const Tensor *input,
std::string tuning_key =
Concat("conv2d_3x3_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
TuningOrRun3DKernel(conv_2d_kernel, tuning_key, gws, lws, future);
TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future);
}
void Conv2dOpenclK3x3S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv2d3x3S12(input, filter, bias, 1, padding, dilations, activation,
relux_max_limit, prelu_alpha, dt, output, future);
};
void Conv2dOpenclK3x3S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv2d3x3S12(input, filter, bias, 2, padding, dilations, activation,
relux_max_limit, prelu_alpha, dt, output, future);
};
} // namespace kernels
} // namespace mace
......@@ -13,18 +13,19 @@
namespace mace {
namespace kernels {
void Conv2dOpencl(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const uint32_t stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
extern void Conv2dOpencl(cl::Kernel *kernel,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
......@@ -35,64 +36,66 @@ void Conv2dOpencl(const Tensor *input,
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
const index_t width_blocks = RoundUpDiv4(width);
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d");
built_options.emplace("-Dconv_2d=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d");
built_options.emplace("-Dconv_2d=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel =
runtime->BuildKernel("conv_2d", kernel_name, built_options);
auto runtime = OpenCLRuntime::Global();
*kernel =
runtime->BuildKernel("conv_2d", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
uint32_t idx = 0;
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
}
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(1)));
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(2)));
kernel->setArg(idx++, static_cast<uint32_t>(input_channel_blocks));
kernel->setArg(idx++, static_cast<uint32_t>(height));
kernel->setArg(idx++, static_cast<uint32_t>(width));
kernel->setArg(idx++, static_cast<uint32_t>(filter->dim(0)));
kernel->setArg(idx++, static_cast<uint32_t>(filter->dim(1)));
kernel->setArg(idx++, static_cast<uint32_t>(stride));
kernel->setArg(idx++, padding[0] / 2);
kernel->setArg(idx++, padding[1] / 2);
kernel->setArg(idx++, dilations[0]);
kernel->setArg(idx++, dilations[1]);
}
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++, relux_max_limit);
conv_2d_kernel.setArg(idx++, prelu_alpha);
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(input->dim(1)));
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(input_channel_blocks));
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(height));
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(width));
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(filter->dim(0)));
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(filter->dim(1)));
conv_2d_kernel.setArg(idx++, static_cast<uint32_t>(stride));
conv_2d_kernel.setArg(idx++, padding[0] / 2);
conv_2d_kernel.setArg(idx++, padding[1] / 2);
conv_2d_kernel.setArg(idx++, dilations[0]);
conv_2d_kernel.setArg(idx++, dilations[1]);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks),
......@@ -101,7 +104,7 @@ void Conv2dOpencl(const Tensor *input,
std::string tuning_key =
Concat("conv2d_general_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
TuningOrRun3DKernel(conv_2d_kernel, tuning_key, gws, lws, future);
TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future);
}
} // namespace kernels
......
......@@ -11,7 +11,8 @@
namespace mace {
namespace kernels {
void DepthwiseConv2d(const Tensor *input, // NHWC
void DepthwiseConv2d(cl::Kernel *kernel,
const Tensor *input, // NHWC
const Tensor *filter, // HWIM
const Tensor *bias,
const int stride,
......@@ -28,80 +29,88 @@ void DepthwiseConv2d(const Tensor *input, // NHWC
const index_t width = output->dim(2);
const index_t channels = output->dim(3);
const index_t input_batch = input->dim(0);
const index_t input_height = input->dim(1);
const index_t input_width = input->dim(2);
const index_t input_channels = input->dim(3);
const index_t filter_height = filter->dim(0);
const index_t filter_width = filter->dim(1);
const index_t multiplier = filter->dim(3);
MACE_CHECK(multiplier == 1, "Multiplier > 1 not supported");
MACE_CHECK(multiplier * input_channels == channels);
MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=",
input_channels);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
const index_t width_blocks = RoundUpDiv4(width);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depthwise_conv2d");
built_options.emplace("-Ddepthwise_conv2d=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
auto dw_conv2d_kernel =
runtime->BuildKernel("depthwise_conv2d", kernel_name, built_options);
uint32_t idx = 0;
dw_conv2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
dw_conv2d_kernel.setArg(
idx++, *(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
dw_conv2d_kernel.setArg(
idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
if(kernel->get() == nullptr) {
const index_t input_batch = input->dim(0);
const index_t input_height = input->dim(1);
const index_t input_width = input->dim(2);
const index_t filter_height = filter->dim(0);
const index_t filter_width = filter->dim(1);
MACE_CHECK(multiplier == 1, "Multiplier > 1 not supported");
MACE_CHECK(multiplier * input_channels == channels);
MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=",
input_channels);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depthwise_conv2d");
if (stride == 1 && dilations[0] == 1 && dilations[1] == 1) {
kernel_name = MACE_OBFUSCATE_SYMBOL("depthwise_conv2d_s1");
built_options.emplace("-Ddepthwise_conv2d_s1=" + kernel_name);
} else {
built_options.emplace("-Ddepthwise_conv2d=" + kernel_name);
}
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
*kernel = runtime->BuildKernel("depthwise_conv2d", kernel_name, built_options);
uint32_t idx = 0;
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
kernel->setArg(
idx++, *(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
kernel->setArg(
idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
}
kernel->setArg(
idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<short>(input_height));
kernel->setArg(idx++, static_cast<short>(input_width));
kernel->setArg(idx++, static_cast<short>(input_channel_blocks));
kernel->setArg(idx++, static_cast<short>(height));
kernel->setArg(idx++, static_cast<short>(width));
kernel->setArg(idx++, static_cast<short>(filter_height));
kernel->setArg(idx++, static_cast<short>(filter_width));
kernel->setArg(idx++, static_cast<short>(paddings[0] / 2));
kernel->setArg(idx++, static_cast<short>(paddings[1] / 2));
if (stride != 1 || dilations[0] != 1 || dilations[1] != 1) {
kernel->setArg(idx++, static_cast<short>(dilations[0]));
kernel->setArg(idx++, static_cast<short>(dilations[1]));
}
}
dw_conv2d_kernel.setArg(
idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
dw_conv2d_kernel.setArg(idx++, relux_max_limit);
dw_conv2d_kernel.setArg(idx++, prelu_alpha);
dw_conv2d_kernel.setArg(idx++, static_cast<short>(input_height));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(input_width));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(input_channel_blocks));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(height));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(width));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(filter_height));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(filter_width));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(paddings[0] / 2));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(paddings[1] / 2));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(dilations[0]));
dw_conv2d_kernel.setArg(idx++, static_cast<short>(dilations[1]));
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks),
......@@ -109,7 +118,7 @@ void DepthwiseConv2d(const Tensor *input, // NHWC
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key = Concat("depthwise_conv2d_ocl_kernel_", activation,
batch, height, width, channels, multiplier);
TuningOrRun3DKernel(dw_conv2d_kernel, tuning_key, gws, lws, future);
TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future);
}
template <typename T>
......@@ -153,7 +162,7 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
DepthwiseConv2d(input, filter, bias, strides_[0], paddings.data(), dilations_,
DepthwiseConv2d(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_,
activation_, relux_max_limit_, prelu_alpha_,
DataTypeToEnum<T>::value, output, future);
}
......
......@@ -121,18 +121,24 @@ std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape,
std::string DtToCLDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:return "float";
case DT_HALF:return "half";
default:LOG(FATAL) << "Unsupported data type";
case DT_FLOAT:
return "float";
case DT_HALF:
return "half";
default:
LOG(FATAL) << "Unsupported data type";
return "";
}
}
std::string DtToCLCMDDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:return "f";
case DT_HALF:return "h";
default:LOG(FATAL) << "Not supported data type for opencl cmd data type";
case DT_FLOAT:
return "f";
case DT_HALF:
return "h";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
return "";
}
}
......@@ -140,8 +146,10 @@ std::string DtToCLCMDDt(const DataType dt) {
std::string DtToUpstreamCLDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_HALF:return "float";
default:LOG(FATAL) << "Unsupported data type";
case DT_HALF:
return "float";
default:
LOG(FATAL) << "Unsupported data type";
return "";
}
}
......@@ -149,8 +157,10 @@ std::string DtToUpstreamCLDt(const DataType dt) {
std::string DtToUpstreamCLCMDDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_HALF:return "f";
default:LOG(FATAL) << "Not supported data type for opencl cmd data type";
case DT_HALF:
return "f";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
return "";
}
}
......@@ -161,8 +171,8 @@ void TuningOrRun3DKernel(cl::Kernel &kernel,
const std::vector<uint32_t> &lws,
StatsFuture *future) {
auto runtime = OpenCLRuntime::Global();
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel);
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel);
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
......@@ -258,8 +268,8 @@ void TuningOrRun2DKernel(cl::Kernel &kernel,
const std::vector<uint32_t> &lws,
StatsFuture *future) {
auto runtime = OpenCLRuntime::Global();
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel);
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel);
uint32_t local_ws[2];
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
......
......@@ -29,26 +29,28 @@ void MatMulFunctor<DeviceType::OPENCL, T>::operator()(
const index_t height_blocks = RoundUpDiv4(height);
const index_t width_blocks = RoundUpDiv4(width);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("matmul");
built_options.emplace("-Dmatmul=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto matmul_kernel = runtime->BuildKernel("matmul", kernel_name, built_options);
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("matmul");
built_options.emplace("-Dmatmul=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("matmul", kernel_name, built_options);
uint32_t idx = 0;
matmul_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(A->buffer())));
matmul_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(B->buffer())));
matmul_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(C->buffer())));
matmul_kernel.setArg(idx++, static_cast<int>(height));
matmul_kernel.setArg(idx++, static_cast<int>(width));
matmul_kernel.setArg(idx++, static_cast<int>(A->dim(2)));
matmul_kernel.setArg(idx++, static_cast<int>(height_blocks));
matmul_kernel.setArg(idx++, static_cast<int>(RoundUpDiv4(A->dim(2))));
uint32_t idx = 0;
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(A->buffer())));
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(B->buffer())));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(C->buffer())));
kernel_.setArg(idx++, static_cast<int>(height));
kernel_.setArg(idx++, static_cast<int>(width));
kernel_.setArg(idx++, static_cast<int>(A->dim(2)));
kernel_.setArg(idx++, static_cast<int>(height_blocks));
kernel_.setArg(idx++, static_cast<int>(RoundUpDiv4(A->dim(2))));
}
const uint32_t gws[2] = {
static_cast<uint32_t>(width_blocks),
......@@ -61,7 +63,7 @@ void MatMulFunctor<DeviceType::OPENCL, T>::operator()(
<< C->dim(1) << "_"
<< C->dim(2) << "_"
<< C->dim(3);
TuningOrRun2DKernel(matmul_kernel, ss.str(), gws, lws, future);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
};
......
......@@ -11,68 +11,6 @@
namespace mace {
namespace kernels {
static void Pooling(const Tensor *input,
const int *stride,
const int *paddings,
const int pooling_size,
const PoolingType type,
const DataType dt,
Tensor *output,
StatsFuture *future) {
index_t batch = output->dim(0);
index_t out_height = output->dim(1);
index_t out_width = output->dim(2);
index_t channels = output->dim(3);
index_t channel_blocks = (channels + 3) / 4;
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("pooling");
built_options.emplace("-Dpooling=" + kernel_name);
if (type == MAX && input->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
built_options.emplace(dt == DT_HALF ? "-DFP16" : "");
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
}
if (type == AVG) {
built_options.emplace("-DPOOL_AVG");
}
auto pooling_kernel = runtime->BuildKernel("pooling", kernel_name, built_options);
uint32_t idx = 0;
pooling_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
pooling_kernel.setArg(idx++, static_cast<int32_t>(input->dim(1)));
pooling_kernel.setArg(idx++, static_cast<int32_t>(input->dim(2)));
pooling_kernel.setArg(idx++, static_cast<int32_t>(out_height));
pooling_kernel.setArg(idx++, paddings[0] / 2);
pooling_kernel.setArg(idx++, paddings[1] / 2);
pooling_kernel.setArg(idx++, stride[0]);
pooling_kernel.setArg(idx++, pooling_size);
pooling_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
const uint32_t gws[3] = {
static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(out_width),
static_cast<uint32_t>(batch * out_height),
};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(pooling_kernel);
std::vector<uint32_t> lws(4, 1);
lws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
lws[1] = std::min<uint32_t>(out_width, kwg_size / lws[0]);
lws[2] = std::min<uint32_t>(out_height * batch, kwg_size / (lws[0] * lws[1]));
std::stringstream ss;
ss << "pooling_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun3DKernel(pooling_kernel, ss.str(), gws, lws, future);
}
template<typename T>
void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output,
......@@ -95,8 +33,57 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
Pooling(input, strides_, paddings.data(), kernels_[0], pooling_type_,
DataTypeToEnum<T>::value, output, future);
index_t batch = output->dim(0);
index_t out_height = output->dim(1);
index_t out_width = output->dim(2);
index_t channels = output->dim(3);
index_t channel_blocks = (channels + 3) / 4;
if (kernel_.get() == nullptr) {
const DataType dt = DataTypeToEnum<T>::value;
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("pooling");
built_options.emplace("-Dpooling=" + kernel_name);
if (pooling_type_ == MAX && input->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
built_options.emplace(dt == DT_HALF ? "-DFP16" : "");
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
}
if (pooling_type_ == AVG) {
built_options.emplace("-DPOOL_AVG");
}
kernel_ = runtime->BuildKernel("pooling", kernel_name, built_options);
uint32_t idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(out_height));
kernel_.setArg(idx++, paddings[0] / 2);
kernel_.setArg(idx++, paddings[1] / 2);
kernel_.setArg(idx++, strides_[0]);
kernel_.setArg(idx++, kernels_[0]);
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
}
const uint32_t gws[3] = {
static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(out_width),
static_cast<uint32_t>(batch * out_height),
};
std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss;
ss << "pooling_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
......
......@@ -21,40 +21,42 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t out_height = out_height_;
const index_t out_width = out_width_;
index_t out_height = out_height_;
index_t out_width = out_width_;
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> output_shape {batch, out_height, out_width, channels};
if (input->is_image()) {
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
} else {
output->Resize(output_shape);
}
if (kernel_.get() == nullptr) {
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> output_shape{batch, out_height, out_width, channels};
if (input->is_image()) {
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
} else {
output->Resize(output_shape);
}
float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_);
float width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_);
float width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bilinear_nocache");
built_options.emplace("-Dresize_bilinear_nocache=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto rb_kernel = runtime->BuildKernel("resize_bilinear", kernel_name, built_options);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bilinear_nocache");
built_options.emplace("-Dresize_bilinear_nocache=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("resize_bilinear", kernel_name, built_options);
uint32_t idx = 0;
rb_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
rb_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
rb_kernel.setArg(idx++, height_scale);
rb_kernel.setArg(idx++, width_scale);
rb_kernel.setArg(idx++, static_cast<int32_t>(in_height));
rb_kernel.setArg(idx++, static_cast<int32_t>(in_width));
rb_kernel.setArg(idx++, static_cast<int32_t>(out_height));
uint32_t idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
kernel_.setArg(idx++, height_scale);
kernel_.setArg(idx++, width_scale);
kernel_.setArg(idx++, static_cast<int32_t>(in_height));
kernel_.setArg(idx++, static_cast<int32_t>(in_width));
kernel_.setArg(idx++, static_cast<int32_t>(out_height));
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(out_width),
......@@ -66,7 +68,7 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun3DKernel(rb_kernel, ss.str(), gws, lws, future);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
template struct ResizeBilinearFunctor<DeviceType::OPENCL, float>;
......
......@@ -23,21 +23,23 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
const index_t channel_blocks = RoundUpDiv4(channels);
const int remain_channels = channel_blocks * 4 - channels;
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("softmax");
built_options.emplace("-Dsoftmax=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
cl::Kernel softmax_kernel = runtime->BuildKernel("softmax", kernel_name, built_options);
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("softmax");
built_options.emplace("-Dsoftmax=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("softmax", kernel_name, built_options);
uint32_t idx = 0;
softmax_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(logits->buffer())));
softmax_kernel.setArg(idx++, static_cast<int>(channels));
softmax_kernel.setArg(idx++, remain_channels);
softmax_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
uint32_t idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(logits->buffer())));
kernel_.setArg(idx++, static_cast<int>(channels));
kernel_.setArg(idx++, remain_channels);
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
......@@ -48,7 +50,7 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun3DKernel(softmax_kernel, ss.str(), gws, lws, future);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
template
......
......@@ -20,9 +20,9 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor
const std::vector<index_t> &output_shape,
Tensor *batch_tensor,
StatsFuture *future) {
const char *kernel_name = nullptr;
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
const char *kernel_name = nullptr;
if (b2s_) {
space_tensor->ResizeImage(output_shape, output_image_shape);
kernel_name = "batch_to_space";
......@@ -30,32 +30,34 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor
batch_tensor->ResizeImage(output_shape, output_image_shape);
kernel_name = "space_to_batch";
}
std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum<T>::value));
auto s2b_kernel = runtime->BuildKernel("space_to_batch", kernel_name, built_options);
if (kernel_.get() == nullptr) {
std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum<T>::value));
kernel_ = runtime->BuildKernel("space_to_batch", kernel_name, built_options);
uint32_t idx = 0;
if (b2s_) {
s2b_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(batch_tensor->buffer())));
s2b_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(space_tensor->buffer())));
} else {
s2b_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(space_tensor->buffer())));
s2b_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(batch_tensor->buffer())));
uint32_t idx = 0;
if (b2s_) {
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(batch_tensor->buffer())));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(space_tensor->buffer())));
} else {
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(space_tensor->buffer())));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(batch_tensor->buffer())));
}
kernel_.setArg(idx++, block_shape_[0]);
kernel_.setArg(idx++, block_shape_[1]);
kernel_.setArg(idx++, paddings_[0]);
kernel_.setArg(idx++, paddings_[2]);
kernel_.setArg(idx++, static_cast<int32_t>(space_tensor->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(space_tensor->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(2)));
}
s2b_kernel.setArg(idx++, block_shape_[0]);
s2b_kernel.setArg(idx++, block_shape_[1]);
s2b_kernel.setArg(idx++, paddings_[0]);
s2b_kernel.setArg(idx++, paddings_[2]);
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(1)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(2)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(1)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(2)));
const uint32_t chan_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(3));
const uint32_t gws[3] = {chan_blk,
......@@ -68,7 +70,7 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor
<< batch_tensor->dim(1) << "_"
<< batch_tensor->dim(2) << "_"
<< batch_tensor->dim(3);
TuningOrRun3DKernel(s2b_kernel, ss.str(), gws, lws, future);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
template struct SpaceToBatchFunctor<DeviceType::OPENCL, float>;
......
......@@ -25,31 +25,34 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
const index_t round_h = (output_shape[1] + 1) / 2;
const index_t round_w = (output_shape[2] + 1) / 2;
const index_t out_width = input_tensor->dim(0) * round_h * round_w;
output_shape = {16, input_tensor->dim(3), out_width, 1};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, image_shape);
output_tensor->ResizeImage(output_shape, image_shape);
string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2");
std::set<std::string> built_options;
built_options.emplace("-Dwinograd_transform_2x2=" + obfuscated_kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
auto runtime = OpenCLRuntime::Global();
auto wino_kernel = runtime->BuildKernel("winograd_transform",
obfuscated_kernel_name,
built_options);
uint32_t idx = 0;
wino_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
wino_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
wino_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(1)));
wino_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(2)));
wino_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(3)));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_w));
wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[0] / 2));
wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2));
if (kernel_.get() == nullptr) {
output_shape = {16, input_tensor->dim(3), out_width, 1};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, image_shape);
output_tensor->ResizeImage(output_shape, image_shape);
string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2");
std::set<std::string> built_options;
built_options.emplace("-Dwinograd_transform_2x2=" + obfuscated_kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
auto runtime = OpenCLRuntime::Global();
kernel_ = runtime->BuildKernel("winograd_transform",
obfuscated_kernel_name,
built_options);
uint32_t idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
kernel_.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(1)));
kernel_.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(2)));
kernel_.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(3)));
kernel_.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
kernel_.setArg(idx++, static_cast<uint32_t>(round_w));
kernel_.setArg(idx++, static_cast<uint32_t>(paddings[0] / 2));
kernel_.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2));
}
const uint32_t gws[2] = {static_cast<size_t>(out_width),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(3)))};
......@@ -60,7 +63,7 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
<< input_tensor->dim(1) << "_"
<< input_tensor->dim(2) << "_"
<< input_tensor->dim(3);
TuningOrRun2DKernel(wino_kernel, ss.str(), gws, lws, future);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
}
template<typename T>
......@@ -73,53 +76,55 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, image_shape);
output_tensor->ResizeImage(output_shape, image_shape);
string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_inverse_transform_2x2");
std::set<std::string> built_options;
built_options.emplace("-Dwinograd_inverse_transform_2x2=" + obfuscated_kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation_) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
auto runtime = OpenCLRuntime::Global();
auto wino_kernel = runtime->BuildKernel("winograd_transform",
obfuscated_kernel_name,
built_options);
const uint32_t round_h = (height_ + 1) / 2;
const uint32_t round_w = (width_ + 1) / 2;
uint32_t idx = 0;
wino_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
if (bias != nullptr) {
wino_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
if (kernel_.get() == nullptr) {
string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_inverse_transform_2x2");
std::set<std::string> built_options;
built_options.emplace("-Dwinograd_inverse_transform_2x2=" + obfuscated_kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation_) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
auto runtime = OpenCLRuntime::Global();
kernel_ = runtime->BuildKernel("winograd_transform",
obfuscated_kernel_name,
built_options);
const uint32_t round_h = (height_ + 1) / 2;
const uint32_t round_w = (width_ + 1) / 2;
uint32_t idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
if (bias != nullptr) {
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
}
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
kernel_.setArg(idx++, static_cast<uint32_t>(output_shape[1]));
kernel_.setArg(idx++, static_cast<uint32_t>(output_shape[2]));
kernel_.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
kernel_.setArg(idx++, static_cast<uint32_t>(round_w));
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
}
wino_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
wino_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[1]));
wino_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[2]));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_w));
wino_kernel.setArg(idx++, relux_max_limit_);
wino_kernel.setArg(idx++, prelu_alpha_);
const uint32_t gws[2] = {static_cast<size_t>(input_tensor->dim(2)),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(1)))};
......@@ -131,7 +136,7 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
<< input_tensor->dim(1) << "_"
<< input_tensor->dim(2) << "_"
<< input_tensor->dim(3);
TuningOrRun2DKernel(wino_kernel, ss.str(), gws, lws, future);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
}
template
......
......@@ -9,6 +9,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
......@@ -171,6 +172,8 @@ struct PoolingFunctor<DeviceType::OPENCL, T> : PoolingFunctorBase {
void operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -6,6 +6,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -169,6 +170,8 @@ struct ResizeBilinearFunctor<DeviceType::OPENCL, T> : ResizeBilinearFunctorBase
: ResizeBilinearFunctorBase(size, align_corners) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -8,6 +8,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/public/mace.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -55,6 +56,8 @@ struct SoftmaxFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namepsace kernels
......
......@@ -8,6 +8,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/public/mace.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -51,6 +52,8 @@ struct SpaceToBatchFunctor<DeviceType::OPENCL, T>: SpaceToBatchFunctorBase{
Tensor *batch_tensor,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -9,6 +9,7 @@
#include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/activation.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -43,6 +44,8 @@ struct WinogradTransformFunctor<DeviceType::OPENCL, T> : WinogradTransformFuncto
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
struct WinogradInverseTransformFunctorBase {
......@@ -100,6 +103,8 @@ struct WinogradInverseTransformFunctor<DeviceType::OPENCL, T> : WinogradInverseT
const Tensor *bias,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
......
......@@ -22,9 +22,11 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
functor_(this->strides_.data(),
this->padding_,
this->dilations_.data(),
kernels::ActivationType::NOOP,
0.0f,
0.0f) {}
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -94,16 +94,16 @@ static void DepthwiseConv2d(int iters,
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1, float);
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册