提交 1f499285 编写于 作者: 卢旭辉

Merge branch 'elu' into 'master'

Add onnx's Elu operator

See merge request applied-machine-learning/sysml/mace!1294
......@@ -21,6 +21,7 @@ Operator lists
"DEPTH_TO_SPACE","Y",""
"DEQUANTIZE","Y","Model quantization will be supported later."
"ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/SQRT/EQUAL/FLOOR_DIV"
"ELU","Y",""
"EMBEDDING_LOOKUP","Y",""
"EXPANDDIMS","Y","Only CPU and TensorFlow is supported."
"FILL","Y","Only CPU and TensorFlow is supported."
......
......@@ -54,7 +54,7 @@ class ActivationOp<DeviceType::CPU, T> : public Operation {
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
if (activation_type_ == PRELU) {
if (activation_type_ == PRELU || activation_type_ == ELU) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
......@@ -63,8 +63,8 @@ class ActivationOp<DeviceType::CPU, T> : public Operation {
const T *alpha_ptr = alpha->data<T>();
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(context, input_ptr, outer_size, input->dim(1), inner_size,
alpha_ptr, output_ptr);
ActivationWithAlpha(context, input_ptr, outer_size, input->dim(1),
inner_size, alpha_ptr, activation_type_, output_ptr);
} else {
activation_delegator_->Compute(context, input, output);
}
......@@ -96,7 +96,7 @@ class ActivationOp<DeviceType::GPU, float> : public Operation {
} else {
MACE_NOT_IMPLEMENTED;
}
if (type == ActivationType::PRELU) {
if (type == ActivationType::PRELU || type == ActivationType::ELU) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS);
......
......@@ -42,6 +42,8 @@ inline ActivationType StringToActivationType(const std::string type) {
return ActivationType::NOOP;
} else if (type == "LEAKYRELU") {
return ActivationType::LEAKYRELU;
} else if (type == "ELU") {
return ActivationType::ELU;
} else {
LOG(FATAL) << "Unknown activation type: " << type;
}
......@@ -49,13 +51,14 @@ inline ActivationType StringToActivationType(const std::string type) {
}
template<typename T>
void PReLUActivation(const OpContext *context,
const T *input_ptr,
const index_t outer_size,
const index_t input_chan,
const index_t inner_size,
const T *alpha_ptr,
T *output_ptr) {
void ActivationWithAlpha(const OpContext *context,
const T *input_ptr,
const index_t outer_size,
const index_t input_chan,
const index_t inner_size,
const T *alpha_ptr,
const index_t activation_type,
T *output_ptr) {
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
......@@ -66,7 +69,12 @@ void PReLUActivation(const OpContext *context,
for (index_t j = 0; j < inner_size; ++j) {
index_t idx = i * input_chan * inner_size + chan_idx * inner_size + j;
if (input_ptr[idx] < 0) {
output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx];
if (activation_type == ActivationType::PRELU) {
output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx];
} else if (activation_type == ActivationType::ELU) {
output_ptr[idx] =
(std::exp(input_ptr[idx]) - 1) * alpha_ptr[chan_idx];
}
} else {
output_ptr[idx] = input_ptr[idx];
}
......@@ -75,7 +83,6 @@ void PReLUActivation(const OpContext *context,
}
}, 0, outer_size, 1, 0, input_chan, 1);
}
} // namespace ops
} // namespace mace
......
......@@ -26,6 +26,7 @@ enum ActivationType {
TANH = 4,
SIGMOID = 5,
LEAKYRELU = 6,
ELU = 7,
};
} // namespace ops
......
......@@ -3,7 +3,7 @@
__kernel void activation(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
#ifdef USE_PRELU
#if defined (USE_PRELU) || defined (USE_ELU)
__read_only image2d_t alpha,
#endif
__private const float relux_max_limit,
......@@ -23,9 +23,9 @@ __kernel void activation(OUT_OF_RANGE_PARAMS
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
#ifdef USE_PRELU
DATA_TYPE4 prelu_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 out = do_activation(in, prelu_alpha, relux_max_limit, leakyrelu_coefficient);
#if defined (USE_PRELU) || defined (USE_ELU)
DATA_TYPE4 activation_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 out = do_activation(in, activation_alpha, relux_max_limit, leakyrelu_coefficient);
#else
DATA_TYPE4 out = do_activation(in, relux_max_limit, leakyrelu_coefficient);
#endif
......
......@@ -83,8 +83,8 @@ inline float4 do_sigmoid(float4 in) {
#ifdef DATA_TYPE
inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
#ifdef USE_PRELU
DATA_TYPE4 prelu_alpha,
#if defined (USE_PRELU) || defined (USE_ELU)
DATA_TYPE4 alpha,
#endif
__private const float relux_max_limit,
__private const float leakyrelu_coefficient) {
......@@ -96,7 +96,10 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
out = clamp(in, (DATA_TYPE4)0, relux_max_limit);
#endif
#ifdef USE_PRELU
out = select(prelu_alpha * in, in, in >= (DATA_TYPE)0);
out = select(alpha * in, in, in >= (DATA_TYPE)0);
#endif
#ifdef USE_ELU
out = select(alpha * (native_exp(in) - 1.0f), in, in >= (DATA_TYPE)0);
#endif
#ifdef USE_TANH
out = tanh(in);
......
......@@ -58,6 +58,11 @@ MaceStatus ActivationKernel::Compute(
built_options.emplace("-DUSE_PRELU");
break;
}
case ELU: {
tuning_key_prefix_ = "elu_opencl_kernel";
built_options.emplace("-DUSE_ELU");
break;
}
case TANH: {
tuning_key_prefix_ = "tanh_opencl_kernel";
built_options.emplace("-DUSE_TANH");
......@@ -94,7 +99,7 @@ MaceStatus ActivationKernel::Compute(
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *(input->opencl_image()));
if (activation_ == PRELU) {
if (activation_ == PRELU || activation_ == ELU) {
MACE_CHECK_NOTNULL(alpha);
kernel_.setArg(idx++, *(alpha->opencl_image()));
}
......
......@@ -161,6 +161,10 @@ MaceStatus WinogradOutputTransform(OpContext *context,
built_options.emplace("-DUSE_PRELU");
break;
}
case ELU: {
built_options.emplace("-DUSE_ELU");
break;
}
case TANH: {
built_options.emplace("-DUSE_TANH");
break;
......
......@@ -208,6 +208,70 @@ MACE_BM_PRELU(1, 3, 512, 512);
MACE_BM_PRELU(1, 32, 112, 112);
MACE_BM_PRELU(1, 64, 256, 256);
namespace {
template <DeviceType D, typename T>
void EluBenchmark(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else if (D == DeviceType::GPU) {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Alpha", {channels}, true);
OpDefBuilder("Activation", "EluBM")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "ELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_ELU_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_ELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
EluBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK(MACE_BM_ELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#ifdef MACE_ENABLE_OPENCL
#define MACE_BM_ELU(N, C, H, W) \
MACE_BM_ELU_MACRO(N, C, H, W, float, CPU); \
MACE_BM_ELU_MACRO(N, C, H, W, float, GPU); \
MACE_BM_ELU_MACRO(N, C, H, W, half, GPU)
#else
#define MACE_BM_ELU(N, C, H, W) \
MACE_BM_ELU_MACRO(N, C, H, W, float, CPU)
#endif
MACE_BM_ELU(1, 1, 512, 512);
MACE_BM_ELU(1, 3, 128, 128);
MACE_BM_ELU(1, 3, 512, 512);
MACE_BM_ELU(1, 32, 112, 112);
MACE_BM_ELU(1, 64, 256, 256);
namespace {
template <DeviceType D, typename T>
void TanhBenchmark(int iters, int batch, int channels, int height, int width) {
......
......@@ -235,6 +235,59 @@ TEST_F(ActivationOpTest, OPENCLSimplePrelu) {
TestSimplePrelu<DeviceType::GPU>();
}
namespace {
template <DeviceType D>
void TestSimpleElu() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, -5, -4, -4, -3, 3, -2, 2, -1, -1, 0, 0});
net.AddInputFromArray<D, float>("Alpha", {2}, {2.0, 3.0}, true);
if (D == DeviceType::GPU) {
OpDefBuilder("Activation", "EluTest")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "ELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Activation", "EluTest")
.Input("InputNCHW")
.Input("Alpha")
.Output("OutputNCHW")
.AddStringArg("activation", "ELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
}
auto expected = net.CreateTensor<float>(
{2, 2, 2, 2},
{-1.998176236068891, 7, -1.9950424956466672, 6, -1.986524106001829,
-2.9797861590027437, -1.9633687222225316, -2.9450530833337973,
-1.900425863264272, 3, -1.7293294335267746, 2, -1.2642411176571153,
-1.896361676485673, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(ActivationOpTest, CPUSimpleElu) { TestSimpleElu<DeviceType::CPU>(); }
TEST_F(ActivationOpTest, OPENCLSimpleElu) {
TestSimpleElu<DeviceType::GPU>();
}
namespace {
template <DeviceType D>
void TestSimpleTanh() {
......
......@@ -49,6 +49,7 @@ class ActivationType(Enum):
TANH = 4
SIGMOID = 5
LEAKYRELU = 6
ELU = 7
class EltwiseType(Enum):
......
......@@ -85,7 +85,7 @@ OnnxSupportedOps = [
'Div',
'Dropout',
'DynamicLSTM',
# 'Elu',
'Elu',
'Equal',
# 'Exp',
# 'Expand',
......@@ -323,6 +323,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Relu.name: ActivationType.RELU,
OnnxOpType.LeakyRelu.name: ActivationType.LEAKYRELU,
OnnxOpType.PRelu.name: ActivationType.PRELU,
OnnxOpType.Elu.name: ActivationType.ELU,
OnnxOpType.Tanh.name: ActivationType.TANH,
OnnxOpType.Sigmoid.name: ActivationType.SIGMOID,
}
......@@ -348,6 +349,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Dropout.name: self.convert_dropout,
OnnxOpType.DimRange.name: self.convert_dim_range,
OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Elu.name: self.convert_activation,
OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.ExtractPooling.name: self.convert_extract_pooling,
OnnxOpType.Flatten.name: self.convert_flatten,
......@@ -627,7 +629,11 @@ class OnnxConverter(base_converter.ConverterInterface):
type_arg.s = six.b(self.activation_type[node.op_type].name)
if "alpha" in node.attrs:
alpha_value = node.attrs["alpha"]
alpha_tensor_name = node.name + '_alpha'
alpha_value = np.array([node.attrs["alpha"]])
self.add_tensor(alpha_tensor_name, alpha_value.reshape(-1).shape,
mace_pb2.DT_FLOAT, alpha_value)
op.input.extend([alpha_tensor_name])
else:
if node.op_type == OnnxOpType.LeakyRelu.name:
alpha_value = 0.01
......
......@@ -977,7 +977,10 @@ class Transformer(base_converter.ConverterInterface):
[ActivationType.RELU.name,
ActivationType.RELUX.name])
else:
fold_consumer = (act_type != ActivationType.PRELU.name)
fold_consumer = (
act_type != ActivationType.PRELU.name
and act_type != ActivationType.ELU.name
)
# during quantization, only fold relu/relux
if (self._option.quantize_stat or self._option.quantize) \
and act_type not in [ActivationType.RELU.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册