提交 fff4d5fd 编写于 作者: L liuqi

Fix the bug activation parameters' type not match.

上级 08533627
......@@ -25,6 +25,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
auto runtime = OpenCLRuntime::Global();
std::string tuning_key_prefix;
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation");
built_options.emplace("-Dactivation=" + kernel_name);
......@@ -33,18 +34,23 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
switch (activation_) {
case RELU:
tuning_key_prefix = "relu_opencl_kernel_";
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
tuning_key_prefix = "relux_opencl_kernel_";
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
tuning_key_prefix = "prelu_opencl_kernel_";
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
tuning_key_prefix = "tanh_opencl_kernel_";
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
tuning_key_prefix = "sigmoid_opencl_kernel_";
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
......@@ -55,8 +61,8 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
int idx = 0;
activation_kernel.setArg(
idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
activation_kernel.setArg(idx++, relux_max_limit_);
activation_kernel.setArg(idx++, prelu_alpha_);
activation_kernel.setArg(idx++, static_cast<float>(relux_max_limit_));
activation_kernel.setArg(idx++, static_cast<float>(prelu_alpha_));
activation_kernel.setArg(idx++,
*(static_cast<cl::Image2D *>(output->buffer())));
......@@ -65,7 +71,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key =
Concat("relu_opencl_kernel_", activation_, output->dim(0), output->dim(1),
Concat(tuning_key_prefix, output->dim(0), output->dim(1),
output->dim(2), output->dim(3));
TuningOrRun3DKernel(activation_kernel, tuning_key, gws, lws, future);
}
......
......@@ -121,18 +121,24 @@ std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape,
std::string DtToCLDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:return "float";
case DT_HALF:return "half";
default:LOG(FATAL) << "Unsupported data type";
case DT_FLOAT:
return "float";
case DT_HALF:
return "half";
default:
LOG(FATAL) << "Unsupported data type";
return "";
}
}
std::string DtToCLCMDDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:return "f";
case DT_HALF:return "h";
default:LOG(FATAL) << "Not supported data type for opencl cmd data type";
case DT_FLOAT:
return "f";
case DT_HALF:
return "h";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
return "";
}
}
......@@ -140,8 +146,10 @@ std::string DtToCLCMDDt(const DataType dt) {
std::string DtToUpstreamCLDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_HALF:return "float";
default:LOG(FATAL) << "Unsupported data type";
case DT_HALF:
return "float";
default:
LOG(FATAL) << "Unsupported data type";
return "";
}
}
......@@ -149,8 +157,10 @@ std::string DtToUpstreamCLDt(const DataType dt) {
std::string DtToUpstreamCLCMDDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_HALF:return "f";
default:LOG(FATAL) << "Not supported data type for opencl cmd data type";
case DT_HALF:
return "f";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
return "";
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册