提交 fff4d5fd 编写于 作者: L liuqi

Fix the bug activation parameters' type not match.

上级 08533627
...@@ -25,6 +25,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -25,6 +25,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
std::string tuning_key_prefix;
std::set<std::string> built_options; std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation"); std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation");
built_options.emplace("-Dactivation=" + kernel_name); built_options.emplace("-Dactivation=" + kernel_name);
...@@ -33,18 +34,23 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -33,18 +34,23 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
switch (activation_) { switch (activation_) {
case RELU: case RELU:
tuning_key_prefix = "relu_opencl_kernel_";
built_options.emplace("-DUSE_RELU"); built_options.emplace("-DUSE_RELU");
break; break;
case RELUX: case RELUX:
tuning_key_prefix = "relux_opencl_kernel_";
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU: case PRELU:
tuning_key_prefix = "prelu_opencl_kernel_";
built_options.emplace("-DUSE_PRELU"); built_options.emplace("-DUSE_PRELU");
break; break;
case TANH: case TANH:
tuning_key_prefix = "tanh_opencl_kernel_";
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
case SIGMOID: case SIGMOID:
tuning_key_prefix = "sigmoid_opencl_kernel_";
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: defeult:
...@@ -55,8 +61,8 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -55,8 +61,8 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
int idx = 0; int idx = 0;
activation_kernel.setArg( activation_kernel.setArg(
idx++, *(static_cast<const cl::Image2D *>(input->buffer()))); idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
activation_kernel.setArg(idx++, relux_max_limit_); activation_kernel.setArg(idx++, static_cast<float>(relux_max_limit_));
activation_kernel.setArg(idx++, prelu_alpha_); activation_kernel.setArg(idx++, static_cast<float>(prelu_alpha_));
activation_kernel.setArg(idx++, activation_kernel.setArg(idx++,
*(static_cast<cl::Image2D *>(output->buffer()))); *(static_cast<cl::Image2D *>(output->buffer())));
...@@ -65,7 +71,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -65,7 +71,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
static_cast<uint32_t>(height * batch)}; static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1}; const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::string tuning_key = std::string tuning_key =
Concat("relu_opencl_kernel_", activation_, output->dim(0), output->dim(1), Concat(tuning_key_prefix, output->dim(0), output->dim(1),
output->dim(2), output->dim(3)); output->dim(2), output->dim(3));
TuningOrRun3DKernel(activation_kernel, tuning_key, gws, lws, future); TuningOrRun3DKernel(activation_kernel, tuning_key, gws, lws, future);
} }
......
...@@ -121,18 +121,24 @@ std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape, ...@@ -121,18 +121,24 @@ std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape,
std::string DtToCLDt(const DataType dt) { std::string DtToCLDt(const DataType dt) {
switch (dt) { switch (dt) {
case DT_FLOAT:return "float"; case DT_FLOAT:
case DT_HALF:return "half"; return "float";
default:LOG(FATAL) << "Unsupported data type"; case DT_HALF:
return "half";
default:
LOG(FATAL) << "Unsupported data type";
return ""; return "";
} }
} }
std::string DtToCLCMDDt(const DataType dt) { std::string DtToCLCMDDt(const DataType dt) {
switch (dt) { switch (dt) {
case DT_FLOAT:return "f"; case DT_FLOAT:
case DT_HALF:return "h"; return "f";
default:LOG(FATAL) << "Not supported data type for opencl cmd data type"; case DT_HALF:
return "h";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
return ""; return "";
} }
} }
...@@ -140,8 +146,10 @@ std::string DtToCLCMDDt(const DataType dt) { ...@@ -140,8 +146,10 @@ std::string DtToCLCMDDt(const DataType dt) {
std::string DtToUpstreamCLDt(const DataType dt) { std::string DtToUpstreamCLDt(const DataType dt) {
switch (dt) { switch (dt) {
case DT_FLOAT: case DT_FLOAT:
case DT_HALF:return "float"; case DT_HALF:
default:LOG(FATAL) << "Unsupported data type"; return "float";
default:
LOG(FATAL) << "Unsupported data type";
return ""; return "";
} }
} }
...@@ -149,8 +157,10 @@ std::string DtToUpstreamCLDt(const DataType dt) { ...@@ -149,8 +157,10 @@ std::string DtToUpstreamCLDt(const DataType dt) {
std::string DtToUpstreamCLCMDDt(const DataType dt) { std::string DtToUpstreamCLCMDDt(const DataType dt) {
switch (dt) { switch (dt) {
case DT_FLOAT: case DT_FLOAT:
case DT_HALF:return "f"; case DT_HALF:
default:LOG(FATAL) << "Not supported data type for opencl cmd data type"; return "f";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
return ""; return "";
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册