提交 4f6af627 编写于 作者: L Liangliang He

Merge branch 'actvation1' into 'master'

Add relux/prelu/tanh/sigmoid

See merge request !218
......@@ -59,6 +59,7 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
}
}
extern void Register_Activation(OperatorRegistry *op_registry);
extern void Register_AddN(OperatorRegistry *op_registry);
extern void Register_BatchNorm(OperatorRegistry *op_registry);
extern void Register_BatchToSpaceND(OperatorRegistry *op_registry);
......@@ -68,17 +69,17 @@ extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FusedConv2D(OperatorRegistry *op_registry);
extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Relu(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
OperatorRegistry::OperatorRegistry() {
Register_Activation(this);
Register_AddN(this);
Register_BatchNorm(this);
Register_BatchToSpaceND(this);
......@@ -88,15 +89,14 @@ OperatorRegistry::OperatorRegistry() {
Register_Concat(this);
Register_Conv2D(this);
Register_DepthwiseConv2d(this);
Register_FoldedBatchNorm(this);
Register_FusedConv2D(this);
Register_GlobalAvgPooling(this);
Register_ImageToBuffer(this);
Register_Pooling(this);
Register_Relu(this);
Register_ResizeBilinear(this);
Register_SpaceToBatchND(this);
Register_Softmax(this);
Register_FoldedBatchNorm(this);
Register_SpaceToBatchND(this);
}
} // namespace mace
......@@ -22,8 +22,9 @@ bool GetSourceOrBinaryProgram(const std::string &program_name,
return false;
}
cl::Program::Sources sources;
std::string kernel_source(it_source->second.begin(), it_source->second.end());
sources.push_back(ObfuscateString(kernel_source));
std::string content(it_source->second.begin(), it_source->second.end());
std::string kernel_source = ObfuscateString(content);
sources.push_back(kernel_source);
*program = cl::Program(context, sources);
return true;
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_ACTIVATION_H_
#define MACE_KERNELS_ACTIVATION_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/types.h"
namespace mace {
namespace kernels {
enum ActivationType {
NOOP = 0,
RELU = 1,
RELUX = 2,
PRELU = 3,
TANH = 4,
SIGMOID = 5
};
inline ActivationType StringToActivationType(const std::string type) {
if (type == "RELU") {
return ActivationType::RELU;
} else if (type == "RELUX") {
return ActivationType::RELUX;
} else if (type == "PRELU") {
return ActivationType::PRELU;
} else if (type == "TANH") {
return ActivationType::TANH;
} else if (type == "SIGMOID") {
return ActivationType::SIGMOID;
} else if (type == "NOOP") {
return ActivationType::NOOP;
} else {
LOG(FATAL) << "Unknown activation type: " << type;
}
return ActivationType::NOOP;
}
template <typename T>
void DoActivation(const T *input_ptr,
T *output_ptr,
const index_t size,
const ActivationType type,
const float relux_max_limit,
const float prelu_alpha) {
MACE_CHECK(DataTypeToEnum<T>::value != DataType::DT_HALF);
switch (type) {
case NOOP:
break;
case RELU:
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0));
}
break;
case RELUX:
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min(std::max(input_ptr[i], static_cast<T>(0)),
static_cast<T>(relux_max_limit));
}
break;
case PRELU:
for (index_t i = 0; i < size; ++i) {
T in = input_ptr[i];
if (in < 0) {
output_ptr[i] = in * prelu_alpha;
} else {
output_ptr[i] = in;
}
}
break;
case TANH:
for (index_t i = 0; i < size; ++i) {
T in_exp = std::exp(-2 * input_ptr[i]);
output_ptr[i] = (1 - in_exp) / (1 + in_exp);
}
break;
case SIGMOID:
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i]));
}
break;
default:
LOG(FATAL) << "Unknown activation type: " << type;
}
}
template <DeviceType D, typename T>
class ActivationFunctor {
public:
ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha)
: activation_(type),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_,
prelu_alpha_);
}
private:
ActivationType activation_;
T relux_max_limit_;
T prelu_alpha_;
};
template <>
void ActivationFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input, Tensor *output, StatsFuture *future);
template <typename T>
class ActivationFunctor<DeviceType::OPENCL, T> {
public:
ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha)
: activation_(type),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
private:
ActivationType activation_;
T relux_max_limit_;
T prelu_alpha_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_ACTIVATION_H_
......@@ -6,25 +6,37 @@
#define MACE_KERNELS_BATCH_NORM_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
struct BatchNormFunctorBase {
BatchNormFunctorBase(bool folded_constant, bool fused_relu) :
folded_constant_(folded_constant),
fused_relu_(fused_relu){}
BatchNormFunctorBase(bool folded_constant,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: folded_constant_(folded_constant),
activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
const bool folded_constant_;
const bool fused_relu_;
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template <DeviceType D, typename T>
struct BatchNormFunctor : BatchNormFunctorBase{
BatchNormFunctor(const bool folded_constant, const bool fused_relu) :
BatchNormFunctorBase(folded_constant, fused_relu) {}
struct BatchNormFunctor : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: BatchNormFunctorBase(
folded_constant, activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input,
const Tensor *scale,
......@@ -85,32 +97,34 @@ struct BatchNormFunctor : BatchNormFunctorBase{
} else {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
}
if (fused_relu_) {
output_ptr[pos] = std::max(output_ptr[pos], static_cast<T>(0));
}
++pos;
}
}
}
}
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
}
};
template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
void BatchNormFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant, const bool fused_relu) :
BatchNormFunctorBase(folded_constant, fused_relu) {}
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: BatchNormFunctorBase(
folded_constant, activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
......
......@@ -7,6 +7,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
......@@ -15,20 +16,39 @@ namespace kernels {
struct Conv2dFunctorBase {
Conv2dFunctorBase(const int *strides,
const Padding &paddings,
const int *dilations)
: strides_(strides), dilations_(dilations), paddings_(paddings) {}
const int *strides_; // [stride_h, stride_w]
const int *dilations_; // [dilation_h, dilation_w]
Padding paddings_;
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
const int *strides_; // [stride_h, stride_w]
const int *dilations_; // [dilation_h, dilation_w]
const Padding paddings_;
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template<DeviceType D, typename T>
template <DeviceType D, typename T>
struct Conv2dFunctor : Conv2dFunctorBase {
Conv2dFunctor(const int *strides,
const Padding &paddings,
const int *dilations)
: Conv2dFunctorBase(strides, paddings, dilations) {}
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: Conv2dFunctorBase(strides,
paddings,
dilations,
activation,
relux_max_limit,
prelu_alpha) {}
void operator()(const Tensor *input,
const Tensor *filter,
......@@ -42,8 +62,8 @@ struct Conv2dFunctor : Conv2dFunctorBase {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_,
strides_, paddings_, output_shape.data(), paddings.data());
input->shape().data(), filter->shape().data(), dilations_, strides_,
paddings_, output_shape.data(), paddings.data());
output->Resize(output_shape);
index_t batch = output->dim(0);
......@@ -101,7 +121,7 @@ struct Conv2dFunctor : Conv2dFunctorBase {
if (inh < 0 || inh >= input_height || inw < 0 ||
inw >= input_width) {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop,
inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", inh, ", ",
inw);
// else padding with 0:
......@@ -109,8 +129,8 @@ struct Conv2dFunctor : Conv2dFunctorBase {
} else {
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
inc;
inh * input_width * input_channels +
inw * input_channels + inc;
sum += input_data[input_offset] * *filter_ptr;
}
filter_ptr += channels;
......@@ -123,24 +143,33 @@ struct Conv2dFunctor : Conv2dFunctorBase {
}
}
}
output_data = output->mutable_data<T>();
DoActivation(output_data, output_data, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
}
};
template<>
template <>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
template<typename T>
template <typename T>
struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
Conv2dFunctor(const int *strides,
const Padding &paddings,
const int *dilations)
: Conv2dFunctorBase(strides, paddings, dilations) {}
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: Conv2dFunctorBase(strides,
paddings,
dilations,
activation,
relux_max_limit,
prelu_alpha) {}
void operator()(const Tensor *input,
const Tensor *filter,
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_FUSED_CONV_2D_H_
#define MACE_KERNELS_FUSED_CONV_2D_H_
#include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/conv_2d.h"
namespace mace {
namespace kernels {
struct FusedConv2dFunctorBase {
FusedConv2dFunctorBase(const int *strides,
const Padding &paddings,
const int *dilations)
: strides_(strides), dilations_(dilations), paddings_(paddings) {}
const int *strides_; // [stride_h, stride_w]
const int *dilations_; // [dilation_h, dilation_w]
Padding paddings_;
};
template<DeviceType D, typename T>
struct FusedConv2dFunctor : FusedConv2dFunctorBase {
FusedConv2dFunctor(const int *strides,
const Padding &paddings,
const int *dilations)
: FusedConv2dFunctorBase(strides, paddings, dilations) {}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
Conv2dFunctor<D, T>(strides_, paddings_, dilations_)(input, filter, bias,
output, future);
T *output_data = output->mutable_data<T>();
T zero_value;
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
zero_value = half_float::half_cast<half>(0.0f);
} else {
zero_value = 0;
}
auto output_size = output->size();
for (int n = 0; n < output_size; ++n) {
*output_data = *output_data < 0 ? zero_value : *output_data;
output_data++;
}
}
};
template<typename T>
struct FusedConv2dFunctor<DeviceType::OPENCL, T> : FusedConv2dFunctorBase {
FusedConv2dFunctor(const int *strides,
const Padding &paddings,
const int *dilations)
: FusedConv2dFunctorBase(strides, paddings, dilations) {}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_FUSED_CONV_2D_H_
......@@ -9,7 +9,7 @@ namespace mace {
namespace kernels {
template <>
void ReluFunctor<DeviceType::NEON, float>::operator()(const Tensor *input_tensor,
void ActivationFunctor<DeviceType::NEON, float>::operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
const float *input = input_tensor->data<float>();
......
......@@ -2,21 +2,20 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/relu.h"
#include "mace/kernels/activation.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <typename T>
void ReluFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
......@@ -27,94 +26,92 @@ void ReluFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation");
built_options.emplace("-Dactivation=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
cl::Kernel relu_kernel;
if (max_limit_ < 0) {
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("relu");
built_options.emplace("-Drelu=" + kernel_name);
relu_kernel = runtime->BuildKernel("relu", kernel_name, built_options);
uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
relu_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
} else {
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("relux");
built_options.emplace("-Drelux=" + kernel_name);
relu_kernel = runtime->BuildKernel("relu", kernel_name, built_options);
uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
relu_kernel.setArg(idx++, max_limit_);
relu_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
switch (activation_) {
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
cl::Kernel activation_kernel =
runtime->BuildKernel("activation", kernel_name, built_options);
int idx = 0;
activation_kernel.setArg(
idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
activation_kernel.setArg(idx++, relux_max_limit_);
activation_kernel.setArg(idx++, prelu_alpha_);
activation_kernel.setArg(idx++,
*(static_cast<cl::Image2D *>(output->buffer())));
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(relu_kernel);
const uint32_t kwg_size =
runtime->GetKernelMaxWorkGroupSize(activation_kernel);
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(width, kwg_size / local_ws[0]);
local_ws[2] = std::min<uint32_t>(height * batch, kwg_size / (local_ws[0] * local_ws[1]));
return {{local_ws[0], local_ws[1], local_ws[2]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, //SNPE size
local_ws[2] = std::min<uint32_t>(height * batch,
kwg_size / (local_ws[0] * local_ws[1]));
return {
{local_ws[0], local_ws[1], local_ws[2]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, // SNPE size
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t> &params) -> cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
relu_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr, &event);
activation_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "relu_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
std::string tuning_key =
Concat("relu_opencl_kernel_", activation_, output->dim(0), output->dim(1),
output->dim(2), output->dim(3));
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(
tuning_key, lws, params_generator, func, &timer);
SetFuture(future, event);
}
template
struct ReluFunctor<DeviceType::OPENCL, float>;
template
struct ReluFunctor<DeviceType::OPENCL, half>;
template struct ActivationFunctor<DeviceType::OPENCL, float>;
template struct ActivationFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -5,23 +5,22 @@
#include "mace/kernels/batch_norm.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
template<typename T>
void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
template <typename T>
void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(folded_constant_ || (mean != nullptr && var != nullptr));
const index_t batch = input->dim(0);
......@@ -41,21 +40,45 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
if (folded_constant_) {
built_options.emplace("-DFOLDED_CONSTANT");
}
if (fused_relu_) {
built_options.emplace("-DFUSED_RELU");
switch (activation_) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
auto bm_kernel = runtime->BuildKernel("batch_norm", kernel_name, built_options);
auto bm_kernel =
runtime->BuildKernel("batch_norm", kernel_name, built_options);
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(scale->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(offset->buffer())));
bm_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(offset->buffer())));
if (!folded_constant_) {
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(mean->buffer())));
bm_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(var->buffer())));
bm_kernel.setArg(idx++, epsilon);
}
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
bm_kernel.setArg(idx++, relux_max_limit_);
bm_kernel.setArg(idx++, prelu_alpha_);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
......@@ -66,64 +89,48 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(width, kwg_size / local_ws[0]);
local_ws[2] = std::min<uint32_t>(height * batch, kwg_size / (local_ws[0] * local_ws[1]));
return {{local_ws[0], local_ws[1], local_ws[2]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{8, 128, 1}, //SNPE size
local_ws[2] = std::min<uint32_t>(height * batch,
kwg_size / (local_ws[0] * local_ws[1]));
return {
{local_ws[0], local_ws[1], local_ws[2]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{8, 128, 1}, // SNPE size
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t> &params) -> cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr, &event);
bm_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "batch_norm_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3) << "_"
<< folded_constant_;
std::string tuning_key =
Concat("batch_norm_opencl_kernel_", activation_, output->dim(0),
output->dim(1), output->dim(2), output->dim(3), folded_constant_);
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(
tuning_key, lws, params_generator, func, &timer);
SetFuture(future, event);
}
template
struct BatchNormFunctor<DeviceType::OPENCL, float>;
template
struct BatchNormFunctor<DeviceType::OPENCL, half>;
template struct BatchNormFunctor<DeviceType::OPENCL, float>;
template struct BatchNormFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
#include <common.h>
// Supported data type: half/float
__kernel void relu(__read_only image2d_t input,
__write_only image2d_t output) {
__kernel void activation(__read_only image2d_t input,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
......@@ -10,20 +11,7 @@ __kernel void relu(__read_only image2d_t input,
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = fmax(in, 0);
DATA_TYPE4 out = do_activation(in, relux_max_limit, prelu_alpha);
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
__kernel void relux(__read_only image2d_t input,
__private const DATA_TYPE max_limit,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
const int width = get_global_size(1);
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = clamp(in, 0, max_limit);
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
......@@ -8,7 +8,9 @@ __kernel void batch_norm(__read_only image2d_t input,
__read_only image2d_t var,
__private const float epsilon,
#endif
__write_only image2d_t output) {
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
......@@ -33,8 +35,8 @@ __kernel void batch_norm(__read_only image2d_t input,
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = mad(in, bn_scale, bn_offset);
#ifdef FUSED_RELU
out = fmax(out, 0);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out = do_activation(out, relux_max_limit, prelu_alpha);
#endif
WRITE_IMAGET(output, (int2)(pos, hb), out);
......
......@@ -18,7 +18,29 @@
#define READ_IMAGET CMD_TYPE(read_image, CMD_DATA_TYPE)
#define WRITE_IMAGET CMD_TYPE(write_image, CMD_DATA_TYPE)
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha) {
DATA_TYPE4 out;
#ifdef USE_RELU
out = fmax(in, 0);
#endif
#ifdef USE_RELUX
out = clamp(in, 0, relux_max_limit);
#endif
#ifdef USE_PRELU
out = select(prelu_alpha * in, in, in >= 0);
#endif
#ifdef USE_TANH
out = tanh(in);
#endif
#ifdef USE_SIGMOID
out = native_recip(1.0 + native_exp(-in));
#endif
return out;
}
#endif // MACE_KERNELS_OPENCL_CL_COMMON_H_
......@@ -6,6 +6,8 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
......@@ -115,12 +117,11 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
}
}
#ifdef FUSED_RELU
// TODO relux
out0 = fmax(out0, 0);
out1 = fmax(out1, 0);
out2 = fmax(out2, 0);
out3 = fmax(out3, 0);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
#endif
const int out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -6,6 +6,8 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
......@@ -90,12 +92,11 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
filter_x_base += 4;
}
#ifdef FUSED_RELU
// TODO relux
out0 = fmax(out0, 0);
out1 = fmax(out1, 0);
out2 = fmax(out2, 0);
out3 = fmax(out3, 0);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
#endif
const int out_x_base = mul24(out_ch_blk, width);
......
......@@ -6,6 +6,8 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
......@@ -122,13 +124,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
}
}
#ifdef FUSED_RELU
// TODO relux
out0 = fmax(out0, 0);
out1 = fmax(out1, 0);
out2 = fmax(out2, 0);
out3 = fmax(out3, 0);
out4 = fmax(out4, 0);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
out4 = do_activation(out4, relux_max_limit, prelu_alpha);
#endif
const int out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -3,52 +3,84 @@
//
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
extern void Conv2dOpenclK1x1S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
extern void Conv2dOpenclK3x3S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
extern void Conv2dOpenclK3x3S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
extern void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const int *dilations, const DataType dt,
Tensor *output, StatsFuture *future);
extern void Conv2dOpencl(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const uint32_t stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
template<typename T>
template <typename T>
void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
typedef void (*Conv2dOpenclFunction)(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
StatsFuture *future);
typedef void (*Conv2dOpenclFunction)(
const Tensor *input, const Tensor *filter, const Tensor *bias,
const int *padding, const int *dilations, const ActivationType activation,
const float relux_max_limit, const float prelu_alpha, const DataType dt,
Tensor *output, StatsFuture *future);
// Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
......@@ -73,8 +105,8 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_,
strides_, paddings_, output_shape.data(), paddings.data());
input->shape().data(), filter->shape().data(), dilations_, strides_,
paddings_, output_shape.data(), paddings.data());
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
......@@ -83,20 +115,18 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1][strides_[0] - 1] != nullptr) {
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_func(input, filter, bias, false, paddings.data(), dilations_,
DataTypeToEnum<T>::value, output, future);
conv2d_func(input, filter, bias, paddings.data(), dilations_, activation_,
relux_max_limit_, prelu_alpha_, DataTypeToEnum<T>::value,
output, future);
} else {
Conv2dOpencl(input, filter, bias, false, strides_[0],
paddings.data(), dilations_, DataTypeToEnum<T>::value,
output, future);
Conv2dOpencl(input, filter, bias, strides_[0], paddings.data(), dilations_,
activation_, relux_max_limit_, prelu_alpha_,
DataTypeToEnum<T>::value, output, future);
}
}
template
struct Conv2dFunctor<DeviceType::OPENCL, float>;
template
struct Conv2dFunctor<DeviceType::OPENCL, half>;
template struct Conv2dFunctor<DeviceType::OPENCL, float>;
template struct Conv2dFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -5,9 +5,10 @@
#include "mace/kernels/conv_2d.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
......@@ -15,8 +16,10 @@ namespace kernels {
void Conv1x1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const bool fused_relu,
const int stride,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
......@@ -44,20 +47,46 @@ void Conv1x1(const Tensor *input,
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
if (fused_relu) {
built_options.emplace("-DFUSED_RELU");
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d_1x1", kernel_name, built_options);
auto conv_2d_kernel =
runtime->BuildKernel("conv_2d_1x1", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(filter->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
}
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
// FIXME handle flexable data type: half not supported
conv_2d_kernel.setArg(idx++, relux_max_limit);
conv_2d_kernel.setArg(idx++, prelu_alpha);
conv_2d_kernel.setArg(idx++, static_cast<int>(input_height));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_width));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channel_blocks));
......@@ -69,86 +98,79 @@ void Conv1x1(const Tensor *input,
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 15, 8};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel);
auto params_generator = [&]()->std::vector<std::vector<uint32_t>> {
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(width_blocks, kwg_size / local_ws[0]);
local_ws[2] = std::min<uint32_t>(height * batch, kwg_size / (local_ws[0] * local_ws[1]));
return {{local_ws[0], local_ws[1], local_ws[2]},
{kwg_size/16, 4, 4},
{kwg_size/32, 4, 8},
{kwg_size/32, 8, 4},
{kwg_size/64, 8, 8},
{kwg_size/64, 16, 4},
{kwg_size/128, 8, 16},
{kwg_size/128, 16, 8},
{kwg_size/128, 32, 4},
{1, kwg_size/32, 32},
{1, kwg_size/64, 64},
{1, kwg_size/128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, //SNPE size
local_ws[2] = std::min<uint32_t>(height * batch,
kwg_size / (local_ws[0] * local_ws[1]));
return {
{local_ws[0], local_ws[1], local_ws[2]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, // SNPE size
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t>& params)->cl_int {
auto func = [&](const std::vector<uint32_t> &params) -> cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
conv_2d_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr, &event);
conv_2d_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "conv2d_1x1_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
std::string tuning_key =
Concat("conv2d_1x1_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(
tuning_key, lws, params_generator, func, &timer);
SetFuture(future, event);
}
extern void Conv2dOpenclK1x1S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const bool fused_relu,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv1x1(input, filter, bias, fused_relu, 1, dt, output, future);
Conv1x1(input, filter, bias, 1, activation, relux_max_limit, prelu_alpha, dt,
output, future);
};
extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const bool fused_relu,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv1x1(input, filter, bias, fused_relu, 2, dt, output, future);
Conv1x1(input, filter, bias, 2, activation, relux_max_limit, prelu_alpha, dt,
output, future);
};
} // namespace kernels
......
......@@ -2,21 +2,29 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/conv_2d.h"
#include "mace/core/common.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const int *dilations, const DataType dt,
Tensor *output, StatsFuture *future) {
static void Conv2d3x3S12(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const uint32_t stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
......@@ -34,20 +42,45 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
if (fused_relu) {
built_options.emplace("-DFUSED_RELU");
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d_3x3", kernel_name, built_options);
auto conv_2d_kernel =
runtime->BuildKernel("conv_2d_3x3", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(filter->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
}
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++, relux_max_limit);
conv_2d_kernel.setArg(idx++, prelu_alpha);
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(1)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channel_blocks));
......@@ -67,83 +100,75 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(width_blocks, kwg_size / local_ws[0]);
local_ws[2] = std::min<uint32_t>(height * batch, kwg_size / (local_ws[0] * local_ws[1]));
return {{local_ws[0], local_ws[1], local_ws[2]},
{local_ws[2], local_ws[1], local_ws[0]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, //SNPE size
local_ws[2] = std::min<uint32_t>(height * batch,
kwg_size / (local_ws[0] * local_ws[1]));
return {
{local_ws[0], local_ws[1], local_ws[2]},
{local_ws[2], local_ws[1], local_ws[0]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, // SNPE size
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t> &params) -> cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
conv_2d_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr, &event);
conv_2d_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "conv2d_3x3_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
std::string tuning_key =
Concat("conv2d_3x3_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(
tuning_key, lws, params_generator, func, &timer);
SetFuture(future, event);
}
void Conv2dOpenclK3x3S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const bool fused_relu,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv2d3x3S12(input, filter, bias, fused_relu, 1, padding, dilations, dt, output, future);
Conv2d3x3S12(input, filter, bias, 1, padding, dilations, activation,
relux_max_limit, prelu_alpha, dt, output, future);
};
void Conv2dOpenclK3x3S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const bool fused_relu,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
Conv2d3x3S12(input, filter, bias, fused_relu, 2, padding, dilations, dt, output, future);
Conv2d3x3S12(input, filter, bias, 2, padding, dilations, activation,
relux_max_limit, prelu_alpha, dt, output, future);
};
} // namespace kernels
......
......@@ -2,21 +2,29 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/conv_2d.h"
#include "mace/core/common.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const int *dilations, const DataType dt,
Tensor *output, StatsFuture *future) {
void Conv2dOpencl(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const uint32_t stride,
const int *padding,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
......@@ -34,20 +42,45 @@ void Conv2dOpencl(const Tensor *input, const Tensor *filter,
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
if (fused_relu) {
built_options.emplace("-DFUSED_RELU");
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation;
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d", kernel_name, built_options);
auto conv_2d_kernel =
runtime->BuildKernel("conv_2d", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(filter->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
}
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++, relux_max_limit);
conv_2d_kernel.setArg(idx++, prelu_alpha);
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(1)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channel_blocks));
......@@ -69,60 +102,46 @@ void Conv2dOpencl(const Tensor *input, const Tensor *filter,
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(width_blocks, kwg_size / local_ws[0]);
local_ws[2] = std::min<uint32_t>(height * batch, kwg_size / (local_ws[0] * local_ws[1]));
return {{local_ws[0], local_ws[1], local_ws[2]},
{local_ws[2], local_ws[1], local_ws[0]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, //SNPE size
local_ws[2] = std::min<uint32_t>(height * batch,
kwg_size / (local_ws[0] * local_ws[1]));
return {
{local_ws[0], local_ws[1], local_ws[2]},
{local_ws[2], local_ws[1], local_ws[0]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1},
{4, 15, 8}, // SNPE size
};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t> &params) -> cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
conv_2d_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr, &event);
conv_2d_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "conv2d_general_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
std::string tuning_key =
Concat("conv2d_general_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(
tuning_key, lws, params_generator, func, &timer);
SetFuture(future, event);
}
} // namespace kernels
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/fused_conv_2d.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
StatsFuture *future);
extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
StatsFuture *future);
extern void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const int *dilations, const DataType dt,
Tensor *output, StatsFuture *future);
template<typename T>
void FusedConv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
typedef void (*Conv2dOpenclFunction)(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const int *padding, const int *dilations,
const DataType dt, Tensor *output,
StatsFuture *future);
// Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
{nullptr, nullptr},
{Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2},
{nullptr, nullptr},
{nullptr, nullptr}};
index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1);
if (!input->is_image() || strides_[0] != strides_[1] || strides_[0] > 2 ||
(dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) {
LOG(WARNING) << "OpenCL conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< ",dilations " << dilations_[0] << "x" << dilations_[1]
<< " and input image: " << input->is_image()
<< " is not implemented yet.";
MACE_NOT_IMPLEMENTED;
}
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_,
strides_, paddings_, output_shape.data(), paddings.data());
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1][strides_[0] - 1] != nullptr) {
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_func(input, filter, bias, true, paddings.data(), dilations_,
DataTypeToEnum<T>::value, output, future);
} else {
Conv2dOpencl(input, filter, bias, true, strides_[0], paddings.data(),
dilations_, DataTypeToEnum<T>::value, output, future);
}
}
template
struct FusedConv2dFunctor<DeviceType::OPENCL, float>;
template
struct FusedConv2dFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -4,8 +4,12 @@
#ifndef MACE_KERNELS_OPENCL_HELPER_H_
#define MACE_KERNELS_OPENCL_HELPER_H_
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/types.h"
#include "mace/utils/utils.h"
#include "mace/core/future.h"
namespace mace {
namespace kernels {
......@@ -28,6 +32,40 @@ std::string DtToCLDt(const DataType dt);
std::string DtToUpstreamCLDt(const DataType dt);
inline void SetFuture(StatsFuture *future, const cl::Event &event) {
if (future != nullptr) {
future->wait_fn = [event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
OpenCLRuntime::Global()->GetCallStats(event, stats);
}
};
}
}
namespace {
template<typename T>
void AppendToStream(std::stringstream *ss, const std::string &delimiter, T v) {
(*ss) << v;
}
template<typename T, typename... Args>
void AppendToStream(std::stringstream *ss,
const std::string &delimiter,
T first,
Args... args) {
(*ss) << first << delimiter;
AppendToStream(ss, delimiter, args...);
}
} // namespace
template<typename... Args>
std::string Concat(Args... args) {
std::stringstream ss;
AppendToStream(&ss, "_", args...);
return ss.str();
}
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_OPENCL_HELPER_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_RELU_H_
#define MACE_KERNELS_RELU_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ReluFunctor {
T max_limit_;
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
index_t size = input->size();
if (max_limit_ < 0) {
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0));
}
} else {
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min(std::max(input_ptr[i], static_cast<T>(0)), max_limit_);
}
}
}
};
template <>
void ReluFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
template <typename T>
struct ReluFunctor<DeviceType::OPENCL, T> {
T max_limit_;
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_RELU_H_
......@@ -2,36 +2,36 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/relu.h"
#include "mace/ops/activation.h"
namespace mace {
void Register_Relu(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Relu")
void Register_Activation(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReluOp<DeviceType::CPU, float>);
ActivationOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Relu")
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ReluOp<DeviceType::NEON, float>);
ActivationOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Relu")
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
ReluOp<DeviceType::OPENCL, float>);
ActivationOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Relu")
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
ReluOp<DeviceType::OPENCL, half>);
ActivationOp<DeviceType::OPENCL, half>);
}
} // namespace mace
......@@ -2,22 +2,25 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_RELU_H_
#define MACE_OPS_RELU_H_
#ifndef MACE_OPS_ACTIVATION_H_
#define MACE_OPS_ACTIVATION_H_
#include "mace/core/operator.h"
#include "mace/kernels/relu.h"
#include "mace/kernels/activation.h"
namespace mace {
template <DeviceType D, class T>
class ReluOp : public Operator<D, T> {
class ActivationOp : public Operator<D, T> {
public:
ReluOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
functor_.max_limit_ =
OperatorBase::GetSingleArgument<float>("max_limit", static_cast<float>(-1));
}
ActivationOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->inputs_[0];
Tensor *output_tensor = this->outputs_[0];
......@@ -28,9 +31,9 @@ class ReluOp : public Operator<D, T> {
}
private:
kernels::ReluFunctor<D, T> functor_;
kernels::ActivationFunctor<D, T> functor_;
};
} // namespace mace
#endif // MACE_OPS_RELU_H_
#endif // MACE_OPS_ACTIVATION_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void ReluBenchmark(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "ReluBM")
.Input("InputImage")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Activation", "ReluBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_RELU_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_RELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ReluBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_RELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_RELU(N, C, H, W, TYPE) \
BM_RELU_MACRO(N, C, H, W, TYPE, CPU); \
BM_RELU_MACRO(N, C, H, W, TYPE, OPENCL);
BM_RELU(1, 1, 512, 512, float);
BM_RELU(1, 3, 128, 128, float);
BM_RELU(1, 3, 512, 512, float);
BM_RELU(1, 32, 112, 112, float);
BM_RELU(1, 64, 256, 256, float);
template <DeviceType D, typename T>
static void ReluxBenchmark(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "ReluxBM")
.Input("InputImage")
.Output("Output")
.AddStringArg("activation", "RELUX")
.AddFloatArg("max_limit", 6.0)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Activation", "ReluxBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELUX")
.AddFloatArg("max_limit", 6.0)
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_RELUX_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_RELUX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ReluxBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_RELUX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_RELUX(N, C, H, W, TYPE) \
BM_RELUX_MACRO(N, C, H, W, TYPE, CPU); \
BM_RELUX_MACRO(N, C, H, W, TYPE, OPENCL);
BM_RELUX(1, 1, 512, 512, float);
BM_RELUX(1, 3, 128, 128, float);
BM_RELUX(1, 3, 512, 512, float);
BM_RELUX(1, 32, 112, 112, float);
BM_RELUX(1, 64, 256, 256, float);
template <DeviceType D, typename T>
static void PreluBenchmark(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "PreluBM")
.Input("InputImage")
.Output("Output")
.AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Activation", "PreluBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_PRELU_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_PRELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
PreluBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_PRELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_PRELU(N, C, H, W, TYPE) \
BM_PRELU_MACRO(N, C, H, W, TYPE, CPU); \
BM_PRELU_MACRO(N, C, H, W, TYPE, OPENCL);
BM_PRELU(1, 1, 512, 512, float);
BM_PRELU(1, 3, 128, 128, float);
BM_PRELU(1, 3, 512, 512, float);
BM_PRELU(1, 32, 112, 112, float);
BM_PRELU(1, 64, 256, 256, float);
template <DeviceType D, typename T>
static void TanhBenchmark(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "TanhBM")
.Input("InputImage")
.Output("Output")
.AddStringArg("activation", "TANH")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Activation", "TanhBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "TANH")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_TANH_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_TANH_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TanhBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_TANH_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_TANH(N, C, H, W, TYPE) \
BM_TANH_MACRO(N, C, H, W, TYPE, CPU); \
BM_TANH_MACRO(N, C, H, W, TYPE, OPENCL);
BM_TANH(1, 1, 512, 512, float);
BM_TANH(1, 3, 128, 128, float);
BM_TANH(1, 3, 512, 512, float);
BM_TANH(1, 32, 112, 112, float);
BM_TANH(1, 64, 256, 256, float);
template <DeviceType D, typename T>
static void SigmoidBenchmark(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "SigmoidBM")
.Input("InputImage")
.Output("Output")
.AddStringArg("activation", "SIGMOID")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Activation", "SigmoidBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "SIGMOID")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_SIGMOID_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_SIGMOID_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SigmoidBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_SIGMOID_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_SIGMOID(N, C, H, W, TYPE) \
BM_SIGMOID_MACRO(N, C, H, W, TYPE, CPU); \
BM_SIGMOID_MACRO(N, C, H, W, TYPE, OPENCL);
BM_SIGMOID(1, 1, 512, 512, float);
BM_SIGMOID(1, 3, 128, 128, float);
BM_SIGMOID(1, 3, 512, 512, float);
BM_SIGMOID(1, 32, 112, 112, float);
BM_SIGMOID(1, 64, 256, 256, float);
} // namespace mace
......@@ -7,10 +7,10 @@
namespace mace {
class ReluOpTest : public OpsTestBase {};
class ActivationOpTest : public OpsTestBase {};
template <DeviceType D>
void TestSimple() {
void TestSimpleRelu() {
OpsTestNet net;
// Add input data
......@@ -22,9 +22,10 @@ void TestSimple() {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluTest")
.Input("InputImage")
.Output("OutputImage")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
......@@ -34,9 +35,10 @@ void TestSimple() {
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluTest")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
......@@ -49,16 +51,18 @@ void TestSimple() {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ReluOpTest, CPUSimple) { TestSimple<DeviceType::CPU>(); }
TEST_F(ActivationOpTest, CPUSimpleRelu) { TestSimpleRelu<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ReluOpTest, NEONSimple) { TestSimple<DeviceType::NEON>(); }
TEST_F(ActivationOpTest, NEONSimpleRelu) { TestSimpleRelu<DeviceType::NEON>(); }
#endif
TEST_F(ReluOpTest, OPENCLSimple) { TestSimple<DeviceType::OPENCL>(); }
TEST_F(ActivationOpTest, OPENCLSimpleRelu) {
TestSimpleRelu<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestUnalignedSimple() {
void TestUnalignedSimpleRelu() {
OpsTestNet net;
// Add input data
......@@ -68,9 +72,10 @@ void TestUnalignedSimple() {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluTest")
.Input("InputImage")
.Output("OutputImage")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
......@@ -80,9 +85,10 @@ void TestUnalignedSimple() {
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluTest")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
......@@ -94,22 +100,22 @@ void TestUnalignedSimple() {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ReluOpTest, CPUUnalignedSimple) {
TestUnalignedSimple<DeviceType::CPU>();
TEST_F(ActivationOpTest, CPUUnalignedSimpleRelu) {
TestUnalignedSimpleRelu<DeviceType::CPU>();
}
#if __ARM_NEON
TEST_F(ReluOpTest, NEONUnalignedSimple) {
TestUnalignedSimple<DeviceType::NEON>();
TEST_F(ActivationOpTest, NEONUnalignedSimpleRelu) {
TestUnalignedSimpleRelu<DeviceType::NEON>();
}
#endif
TEST_F(ReluOpTest, OPENCLUnalignedSimple) {
TestUnalignedSimple<DeviceType::OPENCL>();
TEST_F(ActivationOpTest, OPENCLUnalignedSimpleRelu) {
TestUnalignedSimpleRelu<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestSimpleReluX() {
void TestSimpleRelux() {
OpsTestNet net;
// Add input data
......@@ -121,9 +127,10 @@ void TestSimpleReluX() {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluxTest")
.Input("InputImage")
.Output("OutputImage")
.AddStringArg("activation", "RELUX")
.AddFloatArg("max_limit", 6)
.Finalize(net.NewOperatorDef());
......@@ -134,9 +141,10 @@ void TestSimpleReluX() {
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluxTest")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELUX")
.AddFloatArg("max_limit", 6)
.Finalize(net.NewOperatorDef());
......@@ -150,29 +158,33 @@ void TestSimpleReluX() {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ReluOpTest, CPUSimpleReluX) { TestSimpleReluX<DeviceType::CPU>(); }
TEST_F(ActivationOpTest, CPUSimple) { TestSimpleRelux<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ReluOpTest, NEONSimpleReluX) { TestSimpleReluX<DeviceType::NEON>(); }
TEST_F(ActivationOpTest, NEONSimple) { TestSimpleRelux<DeviceType::NEON>(); }
#endif
TEST_F(ReluOpTest, OPENCLSimpleReluX) { TestSimpleReluX<DeviceType::OPENCL>(); }
TEST_F(ActivationOpTest, OPENCLSimple) {
TestSimpleRelux<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestUnalignedSimpleReluX() {
void TestSimpleReluRelux() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", {1, 1, 7, 1},
{-7, 7, -6, 6, -5, 5, -4});
net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluxTest")
.Input("InputImage")
.Output("OutputImage")
.AddStringArg("activation", "RELUX")
.AddFloatArg("max_limit", 6)
.Finalize(net.NewOperatorDef());
......@@ -183,9 +195,10 @@ void TestUnalignedSimpleReluX() {
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Relu", "ReluTest")
OpDefBuilder("Activation", "ReluxTest")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELUX")
.AddFloatArg("max_limit", 6)
.Finalize(net.NewOperatorDef());
......@@ -193,23 +206,195 @@ void TestUnalignedSimpleReluX() {
net.RunOp(D);
}
auto expected = CreateTensor<float>({1, 1, 7, 1}, {0, 6, 0, 6, 0, 5, 0});
auto expected = CreateTensor<float>(
{2, 2, 2, 2}, {0, 6, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ActivationOpTest, CPUSimpleRelux) {
TestSimpleReluRelux<DeviceType::CPU>();
}
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimpleRelux) {
TestSimpleReluRelux<DeviceType::NEON>();
}
#endif
TEST_F(ActivationOpTest, OPENCLSimpleRelux) {
TestSimpleReluRelux<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestSimplePrelu() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "PreluTest")
.Input("InputImage")
.Output("OutputImage")
.AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Activation", "PreluTest")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
auto expected = CreateTensor<float>(
{2, 2, 2, 2}, {-14, 7, -12, 6, -10, 5, -8, 4, -6, 3, -4, 2, -2, 1, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ActivationOpTest, CPUSimplePrelu) { TestSimplePrelu<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimplePrelu) {
TestSimplePrelu<DeviceType::NEON>();
}
#endif
TEST_F(ActivationOpTest, OPENCLSimplePrelu) {
TestSimplePrelu<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestSimpleTanh() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "TanhTest")
.Input("InputImage")
.Output("OutputImage")
.AddStringArg("activation", "TANH")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Activation", "TanhTest")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "TANH")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
auto expected = CreateTensor<float>(
{2, 2, 2, 2},
{-0.99999834, 0.99999834, -0.99998771, 0.99998771, -0.9999092, 0.9999092,
-0.9993293, 0.9993293, -0.99505475, 0.99505475, -0.96402758, 0.96402758,
-0.76159416, 0.76159416, 0., 0.});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ActivationOpTest, CPUSimpleTanh) { TestSimpleTanh<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimpleTanh) { TestSimpleTanh<DeviceType::NEON>(); }
#endif
TEST_F(ActivationOpTest, OPENCLSimpleTanh) {
TestSimpleTanh<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestSimpleSigmoid() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Activation", "SigmoidTest")
.Input("InputImage")
.Output("OutputImage")
.AddStringArg("activation", "SIGMOID")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Activation", "SigmoidTest")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "SIGMOID")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
auto expected = CreateTensor<float>(
{2, 2, 2, 2},
{9.11051194e-04, 9.99088949e-01, 2.47262316e-03, 9.97527377e-01,
6.69285092e-03, 9.93307149e-01, 1.79862100e-02, 9.82013790e-01,
4.74258732e-02, 9.52574127e-01, 1.19202922e-01, 8.80797078e-01,
2.68941421e-01, 7.31058579e-01, 5.00000000e-01, 5.00000000e-01});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ReluOpTest, CPUUnalignedSimpleReluX) {
TestUnalignedSimpleReluX<DeviceType::CPU>();
TEST_F(ActivationOpTest, CPUSimpleSigmoid) {
TestSimpleSigmoid<DeviceType::CPU>();
}
#if __ARM_NEON
TEST_F(ReluOpTest, NEONUnalignedSimpleReluX) {
TestUnalignedSimpleReluX<DeviceType::NEON>();
TEST_F(ActivationOpTest, NEONSimpleSigmoid) {
TestSimpleSigmoid<DeviceType::NEON>();
}
#endif
TEST_F(ReluOpTest, OPENCLUnalignedSimpleReluX) {
TestUnalignedSimpleReluX<DeviceType::OPENCL>();
TEST_F(ActivationOpTest, OPENCLSimpleSigmoid) {
TestSimpleSigmoid<DeviceType::OPENCL>();
}
} // namespace mace
......@@ -6,6 +6,7 @@
#define MACE_OPS_BATCH_NORM_H_
#include "mace/core/operator.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/batch_norm.h"
namespace mace {
......@@ -14,9 +15,10 @@ template <DeviceType D, class T>
class BatchNormOp : public Operator<D, T> {
public:
BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), functor_(false, false) {
epsilon_ =
OperatorBase::GetSingleArgument<float>("epsilon", static_cast<float>(1e-4));
: Operator<D, T>(operator_def, ws),
functor_(false, kernels::ActivationType::NOOP, 0.0f, 0.0f) {
epsilon_ = OperatorBase::GetSingleArgument<float>("epsilon",
static_cast<float>(1e-4));
}
bool Run(StatsFuture *future) override {
......
......@@ -18,9 +18,12 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
public:
Conv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->strides_.data(), this->padding_,
this->dilations_.data()) {
}
functor_(this->strides_.data(),
this->padding_,
this->dilations_.data(),
kernels::ActivationType::NOOP,
0.0f,
0.0f) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -15,8 +15,12 @@ class FoldedBatchNormOp : public Operator<D, T> {
public:
FoldedBatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(true, OperatorBase::GetSingleArgument<bool>("fused_relu", false)) {
}
functor_(true,
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -8,7 +8,7 @@
#include <memory>
#include "mace/core/operator.h"
#include "mace/kernels/fused_conv_2d.h"
#include "mace/kernels/conv_2d.h"
#include "mace/ops/conv_pool_2d_base.h"
namespace mace {
......@@ -18,9 +18,14 @@ class FusedConv2dOp : public ConvPool2dOpBase<D, T> {
public:
FusedConv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->strides_.data(), this->padding_,
this->dilations_.data()) {
}
functor_(this->strides_.data(),
this->padding_,
this->dilations_.data(),
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......@@ -34,7 +39,7 @@ class FusedConv2dOp : public ConvPool2dOpBase<D, T> {
}
private:
kernels::FusedConv2dFunctor<D, T> functor_;
kernels::Conv2dFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, FILTER, BIAS);
......
......@@ -47,9 +47,9 @@ static void GlobalAvgPooling(
BENCHMARK(BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE)
#define BM_GLOBAL_AVG_POOLING(N, C, H, W) \
BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, CPU); \
BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, NEON);
BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, CPU);
// BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, NEON);
BM_GLOBAL_AVG_POOLING(1, 3, 7, 7);
BM_GLOBAL_AVG_POOLING(1, 3, 64, 64);
BM_GLOBAL_AVG_POOLING(1, 3, 256, 256);
\ No newline at end of file
BM_GLOBAL_AVG_POOLING(1, 3, 256, 256);
......@@ -63,8 +63,8 @@ static void Pooling(int iters,
BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE)
#define BM_POOLING(N, C, H, W, K, S, PA, PO) \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON);
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU);
// BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON);
BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void ReluBenchmark(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Relu", "ReluBM")
.Input("InputImage")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Relu", "ReluBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_RELU_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_RELU_##N##C##H##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ReluBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_RELU_##N##C##H##W##_##TYPE##_##DEVICE)
#define BM_RELU(N, C, H, W, TYPE) \
BM_RELU_MACRO(N, C, H, W, TYPE, CPU); \
BM_RELU_MACRO(N, C, H, W, TYPE, NEON); \
BM_RELU_MACRO(N, C, H, W, TYPE, OPENCL);
BM_RELU(1, 1, 512, 512, float);
BM_RELU(1, 3, 128, 128, float);
BM_RELU(1, 3, 512, 512, float);
BM_RELU(1, 32, 112, 112, float);
BM_RELU(1, 64, 256, 256, float);
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册