提交 c22dee0a 编写于 作者: L liutuo

fix leakyrelu

上级 5d980e20
......@@ -96,8 +96,8 @@ void DoActivation(const T *input_ptr,
case LEAKYRELU:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i],
static_cast<T>(0)) * relux_max_limit;
output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0))
+ std::min(input_ptr[i], static_cast<T>(0)) * relux_max_limit;
}
break;
default:
......
......@@ -75,18 +75,20 @@ void LeakyReluNeon(const float *input, const float alpha,
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i <= size - 4; i += 4) {
float32x4_t v = vld1q_f32(input + i);
float32x4_t u = vminq_f32(v, vzero);;
v = vmaxq_f32(v, vzero);
v = vmulq_f32(v, valpha);
v = vmlaq_f32(v, valpha, u);
vst1q_f32(output + i, v);
}
// remain
for (index_t i = (size >> 2) << 2; i < size; ++i) {
output[i] = std::max(input[i], 0.f) * alpha;
output[i] = std::max(input[i], 0.f) + std::min(input[i], 0.f) * alpha;
}
#else
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output[i] = std::max(input[i], 0.f) * alpha;
output[i] = std::max(input[i], 0.f) + std::min(input[i], 0.f) * alpha;
}
#endif
}
......
......@@ -309,7 +309,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.Gather.name: self.convert_gather,
OnnxOpType.Gemm.name: self.convert_fully_connected,
OnnxOpType.Gemm.name: self.convert_gemm,
OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
OnnxOpType.GlobalMaxPool.name: self.convert_reduce,
OnnxOpType.Identity.name: self.convert_identity,
......@@ -415,6 +415,7 @@ class OnnxConverter(base_converter.ConverterInterface):
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(kernel)
# TODO: Does not support AutoPad yet.
if 'pads' in attrs:
pads = attrs['pads']
if len(pads) == 4:
......@@ -587,6 +588,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if "alpha" in node.attrs:
alpha_value = node.attrs["alpha"]
else:
if node.op_type == OnnxOpType.LeakyRelu.name:
alpha_value = 0.01
else:
alpha_value = 0
alpha_arg = op.arg.add()
......@@ -894,7 +898,8 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.float_data[:] = tensor_data.flat
tensor.dims[:] = tensor_data.shape
def convert_fully_connected(self, node):
def convert_gemm(self, node):
# only supports FullyConnected Style Gemm for now.
trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0
trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0
shape_a = self._graph_shapes_dict[node.inputs[0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册