提交 7ad87bee 编写于 作者: L liuqi

Fix activation opencl tuning key bug.

上级 5501056d
......@@ -135,6 +135,7 @@ class ActivationFunctor<DeviceType::OPENCL, T> {
T relux_max_limit_;
T prelu_alpha_;
cl::Kernel kernel_;
std::string tuning_key_prefix_;
};
} // namespace kernels
......
......@@ -22,7 +22,6 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
std::string tuning_key_prefix;
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
......@@ -35,23 +34,23 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
switch (activation_) {
case RELU:
tuning_key_prefix = "relu_opencl_kernel_";
tuning_key_prefix_ = "relu_opencl_kernel_";
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
tuning_key_prefix = "relux_opencl_kernel_";
tuning_key_prefix_ = "relux_opencl_kernel_";
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
tuning_key_prefix = "prelu_opencl_kernel_";
tuning_key_prefix_ = "prelu_opencl_kernel_";
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
tuning_key_prefix = "tanh_opencl_kernel_";
tuning_key_prefix_ = "tanh_opencl_kernel_";
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
tuning_key_prefix = "sigmoid_opencl_kernel_";
tuning_key_prefix_ = "sigmoid_opencl_kernel_";
built_options.emplace("-DUSE_SIGMOID");
break;
default:
......@@ -60,12 +59,10 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
kernel_ =
runtime->BuildKernel("activation", kernel_name, built_options);
int idx = 0;
kernel_.setArg(
idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, static_cast<float>(relux_max_limit_));
kernel_.setArg(idx++, static_cast<float>(prelu_alpha_));
kernel_.setArg(idx++,
*(static_cast<cl::Image2D *>(output->buffer())));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......@@ -73,7 +70,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key =
Concat(tuning_key_prefix, output->dim(0), output->dim(1),
Concat(tuning_key_prefix_, output->dim(0), output->dim(1),
output->dim(2), output->dim(3));
TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册