提交 9e6ee11c 编写于 作者: 李超

Merge branch 'elu' into 'master'

feat: add Elu for Tensorflow and Caffe and fix bug in reverse op and net adapter

See merge request applied-machine-learning/sysml/mace!1309
......@@ -199,11 +199,11 @@ MaceStatus NetDefAdapter::AdaptNetDef(
input_data_format, input_shape, -1));
}
OpConditionContext context(ws_, &tensor_shape_map);
DataFormat op_output_data_format;
MemoryType op_output_mem_type;
for (int idx = 0; idx < net_def->op_size(); ++idx) {
OperatorDef op_def(net_def->op(idx));
OpConditionContext context(ws_, &tensor_shape_map);
context.set_operator_def(&op_def);
// Select device
MACE_RETURN_IF_ERROR(this->AdaptDevice(&context,
......
......@@ -47,14 +47,14 @@ class ActivationOp<DeviceType::CPU, T> : public Operation {
activation_type_,
Operation::GetOptionalArg<float>("max_limit", 0.f),
Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.f)))) {}
"activation_coefficient", 0.f)))) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
if (activation_type_ == PRELU || activation_type_ == ELU) {
if (activation_type_ == PRELU) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
......@@ -63,8 +63,8 @@ class ActivationOp<DeviceType::CPU, T> : public Operation {
const T *alpha_ptr = alpha->data<T>();
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
ActivationWithAlpha(context, input_ptr, outer_size, input->dim(1),
inner_size, alpha_ptr, activation_type_, output_ptr);
PReLUActivation(context, input_ptr, outer_size, input->dim(1),
inner_size, alpha_ptr, output_ptr);
} else {
activation_delegator_->Compute(context, input, output);
}
......@@ -86,17 +86,17 @@ class ActivationOp<DeviceType::GPU, float> : public Operation {
Operation::GetOptionalArg<std::string>("activation",
"NOOP"));
auto relux_max_limit = Operation::GetOptionalArg<float>("max_limit", 0.0f);
auto leakyrelu_coefficient =
Operation::GetOptionalArg<float>("leakyrelu_coefficient", 0.0f);
auto activation_coefficient =
Operation::GetOptionalArg<float>("activation_coefficient", 0.0f);
MemoryType mem_type;
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::ActivationKernel>(
type, relux_max_limit, leakyrelu_coefficient);
type, relux_max_limit, activation_coefficient);
} else {
MACE_NOT_IMPLEMENTED;
}
if (type == ActivationType::PRELU || type == ActivationType::ELU) {
if (type == ActivationType::PRELU) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS);
......
......@@ -51,14 +51,13 @@ inline ActivationType StringToActivationType(const std::string type) {
}
template<typename T>
void ActivationWithAlpha(const OpContext *context,
const T *input_ptr,
const index_t outer_size,
const index_t input_chan,
const index_t inner_size,
const T *alpha_ptr,
const index_t activation_type,
T *output_ptr) {
void PReLUActivation(const OpContext *context,
const T *input_ptr,
const index_t outer_size,
const index_t input_chan,
const index_t inner_size,
const T *alpha_ptr,
T *output_ptr) {
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
......@@ -69,12 +68,7 @@ void ActivationWithAlpha(const OpContext *context,
for (index_t j = 0; j < inner_size; ++j) {
index_t idx = i * input_chan * inner_size + chan_idx * inner_size + j;
if (input_ptr[idx] < 0) {
if (activation_type == ActivationType::PRELU) {
output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx];
} else if (activation_type == ActivationType::ELU) {
output_ptr[idx] =
(std::exp(input_ptr[idx]) - 1) * alpha_ptr[chan_idx];
}
output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx];
} else {
output_ptr[idx] = input_ptr[idx];
}
......
......@@ -81,6 +81,11 @@ void Activation<T>::DoActivation(const OpContext *context,
break;
}
case ELU: {
ActivateElu(&thread_pool, input, output);
break;
}
case NOOP: {
break;
}
......@@ -164,7 +169,7 @@ void Activation<T>::ActivateLeakyRelu(utils::ThreadPool *thread_pool,
auto output_data = output->mutable_data<T>();
const index_t input_size = input->size();
const float32x4_t vzero = vdupq_n_f32(0.f);
const float32x4_t valpha = vdupq_n_f32(leakyrelu_coefficient_);
const float32x4_t valpha = vdupq_n_f32(activation_coefficient_);
const index_t block_count = input_size / 4;
thread_pool->Compute1D(
......@@ -188,7 +193,7 @@ void Activation<T>::ActivateLeakyRelu(utils::ThreadPool *thread_pool,
// remain
for (index_t i = block_count * 4; i < input_size; ++i) {
output_data[i] = std::max(input_data[i], 0.f) +
std::min(input_data[i], 0.f) * leakyrelu_coefficient_;
std::min(input_data[i], 0.f) * activation_coefficient_;
}
}
......@@ -226,6 +231,28 @@ void Activation<T>::ActivateSigmoid(utils::ThreadPool *thread_pool,
0, input_size, 1);
}
template<typename T>
void Activation<T>::ActivateElu(utils::ThreadPool *thread_pool,
const Tensor *input,
Tensor *output) {
const auto *input_data = input->data<T>();
auto *output_data = output->mutable_data<T>();
const index_t input_size = input->size();
thread_pool->Compute1D(
[=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
const auto in_val = input_data[i];
if (in_val < 0) {
output_data[i] = (std::exp(in_val) - 1) * activation_coefficient_;
} else {
output_data[i] = in_val;
}
}
},
0, input_size, 1);
}
void RegisterActivationDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_DELEGATOR(
registry, Activation<float>, delegator::ActivationParam,
......@@ -240,7 +267,7 @@ void RegisterActivationDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_BF16_DELEGATOR(
registry, Activation<BFloat16>, delegator::ActivationParam,
MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, BFloat16,
ImplType::NEON));
ImplType::NEON));
}
} // namespace arm
......
......@@ -45,6 +45,8 @@ class Activation : public delegator::Activation {
Tensor *output);
void ActivateSigmoid(utils::ThreadPool *thread_pool, const Tensor *input,
Tensor *output);
void ActivateElu(utils::ThreadPool *thread_pool, const Tensor *input,
Tensor *output);
};
} // namespace arm
......
......@@ -123,6 +123,16 @@ void Activation<uint8_t>::ActivateSigmoid(utils::ThreadPool *thread_pool,
MACE_NOT_IMPLEMENTED;
}
template<>
void Activation<uint8_t>::ActivateElu(utils::ThreadPool *thread_pool,
const Tensor *input,
Tensor *output) {
MACE_UNUSED(thread_pool);
MACE_UNUSED(input);
MACE_UNUSED(output);
MACE_NOT_IMPLEMENTED;
}
} // namespace arm
} // namespace ops
} // namespace mace
......@@ -50,7 +50,7 @@ class BatchNormOp<DeviceType::CPU, T> : public Operation {
Operation::GetOptionalArg<std::string>("activation",
"NOOP")),
Operation::GetOptionalArg<float>("max_limit", 0.0f),
Operation::GetOptionalArg<float>("leakyrelu_coefficient",
Operation::GetOptionalArg<float>("activation_coefficient",
0.0f)))) {}
MaceStatus Run(OpContext *context) override {
......@@ -168,13 +168,13 @@ class BatchNormOp<DeviceType::GPU, float> : public Operation {
ActivationType activation = ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation", "NOOP"));
float relux_max_limit = Operation::GetOptionalArg<float>("max_limit", 0.0f);
float leakyrelu_coefficient = Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f);
float activation_coefficient = Operation::GetOptionalArg<float>(
"activation_coefficient", 0.0f);
MemoryType mem_type;
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::BatchNormKernel>(
epsilon, activation, relux_max_limit, leakyrelu_coefficient);
epsilon, activation, relux_max_limit, activation_coefficient);
} else {
MACE_NOT_IMPLEMENTED;
}
......
......@@ -68,7 +68,7 @@ class Conv2dOp<DeviceType::CPU, T> : public ConvPool2dOpBase {
Operation::GetOptionalArg<std::string>("activation",
"NOOP")),
Operation::GetOptionalArg<float>("max_limit", 0.0f),
Operation::GetOptionalArg<float>("leakyrelu_coefficient",
Operation::GetOptionalArg<float>("activation_coefficient",
0.0f)))),
bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
......@@ -190,8 +190,8 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)) {}
activation_coefficient_(Operation::GetOptionalArg<float>(
"activation_coefficient", 0.0f)) {}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
......@@ -414,7 +414,7 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
private:
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
std::vector<int32_t> bias_;
private:
......@@ -433,8 +433,8 @@ class Conv2dOp<DeviceType::GPU, float> : public ConvPool2dOpBase {
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)),
activation_coefficient_(Operation::GetOptionalArg<float>(
"activation_coefficient", 0.0f)),
wino_block_size_(Operation::GetOptionalArg<int>("wino_block_size", 0)) {
MemoryType mem_type;
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
......@@ -488,13 +488,13 @@ class Conv2dOp<DeviceType::GPU, float> : public ConvPool2dOpBase {
return kernel_->Compute(context, input, filter, bias,
strides_.data(), padding_type_, paddings_,
dilations_.data(), activation_, relux_max_limit_,
leakyrelu_coefficient_, wino_block_size_, output);
activation_coefficient_, wino_block_size_, output);
}
private:
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
std::unique_ptr<OpenCLConv2dKernel> kernel_;
int wino_block_size_;
......
......@@ -56,8 +56,8 @@ class Deconv2dOp<DeviceType::CPU, T> : public Deconv2dOpBase {
context->workspace(),
MACE_DELEGATOR_KEY(Activation, DeviceType::CPU,
T, kCpuImplType),
delegator::ActivationParam(activation_, relux_max_limit_,
leakyrelu_coefficient_))),
delegator::ActivationParam(
activation_, relux_max_limit_, activation_coefficient_))),
bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, T, kCpuImplType),
......@@ -228,7 +228,7 @@ class Deconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
return kernel_->Compute(context, input, filter, bias,
strides_.data(), in_paddings.data(), activation_,
relux_max_limit_, leakyrelu_coefficient_,
relux_max_limit_, activation_coefficient_,
out_shape, output);
}
......
......@@ -43,8 +43,8 @@ class Deconv2dOpBase : public Operation {
"NOOP"))),
relux_max_limit_(
Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(
Operation::GetOptionalArg<float>("leakyrelu_coefficient", 0.0f)) {}
activation_coefficient_(
Operation::GetOptionalArg<float>("activation_coefficient", 0.0f)) {}
protected:
std::vector<int> strides_; // [stride_h, stride_w]
......@@ -54,7 +54,7 @@ class Deconv2dOpBase : public Operation {
const FrameworkType model_type_;
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
};
} // namespace ops
......
......@@ -26,20 +26,20 @@ namespace delegator {
struct ActivationParam : public DelegatorParam {
explicit ActivationParam(ActivationType type, const float limit,
const float leakyrelu_coefficient)
const float activation_coefficient)
: type_(type), limit_(limit),
leakyrelu_coefficient_(leakyrelu_coefficient) {}
activation_coefficient_(activation_coefficient) {}
ActivationType type_;
const float limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
};
class Activation : public OpDelegator {
public:
explicit Activation(const ActivationParam &param)
: OpDelegator(param), type_(param.type_), limit_(param.limit_),
leakyrelu_coefficient_(param.leakyrelu_coefficient_) {}
activation_coefficient_(param.activation_coefficient_) {}
virtual ~Activation() = default;
MACE_DEFINE_DELEGATOR_CREATOR(Activation)
......@@ -51,7 +51,7 @@ class Activation : public OpDelegator {
protected:
ActivationType type_;
const float limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
};
} // namespace delegator
......
......@@ -52,12 +52,12 @@ class DepthwiseConv2dOpBase : public ConvPool2dOpBase {
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)) {}
activation_coefficient_(Operation::GetOptionalArg<float>(
"activation_coefficient", 0.0f)) {}
protected:
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
};
template<DeviceType D, class T>
......@@ -73,8 +73,8 @@ class DepthwiseConv2dOp<DeviceType::CPU, T> : public DepthwiseConv2dOpBase {
context->workspace(),
MACE_DELEGATOR_KEY(Activation, DeviceType::CPU,
T, kCpuImplType),
delegator::ActivationParam(activation_, relux_max_limit_,
leakyrelu_coefficient_))),
delegator::ActivationParam(
activation_, relux_max_limit_, activation_coefficient_))),
bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, T, kCpuImplType),
......@@ -389,7 +389,7 @@ class DepthwiseConv2dOp<DeviceType::GPU, float> :
return kernel_->Compute(context, input, filter, bias,
strides_.data(), padding_type_, paddings_,
dilations_.data(), activation_, relux_max_limit_,
leakyrelu_coefficient_, output);
activation_coefficient_, output);
}
private:
......
......@@ -55,8 +55,8 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, T>
context->workspace(),
MACE_DELEGATOR_KEY(Activation, DeviceType::CPU,
T, kCpuImplType),
delegator::ActivationParam(activation_, relux_max_limit_,
leakyrelu_coefficient_))),
delegator::ActivationParam(
activation_, relux_max_limit_, activation_coefficient_))),
bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, T, kCpuImplType),
......@@ -209,7 +209,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
group_,
activation_,
relux_max_limit_,
leakyrelu_coefficient_,
activation_coefficient_,
out_shape,
output);
}
......
......@@ -42,12 +42,12 @@ class FullyConnectedOpBase : public Operation {
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)) {}
activation_coefficient_(Operation::GetOptionalArg<float>(
"activation_coefficient", 0.0f)) {}
protected:
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
MACE_OP_INPUT_TAGS(INPUT, WEIGHT, BIAS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
......@@ -64,9 +64,8 @@ class FullyConnectedOp<DeviceType::CPU, T> : public FullyConnectedOpBase {
activation_delegator_(delegator::Activation::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, T, kCpuImplType),
delegator::ActivationParam(activation_,
relux_max_limit_,
leakyrelu_coefficient_))),
delegator::ActivationParam(
activation_, relux_max_limit_, activation_coefficient_))),
gemv_(delegator::Gemv::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Gemv, DeviceType::CPU, T, kCpuImplType),
......@@ -215,7 +214,7 @@ class FullyConnectedOp<DeviceType::GPU, float> : public FullyConnectedOpBase {
" don't match.");
return kernel_->Compute(
context, input, weight, bias, activation_, relux_max_limit_,
leakyrelu_coefficient_, output);
activation_coefficient_, output);
}
private:
......
......@@ -46,7 +46,7 @@ MaceStatus Conv2dKernel::Compute(
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const int winograd_blk_size,
Tensor *output) {
MACE_UNUSED(winograd_blk_size);
......@@ -148,14 +148,14 @@ MaceStatus Conv2dKernel::Compute(
return conv2d::Conv2d1x1(
context, &kernels_[1], pad_input, filter, bias, strides,
activation, relux_max_limit,
leakyrelu_coefficient, input_changed, output, &conv_future);
activation_coefficient, input_changed, output, &conv_future);
};
} else {
conv_func = [&](const Tensor *pad_input, Tensor *output) -> MaceStatus {
return conv2d::Conv2dGeneral(
context, &kernels_[1], pad_input, filter, bias, strides, dilations,
activation, relux_max_limit,
leakyrelu_coefficient, input_changed, output, &conv_future);
activation_coefficient, input_changed, output, &conv_future);
};
}
MACE_RETURN_IF_ERROR(conv_func(padded_input_ptr, output));
......
......@@ -38,7 +38,7 @@ extern MaceStatus Conv2d1x1(OpContext *context,
const int *strides,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const bool input_changed,
Tensor *output,
StatsFuture *future);
......@@ -52,7 +52,7 @@ extern MaceStatus Conv2dGeneral(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const bool input_changed,
Tensor *output,
StatsFuture *future);
......@@ -81,7 +81,7 @@ class Conv2dKernel : public OpenCLConv2dKernel {
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const int winograd_blk_size,
Tensor *output) override;
......
......@@ -31,7 +31,7 @@ MaceStatus Conv2d1x1(OpContext *context,
const int *strides,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const bool input_changed,
Tensor *output,
StatsFuture *future) {
......@@ -75,6 +75,9 @@ MaceStatus Conv2d1x1(OpContext *context,
case LEAKYRELU:
built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -110,7 +113,7 @@ MaceStatus Conv2d1x1(OpContext *context,
kernel->setArg(idx++, strides[0]);
kernel->setArg(idx++, strides[1]);
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
kernel->setArg(idx++, *(output->opencl_buffer()));
}
......
......@@ -32,7 +32,7 @@ MaceStatus Conv2dGeneral(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const bool input_changed,
Tensor *output,
StatsFuture *future) {
......@@ -81,6 +81,9 @@ MaceStatus Conv2dGeneral(OpContext *context,
case LEAKYRELU:
built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -125,7 +128,7 @@ MaceStatus Conv2dGeneral(OpContext *context,
kernel->setArg(idx++, static_cast<int32_t>(
dilations[1] * in_channel));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
kernel->setArg(idx++, *(output->opencl_buffer()));
}
......
......@@ -32,7 +32,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const bool input_changed,
Tensor *output,
StatsFuture *future) {
......@@ -79,6 +79,9 @@ MaceStatus DepthwiseConv2d(OpContext *context,
case LEAKYRELU:
built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -119,7 +122,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
kernel->setArg(idx++, static_cast<int32_t>(
dilations[1] * in_channel));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
kernel->setArg(idx++, *(output->opencl_buffer()));
}
......@@ -147,7 +150,7 @@ MaceStatus DepthwiseConv2dKernel::Compute(
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) {
StatsFuture pad_future, dw_conv_future;
index_t filter_w = filter->dim(3);
......@@ -242,7 +245,7 @@ MaceStatus DepthwiseConv2dKernel::Compute(
depthwise::DepthwiseConv2d(
context, &kernels_[1], padded_input_ptr, filter, bias, strides,
dilations, activation, relux_max_limit,
leakyrelu_coefficient, input_changed, output, &dw_conv_future));
activation_coefficient, input_changed, output, &dw_conv_future));
MergeMultipleFutureWaitFn({pad_future, dw_conv_future}, context->future());
return MaceStatus::MACE_SUCCESS;
}
......
......@@ -39,7 +39,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const bool input_changed,
Tensor *output,
StatsFuture *future);
......@@ -59,7 +59,7 @@ class DepthwiseConv2dKernel : public OpenCLDepthwiseConv2dKernel {
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) override;
private:
......
......@@ -3,11 +3,11 @@
__kernel void activation(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
#if defined (USE_PRELU) || defined (USE_ELU)
#ifdef USE_PRELU
__read_only image2d_t alpha,
#endif
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float coefficient,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
......@@ -23,11 +23,11 @@ __kernel void activation(OUT_OF_RANGE_PARAMS
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
#if defined (USE_PRELU) || defined (USE_ELU)
#ifdef USE_PRELU
DATA_TYPE4 activation_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 out = do_activation(in, activation_alpha, relux_max_limit, leakyrelu_coefficient);
DATA_TYPE4 out = do_activation(in, activation_alpha, relux_max_limit, coefficient);
#else
DATA_TYPE4 out = do_activation(in, relux_max_limit, leakyrelu_coefficient);
DATA_TYPE4 out = do_activation(in, relux_max_limit, coefficient);
#endif
WRITE_IMAGET(output, (int2)(pos, hb), out);
......
......@@ -12,7 +12,7 @@ __kernel void batch_norm(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient) {
__private const float activation_coefficient) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
......@@ -44,8 +44,8 @@ __kernel void batch_norm(OUT_OF_RANGE_PARAMS
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = mad(in, bn_scale, bn_offset);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out = do_activation(out, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out = do_activation(out, relux_max_limit, activation_coefficient);
#endif
WRITE_IMAGET(output, (int2)(pos, hb), out);
......
......@@ -83,11 +83,11 @@ inline float4 do_sigmoid(float4 in) {
#ifdef DATA_TYPE
inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
#if defined (USE_PRELU) || defined (USE_ELU)
#if defined (USE_PRELU)
DATA_TYPE4 alpha,
#endif
__private const float relux_max_limit,
__private const float leakyrelu_coefficient) {
__private const float activation_coefficient) {
DATA_TYPE4 out;
#ifdef USE_RELU
out = fmax(in, (DATA_TYPE)0);
......@@ -99,7 +99,8 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
out = select(alpha * in, in, in >= (DATA_TYPE)0);
#endif
#ifdef USE_ELU
out = select(alpha * (native_exp(in) - 1.0f), in, in >= (DATA_TYPE)0);
out = select(activation_coefficient * (native_exp(in) - 1.0f),
in, in >= (DATA_TYPE)0);
#endif
#ifdef USE_TANH
out = tanh(in);
......@@ -108,7 +109,7 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
out = do_sigmoid(in);
#endif
#ifdef USE_LEAKYRELU
out = select(leakyrelu_coefficient * in, in, in >= (DATA_TYPE)0);
out = select(activation_coefficient * in, in, in >= (DATA_TYPE)0);
#endif
return out;
}
......
......@@ -9,7 +9,7 @@ __kernel void conv_2d(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
......@@ -125,11 +125,11 @@ __kernel void conv_2d(OUT_OF_RANGE_PARAMS
}
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
#endif
const int out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -9,7 +9,7 @@ __kernel void conv_2d_1x1(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
......@@ -98,11 +98,11 @@ __kernel void conv_2d_1x1(OUT_OF_RANGE_PARAMS
filter_x_base += 4;
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
#endif
const int out_x_base = mul24(out_ch_blk, width);
......
......@@ -17,7 +17,7 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS
__private const int stride_h,
__private const int stride_w,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__global OUT_DATA_TYPE *output) {
const int out_wc_blk_idx = get_global_id(0);
const int out_hb_idx = get_global_id(1);
......@@ -80,9 +80,9 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS
in_offset += 4;
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
#endif
int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx),
......
......@@ -9,7 +9,7 @@ __kernel void conv_2d_3x3(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
......@@ -130,12 +130,12 @@ __kernel void conv_2d_3x3(OUT_OF_RANGE_PARAMS
}
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
out4 = do_activation(out4, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
out4 = do_activation(out4, relux_max_limit, activation_coefficient);
#endif
const int out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -22,7 +22,7 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS
__private const int dilated_h_offset,
__private const int dilated_w_offset,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__global OUT_DATA_TYPE *output) {
const int out_wc_blk_idx = get_global_id(0);
const int out_hb_idx = get_global_id(1);
......@@ -108,11 +108,11 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS
}
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
#endif
int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx),
......
......@@ -9,7 +9,7 @@ __kernel void deconv_2d(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__private const int in_height,
__private const int in_width,
__private const int in_channels,
......@@ -129,12 +129,12 @@ __kernel void deconv_2d(OUT_OF_RANGE_PARAMS
}
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
out4 = do_activation(out4, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
out4 = do_activation(out4, relux_max_limit, activation_coefficient);
#endif
int2 out_pos;
......
......@@ -10,7 +10,7 @@ __kernel void depthwise_conv2d(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__private const short in_height,
__private const short in_width,
__private const short in_ch_blks,
......@@ -113,11 +113,11 @@ __kernel void depthwise_conv2d(OUT_OF_RANGE_PARAMS
in_hb_idx += dilation_h;
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
#endif
const short out_x_base = mul24(out_ch_blk, out_width);
......@@ -146,7 +146,7 @@ __kernel void depthwise_conv2d_s1(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE leakyrelu_coefficient,
__private const DATA_TYPE activation_coefficient,
__private const short in_height,
__private const short in_width,
__private const short in_ch_blks,
......@@ -240,11 +240,11 @@ __kernel void depthwise_conv2d_s1(OUT_OF_RANGE_PARAMS
in_hb_idx += 1;
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
#endif
const short out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -22,7 +22,7 @@ __kernel void depthwise_conv2d(BUFFER_OUT_OF_RANGE_PARAMS
__private const int dilated_h_offset,
__private const int dilated_w_offset,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__global OUT_DATA_TYPE *output) {
const int out_wc_blk_idx = get_global_id(0);
const int out_hb_idx = get_global_id(1);
......@@ -86,11 +86,11 @@ __kernel void depthwise_conv2d(BUFFER_OUT_OF_RANGE_PARAMS
}
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
#endif
int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx),
......
......@@ -9,7 +9,7 @@ __kernel void depthwise_deconv2d(OUT_OF_RANGE_PARAMS
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient,
__private const float activation_coefficient,
__private const int in_height,
__private const int in_width,
__private const int out_height,
......@@ -109,12 +109,12 @@ __kernel void depthwise_deconv2d(OUT_OF_RANGE_PARAMS
}
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient);
out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient);
out4 = do_activation(out4, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0 = do_activation(out0, relux_max_limit, activation_coefficient);
out1 = do_activation(out1, relux_max_limit, activation_coefficient);
out2 = do_activation(out2, relux_max_limit, activation_coefficient);
out3 = do_activation(out3, relux_max_limit, activation_coefficient);
out4 = do_activation(out4, relux_max_limit, activation_coefficient);
#endif
......
......@@ -13,7 +13,7 @@ __kernel void fully_connected(OUT_OF_RANGE_PARAMS
__private const int input_width,
__private const int input_channel,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient) {
__private const float activation_coefficient) {
const int batch_idx = get_global_id(0);
const int out_blk_idx = get_global_id(1);
const int input_chan_blk = (input_channel + 3) >> 2;
......@@ -57,8 +57,8 @@ __kernel void fully_connected(OUT_OF_RANGE_PARAMS
input_coord.y++;
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
result = do_activation(result, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
result = do_activation(result, relux_max_limit, activation_coefficient);
#endif
WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result);
......@@ -79,7 +79,7 @@ __kernel void fully_connected_width(OUT_OF_RANGE_PARAMS
__private const int in_chan_blks,
__private const int out_blks,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient) {
__private const float activation_coefficient) {
const int inter_out_idx = get_global_id(0);
const int width_blk_idx = get_global_id(1);
const int width_blk_count = global_size_dim1;
......@@ -149,8 +149,8 @@ __kernel void fully_connected_width(OUT_OF_RANGE_PARAMS
inter_idx += 4;
}
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
result = do_activation(result, relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
result = do_activation(result, relux_max_limit, activation_coefficient);
#endif
WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result);
......
......@@ -128,7 +128,7 @@ __kernel void winograd_inverse_transform_2x2(OUT_OF_RANGE_PARAMS
__private const int round_hw,
__private const int round_w,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient) {
__private const float activation_coefficient) {
const int width_idx = get_global_id(0);
const int height_idx = get_global_id(1);
......@@ -204,11 +204,11 @@ __kernel void winograd_inverse_transform_2x2(OUT_OF_RANGE_PARAMS
#endif
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
in0[0] = do_activation(in0[0], relux_max_limit, leakyrelu_coefficient);
in0[1] = do_activation(in0[1], relux_max_limit, leakyrelu_coefficient);
in1[0] = do_activation(in1[0], relux_max_limit, leakyrelu_coefficient);
in1[1] = do_activation(in1[1], relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
in0[0] = do_activation(in0[0], relux_max_limit, activation_coefficient);
in0[1] = do_activation(in0[1], relux_max_limit, activation_coefficient);
in1[0] = do_activation(in1[0], relux_max_limit, activation_coefficient);
in1[1] = do_activation(in1[1], relux_max_limit, activation_coefficient);
#endif
WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]);
......@@ -397,7 +397,7 @@ __kernel void winograd_inverse_transform_4x4(OUT_OF_RANGE_PARAMS
__private const int round_hw,
__private const int round_w,
__private const float relux_max_limit,
__private const float leakyrelu_coefficient) {
__private const float activation_coefficient) {
const int width_idx = get_global_id(0);
const int height_idx = get_global_id(1);
......@@ -517,23 +517,23 @@ __kernel void winograd_inverse_transform_4x4(OUT_OF_RANGE_PARAMS
out3[3] += bias_value;
#endif
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0[0] = do_activation(out0[0], relux_max_limit, leakyrelu_coefficient);
out0[1] = do_activation(out0[1], relux_max_limit, leakyrelu_coefficient);
out0[2] = do_activation(out0[2], relux_max_limit, leakyrelu_coefficient);
out0[3] = do_activation(out0[3], relux_max_limit, leakyrelu_coefficient);
out1[0] = do_activation(out1[0], relux_max_limit, leakyrelu_coefficient);
out1[1] = do_activation(out1[1], relux_max_limit, leakyrelu_coefficient);
out1[2] = do_activation(out1[2], relux_max_limit, leakyrelu_coefficient);
out1[3] = do_activation(out1[3], relux_max_limit, leakyrelu_coefficient);
out2[0] = do_activation(out2[0], relux_max_limit, leakyrelu_coefficient);
out2[1] = do_activation(out2[1], relux_max_limit, leakyrelu_coefficient);
out2[2] = do_activation(out2[2], relux_max_limit, leakyrelu_coefficient);
out2[3] = do_activation(out2[3], relux_max_limit, leakyrelu_coefficient);
out3[0] = do_activation(out3[0], relux_max_limit, leakyrelu_coefficient);
out3[1] = do_activation(out3[1], relux_max_limit, leakyrelu_coefficient);
out3[2] = do_activation(out3[2], relux_max_limit, leakyrelu_coefficient);
out3[3] = do_activation(out3[3], relux_max_limit, leakyrelu_coefficient);
#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) || defined(USE_ELU)
out0[0] = do_activation(out0[0], relux_max_limit, activation_coefficient);
out0[1] = do_activation(out0[1], relux_max_limit, activation_coefficient);
out0[2] = do_activation(out0[2], relux_max_limit, activation_coefficient);
out0[3] = do_activation(out0[3], relux_max_limit, activation_coefficient);
out1[0] = do_activation(out1[0], relux_max_limit, activation_coefficient);
out1[1] = do_activation(out1[1], relux_max_limit, activation_coefficient);
out1[2] = do_activation(out1[2], relux_max_limit, activation_coefficient);
out1[3] = do_activation(out1[3], relux_max_limit, activation_coefficient);
out2[0] = do_activation(out2[0], relux_max_limit, activation_coefficient);
out2[1] = do_activation(out2[1], relux_max_limit, activation_coefficient);
out2[2] = do_activation(out2[2], relux_max_limit, activation_coefficient);
out2[3] = do_activation(out2[3], relux_max_limit, activation_coefficient);
out3[0] = do_activation(out3[0], relux_max_limit, activation_coefficient);
out3[1] = do_activation(out3[1], relux_max_limit, activation_coefficient);
out3[2] = do_activation(out3[2], relux_max_limit, activation_coefficient);
out3[3] = do_activation(out3[3], relux_max_limit, activation_coefficient);
#endif
const int num = min(4, out_width - out_width_idx);
......
......@@ -46,7 +46,7 @@ class OpenCLConv2dKernel {
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const int winograd_blk_size,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLConv2dKernel);
......
......@@ -39,7 +39,7 @@ class OpenCLDeconv2dKernel {
const int *padding_data,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const std::vector<index_t> &output_shape,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLDeconv2dKernel);
......
......@@ -38,7 +38,7 @@ class OpenCLDepthwiseConv2dKernel {
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLDepthwiseConv2dKernel);
};
......
......@@ -42,7 +42,7 @@ class OpenCLDepthwiseDeconv2dKernel {
const int group,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const std::vector <index_t> &output_shape,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLDepthwiseDeconv2dKernel);
......
......@@ -34,7 +34,7 @@ class OpenCLFullyConnectedKernel {
const Tensor *bias,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLFullyConnectedKernel);
};
......
......@@ -99,12 +99,12 @@ MaceStatus ActivationKernel::Compute(
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *(input->opencl_image()));
if (activation_ == PRELU || activation_ == ELU) {
if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha);
kernel_.setArg(idx++, *(alpha->opencl_image()));
}
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, leakyrelu_coefficient_);
kernel_.setArg(idx++, activation_coefficient_);
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
......
......@@ -35,9 +35,9 @@ class ActivationKernel : public OpenCLActivationKernel {
public:
ActivationKernel(ActivationType type,
float relux_max_limit,
float leakyrelu_coefficient)
float activation_coefficient)
: activation_(type), relux_max_limit_(relux_max_limit),
leakyrelu_coefficient_(leakyrelu_coefficient) {}
activation_coefficient_(activation_coefficient) {}
MaceStatus Compute(
OpContext *context,
......@@ -48,7 +48,8 @@ class ActivationKernel : public OpenCLActivationKernel {
private:
ActivationType activation_;
float relux_max_limit_;
float leakyrelu_coefficient_;
float activation_coefficient_;
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector<index_t> input_shape_;
......
......@@ -22,11 +22,11 @@ namespace image {
BatchNormKernel::BatchNormKernel(const float epsilon,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient)
const float activation_coefficient)
: epsilon_(epsilon),
activation_(activation),
relux_max_limit_(relux_max_limit),
leakyrelu_coefficient_(leakyrelu_coefficient) {}
activation_coefficient_(activation_coefficient) {}
MaceStatus BatchNormKernel::Compute(
OpContext *context,
......@@ -75,6 +75,8 @@ MaceStatus BatchNormKernel::Compute(
break;
case LEAKYRELU:built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU: built_options.emplace("-DUSE_ELU");
break;
default:LOG(FATAL) << "Unknown activation type: " << activation_;
}
......@@ -99,7 +101,7 @@ MaceStatus BatchNormKernel::Compute(
}
kernel_.setArg(idx++, *(output->opencl_image()));
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, leakyrelu_coefficient_);
kernel_.setArg(idx++, activation_coefficient_);
input_shape_ = input->shape();
}
......
......@@ -37,7 +37,7 @@ class BatchNormKernel : public OpenCLBatchNormKernel {
const float epsilon,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient);
const float activation_coefficient);
MaceStatus Compute(OpContext *context,
const Tensor *input,
const Tensor *scale,
......@@ -50,7 +50,7 @@ class BatchNormKernel : public OpenCLBatchNormKernel {
const float epsilon_;
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
const float activation_coefficient_;
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector<index_t> input_shape_;
......
......@@ -68,7 +68,7 @@ MaceStatus Conv2dKernel::Compute(
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const int wino_blk_size,
Tensor *output) {
index_t kernel_h = filter->dim(2);
......@@ -116,7 +116,7 @@ MaceStatus Conv2dKernel::Compute(
paddings.data(),
activation,
relux_max_limit,
leakyrelu_coefficient,
activation_coefficient,
wino_blk_size,
&input_shape_,
output,
......@@ -135,7 +135,7 @@ MaceStatus Conv2dKernel::Compute(
dilations,
activation,
relux_max_limit,
leakyrelu_coefficient,
activation_coefficient,
&input_shape_,
output,
&kwg_size_[0]);
......@@ -153,7 +153,7 @@ MaceStatus Conv2dKernel::Compute(
dilations,
activation,
relux_max_limit,
leakyrelu_coefficient,
activation_coefficient,
&input_shape_,
output,
&kwg_size_[0]);
......@@ -171,7 +171,7 @@ MaceStatus Conv2dKernel::Compute(
dilations,
activation,
relux_max_limit,
leakyrelu_coefficient,
activation_coefficient,
&input_shape_,
output,
&kwg_size_[0]);
......
......@@ -39,7 +39,7 @@ extern MaceStatus Conv2dK1x1(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size);
......@@ -55,7 +55,7 @@ extern MaceStatus Conv2dK3x3(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size);
......@@ -71,7 +71,7 @@ extern MaceStatus Conv2d(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size);
......@@ -84,7 +84,7 @@ extern MaceStatus WinogradConv2dK3x3S1(OpContext *context,
const int *padding,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const int wino_blk_size,
std::vector<index_t> *prev_input_shape,
Tensor *output,
......@@ -111,7 +111,7 @@ class Conv2dKernel : public OpenCLConv2dKernel {
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const int wino_blk_size,
Tensor *output) override;
......
......@@ -77,7 +77,7 @@ MaceStatus Conv2dK1x1(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size) {
......@@ -135,6 +135,10 @@ MaceStatus Conv2dK1x1(OpContext *context,
built_options.emplace("-DUSE_LEAKYRELU");
break;
}
case ELU: {
built_options.emplace("-DUSE_ELU");
break;
}
default: {
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -165,7 +169,7 @@ MaceStatus Conv2dK1x1(OpContext *context,
kernel->setArg(idx++, *(output->opencl_image()));
// FIXME handle flexable data type: half not supported
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
kernel->setArg(idx++, static_cast<int>(input_height));
kernel->setArg(idx++, static_cast<int>(input_width));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
......
......@@ -70,7 +70,7 @@ MaceStatus Conv2dK3x3(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size) {
......@@ -120,6 +120,10 @@ MaceStatus Conv2dK3x3(OpContext *context,
built_options.emplace("-DUSE_LEAKYRELU");
break;
}
case ELU: {
built_options.emplace("-DUSE_ELU");
break;
}
default: {
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -149,7 +153,7 @@ MaceStatus Conv2dK3x3(OpContext *context,
}
kernel->setArg(idx++, *(output->opencl_image()));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
kernel->setArg(idx++, static_cast<int>(input->dim(1)));
kernel->setArg(idx++, static_cast<int>(input->dim(2)));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
......
......@@ -78,7 +78,7 @@ MaceStatus Conv2d(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size) {
......@@ -128,6 +128,9 @@ MaceStatus Conv2d(OpContext *context,
built_options.emplace("-DUSE_LEAKYRELU");
break;
}
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default: {
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -157,7 +160,7 @@ MaceStatus Conv2d(OpContext *context,
}
kernel->setArg(idx++, *(output->opencl_image()));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(1)));
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(2)));
kernel->setArg(idx++, static_cast<uint32_t>(input_channel_blocks));
......
......@@ -29,7 +29,7 @@ MaceStatus Deconv2dKernel::Compute(
const int *padding_data,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const std::vector<index_t> &output_shape,
Tensor *output) {
std::vector<size_t> output_image_shape;
......@@ -90,6 +90,9 @@ MaceStatus Deconv2dKernel::Compute(
case LEAKYRELU:
built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -117,7 +120,7 @@ MaceStatus Deconv2dKernel::Compute(
}
kernel_.setArg(idx++, *(output->opencl_image()));
kernel_.setArg(idx++, relux_max_limit);
kernel_.setArg(idx++, leakyrelu_coefficient);
kernel_.setArg(idx++, activation_coefficient);
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(3)));
......
......@@ -41,7 +41,7 @@ class Deconv2dKernel : public OpenCLDeconv2dKernel {
const int *padding_data,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const std::vector<index_t> &output_shape,
Tensor *output) override;
......
......@@ -73,7 +73,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size) {
......@@ -129,6 +129,9 @@ MaceStatus DepthwiseConv2d(OpContext *context,
case LEAKYRELU:
built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -162,7 +165,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
}
kernel->setArg(idx++, *(output->opencl_image()));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
kernel->setArg(idx++, static_cast<int16_t>(input_height));
kernel->setArg(idx++, static_cast<int16_t>(input_width));
kernel->setArg(idx++, static_cast<int16_t>(input_channel_blocks));
......@@ -204,7 +207,7 @@ MaceStatus DepthwiseConv2dKernel::Compute(
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) {
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
......@@ -243,7 +246,7 @@ MaceStatus DepthwiseConv2dKernel::Compute(
return depthwise::DepthwiseConv2d(
context, &kernel_, input, filter, bias, strides[0], paddings.data(),
dilations, activation, relux_max_limit, leakyrelu_coefficient,
dilations, activation, relux_max_limit, activation_coefficient,
&input_shape_, output, &kwg_size_);
}
......
......@@ -39,7 +39,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size);
......@@ -58,7 +58,7 @@ class DepthwiseConv2dKernel : public OpenCLDepthwiseConv2dKernel {
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) override;
private:
......
......@@ -30,7 +30,7 @@ MaceStatus DepthwiseDeconv2dKernel::Compute(
const int group,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const std::vector<index_t> &output_shape,
Tensor *output) {
const index_t batch = output_shape[0];
......@@ -95,6 +95,9 @@ MaceStatus DepthwiseDeconv2dKernel::Compute(
case LEAKYRELU:
built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -122,7 +125,7 @@ MaceStatus DepthwiseDeconv2dKernel::Compute(
}
kernel_.setArg(idx++, *(output->opencl_image()));
kernel_.setArg(idx++, relux_max_limit);
kernel_.setArg(idx++, leakyrelu_coefficient);
kernel_.setArg(idx++, activation_coefficient);
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(height));
......
......@@ -42,7 +42,7 @@ class DepthwiseDeconv2dKernel : public OpenCLDepthwiseDeconv2dKernel {
const int group,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const std::vector<index_t> &output_shape,
Tensor *output) override;
......
......@@ -27,7 +27,7 @@ MaceStatus FullyConnectedKernel::Compute(
const Tensor *bias,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) {
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)};
std::vector<size_t> output_image_shape;
......@@ -71,6 +71,9 @@ MaceStatus FullyConnectedKernel::Compute(
case LEAKYRELU:
built_options.emplace("-DUSE_LEAKYRELU");
break;
case ELU:
built_options.emplace("-DUSE_ELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
......@@ -121,7 +124,7 @@ MaceStatus FullyConnectedKernel::Compute(
kernel_.setArg(idx++, static_cast<int>(RoundUpDiv4(input->dim(3))));
kernel_.setArg(idx++, static_cast<int>(output_blocks));
kernel_.setArg(idx++, relux_max_limit);
kernel_.setArg(idx++, leakyrelu_coefficient);
kernel_.setArg(idx++, activation_coefficient);
input_shape_ = input->shape();
}
......
......@@ -40,7 +40,7 @@ class FullyConnectedKernel : public OpenCLFullyConnectedKernel {
const Tensor *bias,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
Tensor *output) override;
private:
......
......@@ -113,7 +113,7 @@ MaceStatus WinogradOutputTransform(OpContext *context,
const int wino_blk_size,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const bool input_changed,
Tensor *output_tensor,
uint32_t *kwg_size,
......@@ -213,7 +213,7 @@ MaceStatus WinogradOutputTransform(OpContext *context,
kernel->setArg(idx++, static_cast<uint32_t>(round_h * round_w));
kernel->setArg(idx++, static_cast<uint32_t>(round_w));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, leakyrelu_coefficient);
kernel->setArg(idx++, activation_coefficient);
}
const std::vector<uint32_t> lws = {*kwg_size / 8, 8, 0};
std::string tuning_key =
......@@ -237,7 +237,7 @@ extern MaceStatus WinogradConv2dK3x3S1(OpContext *context,
const int *paddings,
const ActivationType activation,
const float relux_max_limit,
const float leakyrelu_coefficient,
const float activation_coefficient,
const int wino_blk_size,
std::vector<index_t> *prev_input_shape,
Tensor *output,
......@@ -355,7 +355,7 @@ extern MaceStatus WinogradConv2dK3x3S1(OpContext *context,
MACE_RETURN_IF_ERROR(WinogradOutputTransform(
context, kernels[2], mm_output.get(), bias,
round_h, round_w, wino_blk_size, activation, relux_max_limit,
leakyrelu_coefficient, input_changed, output, kwg_size[2],
activation_coefficient, input_changed, output, kwg_size[2],
&t_output_future))
MergeMultipleFutureWaitFn({t_input_future, mm_future, t_output_future},
......
......@@ -81,7 +81,7 @@ void Activation<T>::DoActivation(const OpContext *context,
for (index_t i = 0; i < size; ++i) {
*output_ptr =
std::max<float>(*input_ptr, 0.f)
+ std::min(*input_ptr, 0.f) * leakyrelu_coefficient_;
+ std::min(*input_ptr, 0.f) * activation_coefficient_;
++input_ptr;
++output_ptr;
}
......@@ -104,6 +104,19 @@ void Activation<T>::DoActivation(const OpContext *context,
break;
}
case ELU: {
for (index_t i = 0; i < input->size(); ++i) {
const auto in_val = *input_ptr++;
if (in_val < 0) {
*output_ptr = (std::exp(in_val) - 1) * activation_coefficient_;
} else {
*output_ptr = in_val;
}
output_ptr++;
}
break;
}
case NOOP:break;
default:MACE_NOT_IMPLEMENTED;
......
......@@ -32,6 +32,8 @@ class ReverseOp<DeviceType::CPU, T> : public Operation {
const Tensor *input = this->Input(INPUT);
const Tensor *axis = this->Input(AXIS);
Tensor *output = this->Output(OUTPUT);
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard axis_guard(axis);
MACE_CHECK(axis->dim_size() == 1, "Only support reverse in one axis now");
......
......@@ -31,17 +31,17 @@ MaceStatus Activation::Init(const framework::Operator *op) {
atcivation_type = "NOOP";
}
const float max_limit = op->GetArgByName("max_limit", 0.0f);
const float leakyrelu_coefficient =
op->GetArgByName("leakyrelu_coefficient", 0.0f);
const float activation_coefficient =
op->GetArgByName("activation_coefficient", 0.0f);
return Init(atcivation_type, max_limit, leakyrelu_coefficient);
return Init(atcivation_type, max_limit, activation_coefficient);
}
MaceStatus Activation::Init(const char *type, const float limit,
const float leakyrelu_coefficient) {
const float activation_coefficient) {
type_ = StringToActivationType(type);
limit_ = limit;
leakyrelu_coefficient_ = leakyrelu_coefficient;
activation_coefficient_ = activation_coefficient;
return MACE_SUCCESS;
}
......@@ -71,7 +71,7 @@ MaceStatus Activation::Compute(const mifloat *input_ptr,
for (int32_t i = 0; i < size; ++i) {
float input = *input_ptr;
*output_ptr = base::max(input, 0.f) +
base::min(input, 0.f) * leakyrelu_coefficient_; // NOLINT
base::min(input, 0.f) * activation_coefficient_; // NOLINT
++input_ptr;
++output_ptr;
}
......
......@@ -44,7 +44,7 @@ class Activation {
MaceStatus Init(const framework::Operator *op);
MaceStatus Init(const char *type, const float limit,
const float leakyrelu_coefficient);
const float activation_coefficient);
MaceStatus Compute(const mifloat *input_ptr,
const int32_t size, mifloat *output_ptr);
ActivationType GetActivationType();
......@@ -55,7 +55,7 @@ class Activation {
private:
ActivationType type_;
float limit_;
float leakyrelu_coefficient_;
float activation_coefficient_;
};
} // namespace ops
......
......@@ -67,7 +67,7 @@ void TestSimpleLeakyRelu() {
framework::SubstituteOp substitude_op;
substitude_op.AddInput(input, input_dims, 4)
.AddRepeatArg("activation", activation_type, arg_type_len)
.AddArg("leakyrelu_coefficient", 0.1f)
.AddArg("activation_coefficient", 0.1f)
.AddOutput(output, output_dims, 4);
activation_op.Init(NULL, reinterpret_cast<framework::OpContext *>(
......
......@@ -250,14 +250,14 @@ void EluBenchmark(int iters, int batch, int channels, int height, int width) {
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Alpha", {channels}, true);
OpDefBuilder("Activation", "EluBM")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "ELU")
.AddFloatArg("activation_coefficient", 1.0)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
......
......@@ -66,7 +66,7 @@ void TestSimpleLeakyRelu() {
.Input("Input")
.Output("Output")
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.Finalize(net.NewOperatorDef());
// Run
......@@ -243,15 +243,14 @@ void TestSimpleElu() {
// Add input data
net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, -5, -4, -4, -3, 3, -2, 2, -1, -1, 0, 0});
net.AddInputFromArray<D, float>("Alpha", {2}, {2.0, 3.0}, true);
{-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0});
if (D == DeviceType::GPU) {
OpDefBuilder("Activation", "EluTest")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "ELU")
.AddFloatArg("activation_coefficient", 2.0)
.Finalize(net.NewOperatorDef());
// Run
......@@ -261,9 +260,9 @@ void TestSimpleElu() {
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Activation", "EluTest")
.Input("InputNCHW")
.Input("Alpha")
.Output("OutputNCHW")
.AddStringArg("activation", "ELU")
.AddFloatArg("activation_coefficient", 2.0)
.Finalize(net.NewOperatorDef());
// Run
......@@ -275,9 +274,9 @@ void TestSimpleElu() {
auto expected = net.CreateTensor<float>(
{2, 2, 2, 2},
{-1.998176236068891, 7, -1.9950424956466672, 6, -1.986524106001829,
-2.9797861590027437, -1.9633687222225316, -2.9450530833337973,
5, -1.9633687222225316, 4,
-1.900425863264272, 3, -1.7293294335267746, 2, -1.2642411176571153,
-1.896361676485673, 0, 0});
1, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
......@@ -439,7 +438,7 @@ void TestBFloat16(const char *activation) {
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", activation)
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.AddFloatArg("max_limit", 6)
.AddIntArg("T", static_cast<int>(DT_FLOAT))
.Finalize(net.NewOperatorDef());
......@@ -450,7 +449,7 @@ void TestBFloat16(const char *activation) {
.Input("BF16Alpha")
.Output("BF16Output")
.AddStringArg("activation", activation)
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.AddFloatArg("max_limit", 6)
.AddIntArg("T", static_cast<int>(DT_BFLOAT16))
.Finalize(net.NewOperatorDef());
......
......@@ -108,7 +108,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.Finalize(net.NewOperatorDef());
// run cpu
......@@ -131,7 +131,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::GPU);
......
......@@ -684,7 +684,7 @@ void TestComplexConvNxN(const std::vector<index_t> &shape,
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
......@@ -709,7 +709,7 @@ void TestComplexConvNxN(const std::vector<index_t> &shape,
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("wino_block_size", wino_blk_size)
.Finalize(net.NewOperatorDef());
......
......@@ -421,7 +421,7 @@ void TestComplexDeconvNxN(const int batch,
.AddIntsArg("padding_values", paddings)
.AddIntArg("framework_type", model_type)
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
......@@ -459,7 +459,7 @@ void TestComplexDeconvNxN(const int batch,
.AddIntsArg("padding_values", paddings)
.AddIntArg("framework_type", model_type)
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
......
......@@ -261,7 +261,7 @@ void TestNxNS12(const index_t height, const index_t width) {
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.Finalize(net.NewOperatorDef());
// Run on cpu
......@@ -284,7 +284,7 @@ void TestNxNS12(const index_t height, const index_t width) {
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1)
.AddFloatArg("activation_coefficient", 0.1)
.Finalize(net.NewOperatorDef());
net.RunOp(DeviceType::GPU);
......
......@@ -206,7 +206,7 @@ void RandomTest(index_t batch,
.AddIntArg("group", channel)
.AddIntsArg("dilations", {1, 1})
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1f)
.AddFloatArg("activation_coefficient", 0.1f)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
// Run
......@@ -229,7 +229,7 @@ void RandomTest(index_t batch,
.AddIntsArg("padding_values", {padding, padding})
.AddIntArg("group", channel)
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1f)
.AddFloatArg("activation_coefficient", 0.1f)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
......
......@@ -138,7 +138,7 @@ void Random(const index_t batch,
.Input("Bias")
.Output("OutputNCHW")
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1f)
.AddFloatArg("activation_coefficient", 0.1f)
.Finalize(net.NewOperatorDef());
// run cpu
......@@ -158,7 +158,7 @@ void Random(const index_t batch,
.Input("Bias")
.Output("Output")
.AddStringArg("activation", "LEAKYRELU")
.AddFloatArg("leakyrelu_coefficient", 0.1f)
.AddFloatArg("activation_coefficient", 0.1f)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
......
......@@ -225,7 +225,7 @@ class MaceKeyword(object):
mace_element_type_str = 'type'
mace_activation_type_str = 'activation'
mace_activation_max_limit_str = 'max_limit'
mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient'
mace_activation_coefficient_str = 'activation_coefficient'
mace_resize_size_str = 'size'
mace_batch_to_space_crops_str = 'crops'
mace_paddings_str = 'paddings'
......
......@@ -166,6 +166,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'TanH': ActivationType.TANH,
'Sigmoid': ActivationType.SIGMOID,
'Clip': ActivationType.RELUX,
'ELU': ActivationType.ELU,
}
def __init__(self, option, src_model_file, src_weight_file):
......@@ -181,6 +182,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'Sigmoid': self.convert_activation,
'PReLU': self.convert_activation,
'Clip': self.convert_activation,
'ELU': self.convert_activation,
'Pooling': self.convert_pooling,
'Concat': self.convert_concat,
'Slice': self.convert_slice,
......@@ -509,7 +511,7 @@ class CaffeConverter(base_converter.ConverterInterface):
negative_slope = caffe_op.layer.relu_param.negative_slope
if negative_slope != 0:
param_arg = op.arg.add()
param_arg.name = MaceKeyword.mace_activation_leakyrelu_coefficient_str # noqa
param_arg.name = MaceKeyword.mace_activation_coefficient_str
param_arg.f = caffe_op.layer.relu_param.negative_slope
type_arg.s = six.b(ActivationType.LEAKYRELU.name)
elif caffe_op.type == 'ReLU6':
......@@ -522,6 +524,11 @@ class CaffeConverter(base_converter.ConverterInterface):
limit_arg = op.arg.add()
limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = caffe_op.layer.clip_param.max
elif caffe_op.type == 'ELU':
# TODO(luxuhui): we have not verify ELU for Caffe
param_arg = op.arg.add()
param_arg.name = MaceKeyword.mace_activation_coefficient_str
param_arg.f = caffe_op.layer.elu_param.alpha
def convert_folded_batchnorm(self, caffe_op):
op = self.convert_general_op(caffe_op)
......
......@@ -629,18 +629,16 @@ class OnnxConverter(base_converter.ConverterInterface):
type_arg.s = six.b(self.activation_type[node.op_type].name)
if "alpha" in node.attrs:
alpha_tensor_name = node.name + '_alpha'
alpha_value = np.array([node.attrs["alpha"]])
self.add_tensor(alpha_tensor_name, alpha_value.reshape(-1).shape,
mace_pb2.DT_FLOAT, alpha_value)
op.input.extend([alpha_tensor_name])
alpha_value = node.attrs["alpha"]
else:
if node.op_type == OnnxOpType.LeakyRelu.name:
alpha_value = 0.01
elif node.op_type == OnnxOpType.Elu.name:
alpha_value = 1.0
else:
alpha_value = 0
alpha_arg = op.arg.add()
alpha_arg.name = MaceKeyword.mace_activation_leakyrelu_coefficient_str
alpha_arg.name = MaceKeyword.mace_activation_coefficient_str
alpha_arg.f = alpha_value
def convert_affine(self, node):
......
......@@ -70,6 +70,7 @@ TFSupportedOps = [
'DepthwiseConv2dNative',
'DepthToSpace',
'Div',
'Elu',
'Equal',
'ExpandDims',
'ExtractImagePatches',
......@@ -190,6 +191,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
}
activation_type = {
TFOpType.Elu.name: ActivationType.ELU,
TFOpType.Relu.name: ActivationType.RELU,
TFOpType.Relu6.name: ActivationType.RELUX,
TFOpType.Tanh.name: ActivationType.TANH,
......@@ -232,6 +234,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.DepthwiseConv2dNative.name: self.convert_conv2d,
TFOpType.DepthToSpace.name: self.convert_space_depth,
TFOpType.Div.name: self.convert_elementwise,
TFOpType.Elu.name: self.convert_activation,
TFOpType.Equal.name: self.convert_elementwise,
TFOpType.ExpandDims.name: self.convert_expand_dims,
TFOpType.ExtractImagePatches.name:
......@@ -668,11 +671,18 @@ class TensorflowConverter(base_converter.ConverterInterface):
limit_arg = op.arg.add()
limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0
elif tf_op.type == TFOpType.LeakyRelu.name:
elif tf_op.type == TFOpType.LeakyRelu.name or \
tf_op.type == TFOpType.Elu.name:
alpha_arg = op.arg.add()
alpha_arg.name = \
MaceKeyword.mace_activation_leakyrelu_coefficient_str
alpha_arg.f = tf_op.get_attr(tf_alpha_str)
MaceKeyword.mace_activation_coefficient_str
try:
alpha_arg.f = tf_op.get_attr(tf_alpha_str)
except ValueError:
if tf_op.type == TFOpType.LeakyRelu.name:
alpha_arg.f = 0.0
else:
alpha_arg.f = 1.0
def convert_fill(self, tf_op):
op = self.convert_general_op(tf_op)
......
......@@ -985,10 +985,7 @@ class Transformer(base_converter.ConverterInterface):
[ActivationType.RELU.name,
ActivationType.RELUX.name])
else:
fold_consumer = (
act_type != ActivationType.PRELU.name
and act_type != ActivationType.ELU.name
)
fold_consumer = (act_type != ActivationType.PRELU.name)
# during quantization, only fold relu/relux
if (self._option.quantize_stat or self._option.quantize) \
and act_type not in [ActivationType.RELU.name,
......@@ -1002,7 +999,7 @@ class Transformer(base_converter.ConverterInterface):
if arg.name == MaceKeyword.mace_activation_type_str \
or arg.name == \
MaceKeyword.mace_activation_max_limit_str \
or arg.name == MaceKeyword.mace_activation_leakyrelu_coefficient_str: # noqa
or arg.name == MaceKeyword.mace_activation_coefficient_str: # noqa
op.arg.extend([arg])
self.replace_quantize_info(op, consumer_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册