提交 7ad87bee 编写于 作者: L liuqi

Fix activation opencl tuning key bug.

上级 5501056d
...@@ -135,6 +135,7 @@ class ActivationFunctor<DeviceType::OPENCL, T> { ...@@ -135,6 +135,7 @@ class ActivationFunctor<DeviceType::OPENCL, T> {
T relux_max_limit_; T relux_max_limit_;
T prelu_alpha_; T prelu_alpha_;
cl::Kernel kernel_; cl::Kernel kernel_;
std::string tuning_key_prefix_;
}; };
} // namespace kernels } // namespace kernels
......
...@@ -22,7 +22,6 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -22,7 +22,6 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const index_t channels = input->dim(3); const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels); const index_t channel_blocks = RoundUpDiv4(channels);
std::string tuning_key_prefix;
if (kernel_.get() == nullptr) { if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
...@@ -35,23 +34,23 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -35,23 +34,23 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
switch (activation_) { switch (activation_) {
case RELU: case RELU:
tuning_key_prefix = "relu_opencl_kernel_"; tuning_key_prefix_ = "relu_opencl_kernel_";
built_options.emplace("-DUSE_RELU"); built_options.emplace("-DUSE_RELU");
break; break;
case RELUX: case RELUX:
tuning_key_prefix = "relux_opencl_kernel_"; tuning_key_prefix_ = "relux_opencl_kernel_";
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU: case PRELU:
tuning_key_prefix = "prelu_opencl_kernel_"; tuning_key_prefix_ = "prelu_opencl_kernel_";
built_options.emplace("-DUSE_PRELU"); built_options.emplace("-DUSE_PRELU");
break; break;
case TANH: case TANH:
tuning_key_prefix = "tanh_opencl_kernel_"; tuning_key_prefix_ = "tanh_opencl_kernel_";
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
case SIGMOID: case SIGMOID:
tuning_key_prefix = "sigmoid_opencl_kernel_"; tuning_key_prefix_ = "sigmoid_opencl_kernel_";
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
default: default:
...@@ -60,12 +59,10 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -60,12 +59,10 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
kernel_ = kernel_ =
runtime->BuildKernel("activation", kernel_name, built_options); runtime->BuildKernel("activation", kernel_name, built_options);
int idx = 0; int idx = 0;
kernel_.setArg( kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++, static_cast<float>(relux_max_limit_)); kernel_.setArg(idx++, static_cast<float>(relux_max_limit_));
kernel_.setArg(idx++, static_cast<float>(prelu_alpha_)); kernel_.setArg(idx++, static_cast<float>(prelu_alpha_));
kernel_.setArg(idx++, kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
*(static_cast<cl::Image2D *>(output->buffer())));
} }
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
...@@ -73,7 +70,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -73,7 +70,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key = std::string tuning_key =
Concat(tuning_key_prefix, output->dim(0), output->dim(1), Concat(tuning_key_prefix_, output->dim(0), output->dim(1),
output->dim(2), output->dim(3)); output->dim(2), output->dim(3));
TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future); TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册