From 899e26f49e58496b440788d123f7bab40274e1ab Mon Sep 17 00:00:00 2001 From: liuqi Date: Tue, 6 Mar 2018 17:27:23 +0800 Subject: [PATCH] Fix prelu benchmark bug. --- mace/kernels/activation.h | 3 ++- mace/ops/activation_benchmark.cc | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mace/kernels/activation.h b/mace/kernels/activation.h index b768eb28..c9f8dac5 100644 --- a/mace/kernels/activation.h +++ b/mace/kernels/activation.h @@ -116,7 +116,8 @@ class ActivationFunctor { const T *input_ptr = input->data(); T *output_ptr = output->mutable_data(); if (activation_ == PRELU) { - const T *alpha_ptr = alpha == nullptr ? nullptr : alpha->data(); + MACE_CHECK(alpha != nullptr) << "PReLU's alpha parameter shouldn't be null"; + const T *alpha_ptr = alpha->data(); PReLUActivation(input_ptr, output->size(), input->dim(3), alpha_ptr, output_ptr); } else { DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_); diff --git a/mace/ops/activation_benchmark.cc b/mace/ops/activation_benchmark.cc index 1037bdcb..4e904fce 100644 --- a/mace/ops/activation_benchmark.cc +++ b/mace/ops/activation_benchmark.cc @@ -139,23 +139,26 @@ static void PreluBenchmark( // Add input data net.AddRandomInput("Input", {batch, height, width, channels}); + net.AddRandomInput("Alpha", {channels}); if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Alpha", "AlphaImage", + kernels::BufferType::ARGUMENT); OpDefBuilder("Activation", "PreluBM") .Input("InputImage") + .Input("AlphaImage") .Output("Output") .AddStringArg("activation", "PRELU") - .AddFloatArg("alpha", 2.0) .Finalize(net.NewOperatorDef()); } else { OpDefBuilder("Activation", "PreluBM") .Input("Input") + .Input("Alpha") .Output("Output") .AddStringArg("activation", "PRELU") - .AddFloatArg("alpha", 2.0) .Finalize(net.NewOperatorDef()); } -- GitLab