提交 19e3a65c 编写于 作者: 刘托

Merge branch 'fix-leakyrelu' into 'master'

fix leakyrelu

See merge request !917
...@@ -96,8 +96,8 @@ void DoActivation(const T *input_ptr, ...@@ -96,8 +96,8 @@ void DoActivation(const T *input_ptr,
case LEAKYRELU: case LEAKYRELU:
#pragma omp parallel for schedule(runtime) #pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i], output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0))
static_cast<T>(0)) * relux_max_limit; + std::min(input_ptr[i], static_cast<T>(0)) * relux_max_limit;
} }
break; break;
default: default:
......
...@@ -75,18 +75,20 @@ void LeakyReluNeon(const float *input, const float alpha, ...@@ -75,18 +75,20 @@ void LeakyReluNeon(const float *input, const float alpha,
#pragma omp parallel for schedule(runtime) #pragma omp parallel for schedule(runtime)
for (index_t i = 0; i <= size - 4; i += 4) { for (index_t i = 0; i <= size - 4; i += 4) {
float32x4_t v = vld1q_f32(input + i); float32x4_t v = vld1q_f32(input + i);
float32x4_t u = vminq_f32(v, vzero);;
v = vmaxq_f32(v, vzero); v = vmaxq_f32(v, vzero);
v = vmulq_f32(v, valpha); v = vmlaq_f32(v, valpha, u);
vst1q_f32(output + i, v); vst1q_f32(output + i, v);
} }
// remain // remain
for (index_t i = (size >> 2) << 2; i < size; ++i) { for (index_t i = (size >> 2) << 2; i < size; ++i) {
output[i] = std::max(input[i], 0.f) * alpha; output[i] = std::max(input[i], 0.f) + std::min(input[i], 0.f) * alpha;
} }
#else #else
#pragma omp parallel for schedule(runtime) #pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output[i] = std::max(input[i], 0.f) * alpha; output[i] = std::max(input[i], 0.f) + std::min(input[i], 0.f) * alpha;
} }
#endif #endif
} }
......
...@@ -309,7 +309,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -309,7 +309,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.Gather.name: self.convert_gather, OnnxOpType.Gather.name: self.convert_gather,
OnnxOpType.Gemm.name: self.convert_fully_connected, OnnxOpType.Gemm.name: self.convert_gemm,
OnnxOpType.GlobalAveragePool.name: self.convert_reduce, OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
OnnxOpType.GlobalMaxPool.name: self.convert_reduce, OnnxOpType.GlobalMaxPool.name: self.convert_reduce,
OnnxOpType.Identity.name: self.convert_identity, OnnxOpType.Identity.name: self.convert_identity,
...@@ -415,6 +415,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -415,6 +415,7 @@ class OnnxConverter(base_converter.ConverterInterface):
kernels_arg.name = MaceKeyword.mace_kernel_str kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(kernel) kernels_arg.ints.extend(kernel)
# TODO: Does not support AutoPad yet.
if 'pads' in attrs: if 'pads' in attrs:
pads = attrs['pads'] pads = attrs['pads']
if len(pads) == 4: if len(pads) == 4:
...@@ -588,7 +589,10 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -588,7 +589,10 @@ class OnnxConverter(base_converter.ConverterInterface):
if "alpha" in node.attrs: if "alpha" in node.attrs:
alpha_value = node.attrs["alpha"] alpha_value = node.attrs["alpha"]
else: else:
alpha_value = 0 if node.op_type == OnnxOpType.LeakyRelu.name:
alpha_value = 0.01
else:
alpha_value = 0
alpha_arg = op.arg.add() alpha_arg = op.arg.add()
alpha_arg.name = MaceKeyword.mace_activation_max_limit_str alpha_arg.name = MaceKeyword.mace_activation_max_limit_str
alpha_arg.f = alpha_value alpha_arg.f = alpha_value
...@@ -894,7 +898,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -894,7 +898,8 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.float_data[:] = tensor_data.flat tensor.float_data[:] = tensor_data.flat
tensor.dims[:] = tensor_data.shape tensor.dims[:] = tensor_data.shape
def convert_fully_connected(self, node): def convert_gemm(self, node):
# only supports FullyConnected Style Gemm for now.
trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0 trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0
trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0 trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0
shape_a = self._graph_shapes_dict[node.inputs[0]] shape_a = self._graph_shapes_dict[node.inputs[0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册