提交 7dbe9f70 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4986 optimize prule

Merge pull request !4986 from fuzhiye/tmp
...@@ -14,24 +14,200 @@ ...@@ -14,24 +14,200 @@
* limitations under the License. * limitations under the License.
*/ */
#include "nnacl/fp32/prelu.h" #include "nnacl/fp32/prelu.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
void DoPRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id) { void PRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id) {
int block = (int)(prelu_param_->input_num_ / prelu_param_->op_parameter_.thread_num_); float *negetive_slope_value = prelu_param_->slope_;
int start = task_id * block; int c4 = prelu_param_->channel_num_ / C4NUM;
int end = start + block; int channel_num = prelu_param_->channel_num_;
if (task_id == prelu_param_->op_parameter_.thread_num_ - 1) { for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) {
end = prelu_param_->input_num_; float *input_ptr = input + j * TILE_NUM * channel_num;
} float *output_ptr = input_ptr;
for (int i = start; i < end; i++) { #ifdef ENABLE_NEON
if (input[i] > 0) { for (int i = 0; i < c4; i++) {
output[i] = input[i]; int c_offset = i * C4NUM;
} else { float32x4_t slope_value = vld1q_f32(negetive_slope_value + c_offset);
if (!prelu_param_->channelShared) { float32x4_t v1 = vld1q_f32(input_ptr + c_offset);
int temp = i % prelu_param_->channel_num_; float32x4_t v2 = vld1q_f32(input_ptr + c_offset + channel_num);
output[i] = input[i] * prelu_param_->slope_[temp]; float32x4_t v3 = vld1q_f32(input_ptr + c_offset + 2 * channel_num);
} else { float32x4_t v4 = vld1q_f32(input_ptr + c_offset + 3 * channel_num);
output[i] = input[i] * prelu_param_->slope_[0]; float32x4_t v5 = vld1q_f32(input_ptr + c_offset + 4 * channel_num);
float32x4_t v6 = vld1q_f32(input_ptr + c_offset + 5 * channel_num);
float32x4_t v7 = vld1q_f32(input_ptr + c_offset + 6 * channel_num);
float32x4_t v8 = vld1q_f32(input_ptr + c_offset + 7 * channel_num);
float32x4_t t1 = vmulq_f32(v1, slope_value);
float32x4_t t2 = vmulq_f32(v2, slope_value);
float32x4_t t3 = vmulq_f32(v3, slope_value);
float32x4_t t4 = vmulq_f32(v4, slope_value);
float32x4_t t5 = vmulq_f32(v5, slope_value);
float32x4_t t6 = vmulq_f32(v6, slope_value);
float32x4_t t7 = vmulq_f32(v7, slope_value);
float32x4_t t8 = vmulq_f32(v8, slope_value);
uint32x4_t flag1 = vclezq_f32(v1);
uint32x4_t flag2 = vclezq_f32(v2);
uint32x4_t flag3 = vclezq_f32(v3);
uint32x4_t flag4 = vclezq_f32(v4);
uint32x4_t flag5 = vclezq_f32(v5);
uint32x4_t flag6 = vclezq_f32(v6);
uint32x4_t flag7 = vclezq_f32(v7);
uint32x4_t flag8 = vclezq_f32(v8);
float32x4_t r1 = vbslq_f32(flag1, t1, v1);
float32x4_t r2 = vbslq_f32(flag2, t2, v2);
float32x4_t r3 = vbslq_f32(flag3, t3, v3);
float32x4_t r4 = vbslq_f32(flag4, t4, v4);
float32x4_t r5 = vbslq_f32(flag5, t5, v5);
float32x4_t r6 = vbslq_f32(flag6, t6, v6);
float32x4_t r7 = vbslq_f32(flag7, t7, v7);
float32x4_t r8 = vbslq_f32(flag8, t8, v8);
vst1q_f32(output_ptr + c_offset, r1);
vst1q_f32(output_ptr + c_offset + channel_num, r2);
vst1q_f32(output_ptr + c_offset + 2 * channel_num, r3);
vst1q_f32(output_ptr + c_offset + 3 * channel_num, r4);
vst1q_f32(output_ptr + c_offset + 4 * channel_num, r5);
vst1q_f32(output_ptr + c_offset + 5 * channel_num, r6);
vst1q_f32(output_ptr + c_offset + 6 * channel_num, r7);
vst1q_f32(output_ptr + c_offset + 7 * channel_num, r8);
} // c4 -1 loop
#else
for (int i = 0; i < TILE_NUM; ++i) {
int tile_offset = i * channel_num;
for (int k = 0; k < c4; ++k) {
int c4_offset = tile_offset + k * C4NUM;
int slope_offset = k * C4NUM;
for (int l = 0; l < C4NUM; ++l) {
float in_data = input_ptr[c4_offset + l];
output_ptr[c4_offset + l] =
(in_data < 0 ? in_data : 0) * negetive_slope_value[slope_offset + l] + (in_data > 0 ? in_data : 0);
}
}
} // c4 - 1 loop
#endif
int c_s = c4 * C4NUM;
for (int m = 0; m < TILE_NUM; ++m) {
int offset = m * channel_num;
for (int k = c_s; k < channel_num; ++k) {
int c4_offset = offset + k;
float in_data = input_ptr[c4_offset];
if (in_data >= 0) {
output_ptr[c4_offset] = in_data;
} else {
output_ptr[c4_offset] = in_data * negetive_slope_value[k];
}
} }
} // res loop
}
}
void PReluShareChannel(float *input, float *output, PReluParameter *prelu_param_, int task_id) {
for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) {
int cal_index;
int cal_per_time;
#ifdef ENABLE_NEON
float32x4_t slope_value = vdupq_n_f32(prelu_param_->slope_[0]);
float32x4_t zero_value = vdupq_n_f32(0);
#endif
#ifdef ENABLE_ARM64
cal_index = j * 64;
cal_per_time = 64;
#elif ENABLE_ARM32
cal_index = j * 32;
cal_per_time = 32;
#else
cal_index = j * 32;
cal_per_time = 32;
#endif
float *input_ptr = input + cal_index;
float *output_ptr = input + cal_index;
#ifdef ENABLE_ARM64
float32x4_t v1 = vld1q_f32(input_ptr);
float32x4_t v2 = vld1q_f32(input_ptr + 4);
float32x4_t v3 = vld1q_f32(input_ptr + 8);
float32x4_t v4 = vld1q_f32(input_ptr + 12);
float32x4_t v5 = vld1q_f32(input_ptr + 16);
float32x4_t v6 = vld1q_f32(input_ptr + 20);
float32x4_t v7 = vld1q_f32(input_ptr + 24);
float32x4_t v8 = vld1q_f32(input_ptr + 28);
float32x4_t v9 = vld1q_f32(input_ptr + 32);
float32x4_t v10 = vld1q_f32(input_ptr + 36);
float32x4_t v11 = vld1q_f32(input_ptr + 40);
float32x4_t v12 = vld1q_f32(input_ptr + 44);
float32x4_t v13 = vld1q_f32(input_ptr + 48);
float32x4_t v14 = vld1q_f32(input_ptr + 52);
float32x4_t v15 = vld1q_f32(input_ptr + 56);
float32x4_t v16 = vld1q_f32(input_ptr + 60);
float32x4_t t1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value));
float32x4_t t2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value));
float32x4_t t3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value));
float32x4_t t4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value));
float32x4_t t5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value));
float32x4_t t6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value));
float32x4_t t7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value));
float32x4_t t8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value));
float32x4_t t9 = vaddq_f32(vmulq_f32(vminq_f32(v9, zero_value), slope_value), vmaxq_f32(v9, zero_value));
float32x4_t t10 = vaddq_f32(vmulq_f32(vminq_f32(v10, zero_value), slope_value), vmaxq_f32(v10, zero_value));
float32x4_t t11 = vaddq_f32(vmulq_f32(vminq_f32(v11, zero_value), slope_value), vmaxq_f32(v11, zero_value));
float32x4_t t12 = vaddq_f32(vmulq_f32(vminq_f32(v12, zero_value), slope_value), vmaxq_f32(v12, zero_value));
float32x4_t t13 = vaddq_f32(vmulq_f32(vminq_f32(v13, zero_value), slope_value), vmaxq_f32(v13, zero_value));
float32x4_t t14 = vaddq_f32(vmulq_f32(vminq_f32(v14, zero_value), slope_value), vmaxq_f32(v14, zero_value));
float32x4_t t15 = vaddq_f32(vmulq_f32(vminq_f32(v15, zero_value), slope_value), vmaxq_f32(v15, zero_value));
float32x4_t t16 = vaddq_f32(vmulq_f32(vminq_f32(v16, zero_value), slope_value), vmaxq_f32(v16, zero_value));
vst1q_f32(output_ptr, t1);
vst1q_f32(output_ptr + 4, t2);
vst1q_f32(output_ptr + 8, t3);
vst1q_f32(output_ptr + 12, t4);
vst1q_f32(output_ptr + 16, t5);
vst1q_f32(output_ptr + 20, t6);
vst1q_f32(output_ptr + 24, t7);
vst1q_f32(output_ptr + 28, t8);
vst1q_f32(output_ptr + 32, t9);
vst1q_f32(output_ptr + 36, t10);
vst1q_f32(output_ptr + 40, t11);
vst1q_f32(output_ptr + 44, t12);
vst1q_f32(output_ptr + 48, t13);
vst1q_f32(output_ptr + 52, t14);
vst1q_f32(output_ptr + 56, t15);
vst1q_f32(output_ptr + 60, t16);
#elif ENABLE_ARM32
float32x4_t v1 = vld1q_f32(input_ptr);
float32x4_t v2 = vld1q_f32(input_ptr + 4);
float32x4_t v3 = vld1q_f32(input_ptr + 8);
float32x4_t v4 = vld1q_f32(input_ptr + 12);
float32x4_t v5 = vld1q_f32(input_ptr + 16);
float32x4_t v6 = vld1q_f32(input_ptr + 20);
float32x4_t v7 = vld1q_f32(input_ptr + 24);
float32x4_t v8 = vld1q_f32(input_ptr + 28);
float32x4_t t1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value));
float32x4_t t2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value));
float32x4_t t3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value));
float32x4_t t4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value));
float32x4_t t5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value));
float32x4_t t6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value));
float32x4_t t7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value));
float32x4_t t8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value));
vst1q_f32(output_ptr, t1);
vst1q_f32(output_ptr + 4, t2);
vst1q_f32(output_ptr + 8, t3);
vst1q_f32(output_ptr + 12, t4);
vst1q_f32(output_ptr + 16, t5);
vst1q_f32(output_ptr + 20, t6);
vst1q_f32(output_ptr + 24, t7);
vst1q_f32(output_ptr + 28, t8);
#else
for (int i = 0; i < cal_per_time; ++i) {
float data = input_ptr[i];
output_ptr[i] = (data < 0 ? data : 0) * prelu_param_->slope_[0] + (data > 0 ? data : 0);
} }
#endif
} }
} }
...@@ -22,7 +22,9 @@ ...@@ -22,7 +22,9 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void DoPRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id); void PRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id);
void PReluShareChannel(float *input, float *output, PReluParameter *prelu_param_, int task_id);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -52,7 +52,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in ...@@ -52,7 +52,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
int plane_c4_res = b % C4NUM; int plane_c4_res = b % C4NUM;
int src_plane_offset = src_tile_offset + plane_c4_block * tile_num * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM; int src_plane_offset = src_tile_offset + plane_c4_block * tile_num * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM;
int weight_plane_offset = int weight_plane_offset =
weight_oc4_offset + plane_c4_block * tile_num * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM; weight_oc4_offset + plane_c4_block * C4NUM * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM;
for (int i = 0; i < ic4; i++) { for (int i = 0; i < ic4; i++) {
int src_ic4_offset = src_plane_offset + i * tile_num * C4NUM * C4NUM; int src_ic4_offset = src_plane_offset + i * tile_num * C4NUM * C4NUM;
int weight_ic4_offset = weight_plane_offset + i * C4NUM * C4NUM * C4NUM; int weight_ic4_offset = weight_plane_offset + i * C4NUM * C4NUM * C4NUM;
......
...@@ -22,6 +22,7 @@ typedef struct PReluParameter { ...@@ -22,6 +22,7 @@ typedef struct PReluParameter {
OpParameter op_parameter_; OpParameter op_parameter_;
float *slope_; float *slope_;
bool channelShared; bool channelShared;
int tile_block_;
int channel_num_; int channel_num_;
int input_num_; int input_num_;
} PReluParameter; } PReluParameter;
......
...@@ -29,8 +29,8 @@ using mindspore::schema::PrimitiveType_CaffePReLU; ...@@ -29,8 +29,8 @@ using mindspore::schema::PrimitiveType_CaffePReLU;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace { namespace {
int PReluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { int PReluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto PReludata = reinterpret_cast<PReluCPUKernel *>(cdata); auto PRelu = reinterpret_cast<PReluCPUKernel *>(cdata);
auto ret = PReludata->DoExcute(task_id); auto ret = PRelu->DoExcute(task_id);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "PReluRun error task_id[" << task_id << "] error_code[" << ret << "]"; MS_LOG(ERROR) << "PReluRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;
...@@ -42,7 +42,67 @@ int PReluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { ...@@ -42,7 +42,67 @@ int PReluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
int PReluCPUKernel::Init() { return RET_OK; } int PReluCPUKernel::Init() { return RET_OK; }
int PReluCPUKernel::DoExcute(int task_id) { int PReluCPUKernel::DoExcute(int task_id) {
DoPRelu(input_data, output_data, prelu_param_, task_id); if (prelu_param_->channelShared) {
PReluShareChannel(input_data_, output_data_, prelu_param_, task_id);
} else {
PRelu(input_data_, output_data_, prelu_param_, task_id);
}
return RET_OK;
}
int PReluCPUKernel::ProcessInput() {
// input tensor
auto input_tensor = in_tensors_[0];
auto in_shape = input_tensor->shape();
auto n_dim = in_shape.size();
auto channel_num = in_shape.at(n_dim - 1);
int input_plane = 1;
for (size_t i = 0; i < n_dim - 1; ++i) {
input_plane *= in_shape[i];
}
int tile_block = UP_DIV(input_plane, TILE_NUM);
prelu_param_->input_num_ = input_tensor->ElementsNum();
prelu_param_->tile_block_ = tile_block;
prelu_param_->channel_num_ = channel_num;
input_data_ =
reinterpret_cast<float *>(context_->allocator->Malloc(tile_block * TILE_NUM * channel_num * sizeof(float)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR;
}
memcpy(input_data_, ori_input_, tile_block * TILE_NUM * channel_num * sizeof(float));
return RET_OK;
}
int PReluCPUKernel::ProcessShareChannelInput() {
// input tensor
auto input_tensor = in_tensors_[0];
prelu_param_->input_num_ = input_tensor->ElementsNum();
#ifdef ENABLE_ARM64
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 64);
input_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * 64 * sizeof(float)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR;
}
memcpy(input_data_, ori_input_, prelu_param_->tile_block_ * 64 * sizeof(float));
#elif ENABLE_ARM32
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 32);
input_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * 32 * sizeof(float)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR;
}
memcpy(input_data_, ori_input_, prelu_param_->tile_block_ * 32 * sizeof(float));
#else
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 32);
input_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * 32 * sizeof(float)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR;
}
memcpy(input_data_, ori_input_, prelu_param_->tile_block_ * 32 * sizeof(float));
#endif
return RET_OK; return RET_OK;
} }
...@@ -52,28 +112,44 @@ int PReluCPUKernel::Run() { ...@@ -52,28 +112,44 @@ int PReluCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret; return prepare_ret;
} }
auto input = in_tensors_[0]; MS_ASSERT(in_shape.size() >= 2);
auto input1 = in_tensors_[1]; auto input_tensor = in_tensors_[0];
ori_input_ = reinterpret_cast<float *>(input_tensor->Data());
output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
if (prelu_param_->channelShared) {
auto ret = ProcessShareChannelInput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ProcessShareChannel failed.";
return ret;
}
} else {
auto ret = ProcessInput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Process failed.";
return ret;
}
}
prelu_param_->input_num_ = input->ElementsNum(); // negative slope tensor
input_data = reinterpret_cast<float *>(input->Data()); auto negative_slope_tensor = in_tensors_.at(1);
output_data = reinterpret_cast<float *>(out_tensors_[0]->Data()); prelu_param_->slope_ = reinterpret_cast<float *>(negative_slope_tensor->Data());
auto channels = input->shape();
prelu_param_->slope_ = reinterpret_cast<float *>(input1->Data());
prelu_param_->channel_num_ = channels.at(channels.size() - 1);
auto ret = LiteBackendParallelLaunch(PReluRun, this, prelu_param_->op_parameter_.thread_num_); auto ret = LiteBackendParallelLaunch(PReluRun, this, prelu_param_->op_parameter_.thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "PReluDwRun error: error_code[" << ret << "]"; MS_LOG(ERROR) << "PRelu Run error: error_code[" << ret << "]";
context_->allocator->Free(input_data_);
return RET_ERROR; return RET_ERROR;
} }
memcpy(output_data_, input_data_, prelu_param_->input_num_ * sizeof(float));
context_->allocator->Free(input_data_);
return RET_OK; return RET_OK;
} }
kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *param,
OpParameter *param, const lite::Context *ctx, const lite::Context *ctx, const kernel::KernelKey &desc,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "input param is nullptr!"; MS_LOG(ERROR) << "input param is nullptr!";
...@@ -87,8 +163,8 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te ...@@ -87,8 +163,8 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << param->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << param->name_
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(param->type_)); << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(param->type_));
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
......
...@@ -25,8 +25,8 @@ namespace mindspore::kernel { ...@@ -25,8 +25,8 @@ namespace mindspore::kernel {
class PReluCPUKernel : public LiteKernel { class PReluCPUKernel : public LiteKernel {
public: public:
PReluCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, PReluCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) { : LiteKernel(parameter, inputs, outputs, ctx, primitive) {
prelu_param_ = reinterpret_cast<PReluParameter *>(op_parameter_); prelu_param_ = reinterpret_cast<PReluParameter *>(op_parameter_);
} }
...@@ -36,11 +36,14 @@ class PReluCPUKernel : public LiteKernel { ...@@ -36,11 +36,14 @@ class PReluCPUKernel : public LiteKernel {
int ReSize() override { return 0; } int ReSize() override { return 0; }
int Run() override; int Run() override;
int DoExcute(int task_id); int DoExcute(int task_id);
int ProcessShareChannelInput();
int ProcessInput();
private: private:
PReluParameter *prelu_param_; PReluParameter *prelu_param_;
float *input_data = nullptr; float *ori_input_ = nullptr;
float *output_data = nullptr; float *input_data_ = nullptr;
float *output_data_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册