From fff4d5fdf9c7bd34952ad6813f91bd3cb91670a8 Mon Sep 17 00:00:00 2001 From: liuqi Date: Fri, 26 Jan 2018 15:51:05 +0800 Subject: [PATCH] Fix the bug activation parameters' type not match. --- mace/kernels/opencl/activation_opencl.cc | 12 +++++++--- mace/kernels/opencl/helper.cc | 30 ++++++++++++++++-------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/mace/kernels/opencl/activation_opencl.cc b/mace/kernels/opencl/activation_opencl.cc index 17688e9c..480d718d 100644 --- a/mace/kernels/opencl/activation_opencl.cc +++ b/mace/kernels/opencl/activation_opencl.cc @@ -25,6 +25,7 @@ void ActivationFunctor::operator()(const Tensor *input, auto runtime = OpenCLRuntime::Global(); + std::string tuning_key_prefix; std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation"); built_options.emplace("-Dactivation=" + kernel_name); @@ -33,18 +34,23 @@ void ActivationFunctor::operator()(const Tensor *input, built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); switch (activation_) { case RELU: + tuning_key_prefix = "relu_opencl_kernel_"; built_options.emplace("-DUSE_RELU"); break; case RELUX: + tuning_key_prefix = "relux_opencl_kernel_"; built_options.emplace("-DUSE_RELUX"); break; case PRELU: + tuning_key_prefix = "prelu_opencl_kernel_"; built_options.emplace("-DUSE_PRELU"); break; case TANH: + tuning_key_prefix = "tanh_opencl_kernel_"; built_options.emplace("-DUSE_TANH"); break; case SIGMOID: + tuning_key_prefix = "sigmoid_opencl_kernel_"; built_options.emplace("-DUSE_SIGMOID"); break; defeult: @@ -55,8 +61,8 @@ void ActivationFunctor::operator()(const Tensor *input, int idx = 0; activation_kernel.setArg( idx++, *(static_cast(input->buffer()))); - activation_kernel.setArg(idx++, relux_max_limit_); - activation_kernel.setArg(idx++, prelu_alpha_); + activation_kernel.setArg(idx++, static_cast(relux_max_limit_)); + activation_kernel.setArg(idx++, static_cast(prelu_alpha_)); activation_kernel.setArg(idx++, *(static_cast(output->buffer()))); @@ -65,7 +71,7 @@ void ActivationFunctor::operator()(const Tensor *input, static_cast(height * batch)}; const std::vector lws = {8, 16, 8, 1}; std::string tuning_key = - Concat("relu_opencl_kernel_", activation_, output->dim(0), output->dim(1), + Concat(tuning_key_prefix, output->dim(0), output->dim(1), output->dim(2), output->dim(3)); TuningOrRun3DKernel(activation_kernel, tuning_key, gws, lws, future); } diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index 84e102e0..e3f91549 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -121,18 +121,24 @@ std::vector CalWinogradShape(const std::vector &shape, std::string DtToCLDt(const DataType dt) { switch (dt) { - case DT_FLOAT:return "float"; - case DT_HALF:return "half"; - default:LOG(FATAL) << "Unsupported data type"; + case DT_FLOAT: + return "float"; + case DT_HALF: + return "half"; + default: + LOG(FATAL) << "Unsupported data type"; return ""; } } std::string DtToCLCMDDt(const DataType dt) { switch (dt) { - case DT_FLOAT:return "f"; - case DT_HALF:return "h"; - default:LOG(FATAL) << "Not supported data type for opencl cmd data type"; + case DT_FLOAT: + return "f"; + case DT_HALF: + return "h"; + default: + LOG(FATAL) << "Not supported data type for opencl cmd data type"; return ""; } } @@ -140,8 +146,10 @@ std::string DtToCLCMDDt(const DataType dt) { std::string DtToUpstreamCLDt(const DataType dt) { switch (dt) { case DT_FLOAT: - case DT_HALF:return "float"; - default:LOG(FATAL) << "Unsupported data type"; + case DT_HALF: + return "float"; + default: + LOG(FATAL) << "Unsupported data type"; return ""; } } @@ -149,8 +157,10 @@ std::string DtToUpstreamCLDt(const DataType dt) { std::string DtToUpstreamCLCMDDt(const DataType dt) { switch (dt) { case DT_FLOAT: - case DT_HALF:return "f"; - default:LOG(FATAL) << "Not supported data type for opencl cmd data type"; + case DT_HALF: + return "f"; + default: + LOG(FATAL) << "Not supported data type for opencl cmd data type"; return ""; } } -- GitLab