提交 34f21226 编写于 作者: C chenjianping

caffeprelu rename to prelu

上级 29070d60
...@@ -80,7 +80,7 @@ union PrimitiveType { ...@@ -80,7 +80,7 @@ union PrimitiveType {
Pad, Pad,
Maximum, Maximum,
Minimum, Minimum,
CaffePReLU, PReLU,
LeakyReLU, LeakyReLU,
ArgMax, ArgMax,
ArgMin, ArgMin,
...@@ -126,7 +126,6 @@ union PrimitiveType { ...@@ -126,7 +126,6 @@ union PrimitiveType {
Broadcast, Broadcast,
BroadcastTo, BroadcastTo,
Lrn, Lrn,
Prelu,
ZerosLike, ZerosLike,
TopK, TopK,
SpaceToDepth, SpaceToDepth,
......
...@@ -540,7 +540,7 @@ table MatMul { ...@@ -540,7 +540,7 @@ table MatMul {
transposeB : bool = false; transposeB : bool = false;
} }
table CaffePReLU { table PReLU {
channelShared : bool = false; channelShared : bool = false;
slope: [float]; slope: [float];
} }
...@@ -650,10 +650,6 @@ table Reduce { ...@@ -650,10 +650,6 @@ table Reduce {
mode: ReduceMode; mode: ReduceMode;
} }
table Prelu {
slope: [float];
}
table Transpose { table Transpose {
perm: [int]; perm: [int];
conjugate: bool = false; conjugate: bool = false;
......
...@@ -14,20 +14,20 @@ ...@@ -14,20 +14,20 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/ops/caffe_p_relu.h" #include "src/ops/p_relu.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value.AsCaffePReLU()->channelShared; } bool PReLU::GetChannelShared() const { return this->primitive_->value.AsPReLU()->channelShared; }
void CaffePReLU::SetChannelShared(bool channel_shared) { void PReLU::SetChannelShared(bool channel_shared) {
this->primitive_->value.AsCaffePReLU()->channelShared = channel_shared; this->primitive_->value.AsPReLU()->channelShared = channel_shared;
} }
#else #else
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value_as_CaffePReLU()->channelShared(); } bool PReLU::GetChannelShared() const { return this->primitive_->value_as_PReLU()->channelShared(); }
#endif #endif
} // namespace lite } // namespace lite
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_ #ifndef LITE_MINDSPORE_LITE_C_OPS_P_RELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_ #define LITE_MINDSPORE_LITE_C_OPS_P_RELU_H_
#include <vector> #include <vector>
#include <set> #include <set>
...@@ -26,21 +26,21 @@ ...@@ -26,21 +26,21 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class CaffePReLU : public Activation { class PReLU : public Activation {
public: public:
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(CaffePReLU, Activation); MS_DECLARE_PARENT(PReLU, Activation);
CaffePReLU() = default; PReLU() = default;
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {} explicit PReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
void SetChannelShared(bool channel_shared); void SetChannelShared(bool channel_shared);
#else #else
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {} explicit PReLU(schema::Primitive *primitive) : Activation(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) { schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024); flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_CaffePReLU(); auto attr = primitive->value_as_PReLU();
MS_ASSERT(attr != nullptr); MS_ASSERT(attr != nullptr);
auto slope = std::make_unique<std::vector<float>>(); auto slope = std::make_unique<std::vector<float>>();
...@@ -48,8 +48,8 @@ class CaffePReLU : public Activation { ...@@ -48,8 +48,8 @@ class CaffePReLU : public Activation {
slope->push_back(attr->slope()->data()[i]); slope->push_back(attr->slope()->data()[i]);
} }
auto val_offset = schema::CreateCaffePReLUDirect(fbb, attr->channelShared(), slope.release()); auto val_offset = schema::CreatePReLUDirect(fbb, attr->channelShared(), slope.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_CaffePReLU, val_offset.o); auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PReLU, val_offset.o);
fbb.Finish(prim_offset); fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer(); auto buf = fbb.GetBufferPointer();
...@@ -70,4 +70,4 @@ class CaffePReLU : public Activation { ...@@ -70,4 +70,4 @@ class CaffePReLU : public Activation {
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_ #endif // LITE_MINDSPORE_LITE_C_OPS_P_RELU_H_
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/prelu.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<float> Prelu::GetSlope() const { return this->primitive_->value.AsPrelu()->slope; }
void Prelu::SetSlope(const std::vector<float> &slope) { this->primitive_->value.AsPrelu()->slope = slope; }
#else
std::vector<float> Prelu::GetSlope() const {
auto fb_vector = this->primitive_->value_as_Prelu()->slope();
return std::vector<float>(fb_vector->begin(), fb_vector->end());
}
#endif
} // namespace lite
} // namespace mindspore
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_PRELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_PRELU_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "ir/dtype/type_id.h"
#include "src/ops/activation.h"
namespace mindspore {
namespace lite {
class Prelu : public Activation {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Prelu, PrimitiveC);
Prelu() = default;
explicit Prelu(schema::PrimitiveT *primitive) : Activation(primitive) {}
void SetSlope(const std::vector<float> &slope);
#else
explicit Prelu(schema::Primitive *primitive) : Activation(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Prelu();
MS_ASSERT(attr != nullptr);
auto slope = std::make_unique<std::vector<float>>();
for (int i = 0; i < static_cast<int>(attr->slope()->size()); i++) {
slope->push_back(attr->slope()->data()[i]);
}
auto val_offset = schema::CreatePreluDirect(fbb, slope.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Prelu, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
delete[] buf_bak;
fbb.Clear();
return prim;
}
#endif
std::vector<float> GetSlope() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_PRELU_H_
...@@ -72,8 +72,8 @@ ...@@ -72,8 +72,8 @@
#include "src/ops/gather_nd.h" #include "src/ops/gather_nd.h"
#include "src/ops/local_response_normalization.h" #include "src/ops/local_response_normalization.h"
#include "src/ops/pad.h" #include "src/ops/pad.h"
#include "src/ops/prelu.h" #include "src/ops/p_relu.h"
#include "src/ops/caffe_p_relu.h" #include "src/ops/leaky_relu.h"
#include "src/ops/reverse_sequence.h" #include "src/ops/reverse_sequence.h"
#include "src/ops/dedepthwise_conv2d.h" #include "src/ops/dedepthwise_conv2d.h"
#include "src/ops/depthwise_conv2d.h" #include "src/ops/depthwise_conv2d.h"
...@@ -346,10 +346,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT ...@@ -346,10 +346,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return new Minimum(primitive); return new Minimum(primitive);
case schema::PrimitiveType_StridedSlice: case schema::PrimitiveType_StridedSlice:
return new StridedSlice(primitive); return new StridedSlice(primitive);
case schema::PrimitiveType_Prelu: case schema::PrimitiveType_LeakyReLU:
return new Prelu(primitive); return new (std::nothrow) LeakyReLU(primitive);
case schema::PrimitiveType_CaffePReLU: case schema::PrimitiveType_PReLU:
return new CaffePReLU(primitive); return new (std::nothrow) PReLU(primitive);
case schema::PrimitiveType_Round: case schema::PrimitiveType_Round:
return new Round(primitive); return new Round(primitive);
case schema::PrimitiveType_Reverse: case schema::PrimitiveType_Reverse:
...@@ -554,10 +554,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * ...@@ -554,10 +554,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new Minimum(const_cast<schema::Primitive *>(primitive)); return new Minimum(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_StridedSlice: case schema::PrimitiveType_StridedSlice:
return new StridedSlice(const_cast<schema::Primitive *>(primitive)); return new StridedSlice(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Prelu: case schema::PrimitiveType_LeakyReLU:
return new Prelu(const_cast<schema::Primitive *>(primitive)); return new (std::nothrow) LeakyReLU(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_CaffePReLU: case schema::PrimitiveType_PReLU:
return new CaffePReLU(const_cast<schema::Primitive *>(primitive)); return new (std::nothrow) PReLU(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Round: case schema::PrimitiveType_Round:
return new Round(const_cast<schema::Primitive *>(primitive)); return new Round(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Reverse: case schema::PrimitiveType_Reverse:
......
...@@ -75,8 +75,8 @@ ...@@ -75,8 +75,8 @@
#include "src/ops/gather_nd.h" #include "src/ops/gather_nd.h"
#include "src/ops/local_response_normalization.h" #include "src/ops/local_response_normalization.h"
#include "src/ops/pad.h" #include "src/ops/pad.h"
#include "src/ops/prelu.h" #include "src/ops/leaky_relu.h"
#include "src/ops/caffe_p_relu.h" #include "src/ops/p_relu.h"
#include "src/ops/reverse_sequence.h" #include "src/ops/reverse_sequence.h"
#include "src/ops/dedepthwise_conv2d.h" #include "src/ops/dedepthwise_conv2d.h"
#include "src/ops/depthwise_conv2d.h" #include "src/ops/depthwise_conv2d.h"
...@@ -233,7 +233,7 @@ OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *prim ...@@ -233,7 +233,7 @@ OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *prim
} }
OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive) { OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive) {
auto param = dynamic_cast<const mindspore::lite::CaffePReLU *>(primitive); auto param = dynamic_cast<const mindspore::lite::PReLU *>(primitive);
PReluParameter *prelu_param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter))); PReluParameter *prelu_param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter)));
if (prelu_param == nullptr) { if (prelu_param == nullptr) {
MS_LOG(ERROR) << "malloc PReluParameter failed."; MS_LOG(ERROR) << "malloc PReluParameter failed.";
...@@ -246,7 +246,7 @@ OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive ...@@ -246,7 +246,7 @@ OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive
} }
OpParameter *PopulateLeakyReluParameter(const mindspore::lite::PrimitiveC *primitive) { OpParameter *PopulateLeakyReluParameter(const mindspore::lite::PrimitiveC *primitive) {
auto param = dynamic_cast<const mindspore::lite::Prelu *>(primitive); auto param = dynamic_cast<const mindspore::lite::LeakyReLU *>(primitive);
LeakyReluParameter *leaky_relu_param = reinterpret_cast<LeakyReluParameter *>(malloc(sizeof(LeakyReluParameter))); LeakyReluParameter *leaky_relu_param = reinterpret_cast<LeakyReluParameter *>(malloc(sizeof(LeakyReluParameter)));
if (leaky_relu_param == nullptr) { if (leaky_relu_param == nullptr) {
MS_LOG(ERROR) << "malloc LeakyReluParameter failed."; MS_LOG(ERROR) << "malloc LeakyReluParameter failed.";
...@@ -254,17 +254,14 @@ OpParameter *PopulateLeakyReluParameter(const mindspore::lite::PrimitiveC *primi ...@@ -254,17 +254,14 @@ OpParameter *PopulateLeakyReluParameter(const mindspore::lite::PrimitiveC *primi
} }
memset(leaky_relu_param, 0, sizeof(LeakyReluParameter)); memset(leaky_relu_param, 0, sizeof(LeakyReluParameter));
leaky_relu_param->op_parameter_.type_ = primitive->Type(); leaky_relu_param->op_parameter_.type_ = primitive->Type();
auto temp = param->GetSlope(); leaky_relu_param->slope_ = reinterpret_cast<float *>(malloc(sizeof(float)));
leaky_relu_param->slope_ = reinterpret_cast<float *>(malloc(temp.size() * sizeof(float)));
if (leaky_relu_param->slope_ == nullptr) { if (leaky_relu_param->slope_ == nullptr) {
MS_LOG(ERROR) << "malloc relu slope fail!"; MS_LOG(ERROR) << "malloc relu slope fail!";
free(leaky_relu_param); free(leaky_relu_param);
return nullptr; return nullptr;
} }
for (size_t i = 0; i < temp.size(); i++) { leaky_relu_param->slope_[0] = param->GetNegativeSlope();
leaky_relu_param->slope_[i] = temp[i]; leaky_relu_param->slope_num_ = 1;
}
leaky_relu_param->slope_num_ = temp.size();
return reinterpret_cast<OpParameter *>(leaky_relu_param); return reinterpret_cast<OpParameter *>(leaky_relu_param);
} }
...@@ -1598,8 +1595,8 @@ PopulateParameterRegistry::PopulateParameterRegistry() { ...@@ -1598,8 +1595,8 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_ScatterND] = PopulateScatterNDParameter; populate_parameter_funcs_[schema::PrimitiveType_ScatterND] = PopulateScatterNDParameter;
populate_parameter_funcs_[schema::PrimitiveType_Squeeze] = PopulateSqueezeParameter; populate_parameter_funcs_[schema::PrimitiveType_Squeeze] = PopulateSqueezeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter; populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter;
populate_parameter_funcs_[schema::PrimitiveType_CaffePReLU] = PopulatePReLUParameter; populate_parameter_funcs_[schema::PrimitiveType_PReLU] = PopulatePReLUParameter;
populate_parameter_funcs_[schema::PrimitiveType_Prelu] = PopulateLeakyReluParameter; populate_parameter_funcs_[schema::PrimitiveType_LeakyReLU] = PopulateLeakyReluParameter;
populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter;
populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter; populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter;
populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter; populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter;
......
...@@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_LeakyReLU; ...@@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_LeakyReLU;
namespace mindspore::kernel { namespace mindspore::kernel {
int LeakyReluBaseCPUKernel::Init() { return RET_OK; } int LeakyReluBaseCPUKernel::Init() { return RET_OK; }
kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuLeakyReluInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx, OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc, const kernel::KernelKey &desc,
...@@ -41,7 +41,7 @@ kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Te ...@@ -41,7 +41,7 @@ kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Te
MS_ASSERT(desc.type == schema::PrimitiveType_LeakyRelu); MS_ASSERT(desc.type == schema::PrimitiveType_LeakyRelu);
auto *kernel = new (std::nothrow) LeakyReluInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); auto *kernel = new (std::nothrow) LeakyReluInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new PreluCPUKernel fail!"; MS_LOG(ERROR) << "new LeakyReluInt8CPUKernel fail!";
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
...@@ -54,5 +54,5 @@ kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Te ...@@ -54,5 +54,5 @@ kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Te
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LeakyReLU, CpuPreluInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LeakyReLU, CpuLeakyReluInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -38,7 +38,10 @@ int GatherCPUKernel::Init() { ...@@ -38,7 +38,10 @@ int GatherCPUKernel::Init() {
} }
GatherCPUKernel::~GatherCPUKernel() { GatherCPUKernel::~GatherCPUKernel() {
context_->allocator->Free(indices_data_); if (indices_data_ != nullptr) {
free(indices_data_);
indices_data_ = nullptr;
}
} }
int GatherCPUKernel::ReSize() { return RET_OK; } int GatherCPUKernel::ReSize() { return RET_OK; }
...@@ -102,7 +105,7 @@ int GatherCPUKernel::Run() { ...@@ -102,7 +105,7 @@ int GatherCPUKernel::Run() {
} }
auto indices_tensor = in_tensors_.at(1); auto indices_tensor = in_tensors_.at(1);
indices_data_ = reinterpret_cast<int *>(context_->allocator->Malloc(indices_tensor->Size())); indices_data_ = reinterpret_cast<int *>(malloc(indices_tensor->Size()));
if (indices_data_ == nullptr) { if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed"; MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR; return RET_ERROR;
......
...@@ -36,7 +36,7 @@ class GatherCPUKernel : public LiteKernel { ...@@ -36,7 +36,7 @@ class GatherCPUKernel : public LiteKernel {
int DoGather(int task_id); int DoGather(int task_id);
private: private:
int *indices_data_; int *indices_data_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -26,7 +26,6 @@ using mindspore::lite::KernelRegistrar; ...@@ -26,7 +26,6 @@ using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LeakyReLU; using mindspore::schema::PrimitiveType_LeakyReLU;
using mindspore::schema::PrimitiveType_Prelu;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace { namespace {
...@@ -100,5 +99,4 @@ kernel::LiteKernel *CpuLeakyReluFp32KernelCreator(const std::vector<lite::tensor ...@@ -100,5 +99,4 @@ kernel::LiteKernel *CpuLeakyReluFp32KernelCreator(const std::vector<lite::tensor
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, CpuLeakyReluFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, CpuLeakyReluFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Prelu, CpuLeakyReluFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -24,7 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; ...@@ -24,7 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_CaffePReLU; using mindspore::schema::PrimitiveType_PReLU;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace { namespace {
...@@ -155,7 +155,7 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te ...@@ -155,7 +155,7 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te
MS_LOG(ERROR) << "input param is nullptr!"; MS_LOG(ERROR) << "input param is nullptr!";
return nullptr; return nullptr;
} }
MS_ASSERT(desc.type == schema::PrimitiveType_Prelu);
auto *kernel = new (std::nothrow) PReluCPUKernel(param, inputs, outputs, ctx, primitive); auto *kernel = new (std::nothrow) PReluCPUKernel(param, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new PReluCPUKernel fail!"; MS_LOG(ERROR) << "new PReluCPUKernel fail!";
...@@ -171,5 +171,5 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te ...@@ -171,5 +171,5 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CaffePReLU, CpuPReluFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PReLU, CpuPReluFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -25,9 +25,20 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; ...@@ -25,9 +25,20 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Prelu;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace {
int LeakyReluInt8Run(void *cdata, int task_id) {
if (cdata == nullptr) {
MS_LOG(ERROR) << "input cdata is nullptr!";
return RET_ERROR;
}
auto relu = reinterpret_cast<LeakyReluInt8CPUKernel *>(cdata);
relu->DoExecute(task_id);
return RET_OK;
}
} // namespace
int LeakyReluInt8CPUKernel::Init() { int LeakyReluInt8CPUKernel::Init() {
LeakyReluBaseCPUKernel::Init(); LeakyReluBaseCPUKernel::Init();
LeakyReluParameter *param = reinterpret_cast<LeakyReluParameter *>(op_parameter_); LeakyReluParameter *param = reinterpret_cast<LeakyReluParameter *>(op_parameter_);
...@@ -82,17 +93,12 @@ int LeakyReluInt8CPUKernel::Run() { ...@@ -82,17 +93,12 @@ int LeakyReluInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret; MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret; return ret;
} }
ret = ParallelLaunch(THREAD_POOL_DEFAULT, PreluInt8Run, this, op_parameter_->thread_num_); ret = ParallelLaunch(THREAD_POOL_DEFAULT, LeakyReluInt8Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "RunPreluParam failed. errorcode: "; MS_LOG(ERROR) << "RunPreluParam failed. errorcode: ";
} }
return RET_OK; return RET_OK;
} }
int PreluInt8Run(void *cdata, int task_id) {
auto prelu = reinterpret_cast<LeakyReluInt8CPUKernel *>(cdata);
prelu->DoExecute(task_id);
return RET_OK;
}
int LeakyReluInt8CPUKernel::DoExecute(int task_id) { int LeakyReluInt8CPUKernel::DoExecute(int task_id) {
auto input_tensor = in_tensors_.at(kInputIndex); auto input_tensor = in_tensors_.at(kInputIndex);
......
...@@ -41,7 +41,6 @@ class LeakyReluInt8CPUKernel : public LeakyReluBaseCPUKernel { ...@@ -41,7 +41,6 @@ class LeakyReluInt8CPUKernel : public LeakyReluBaseCPUKernel {
private: private:
LeakyReluQuantArg quant_prelu_parm_; LeakyReluQuantArg quant_prelu_parm_;
}; };
int PreluInt8Run(void *cdata, int task_id);
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PRELU_INT8_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PRELU_INT8_H_
...@@ -29,7 +29,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU; ...@@ -29,7 +29,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Prelu; using mindspore::schema::PrimitiveType_PReLU;
namespace mindspore::kernel { namespace mindspore::kernel {
...@@ -154,5 +154,5 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::tensor::Ten ...@@ -154,5 +154,5 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::tensor::Ten
return kernel; return kernel;
} }
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Prelu, OpenCLPReluKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -65,14 +65,14 @@ TEST_F(TestPreluInt8, prelu_1) { ...@@ -65,14 +65,14 @@ TEST_F(TestPreluInt8, prelu_1) {
outputs_tensor[0] = output0_tensor; outputs_tensor[0] = output0_tensor;
LeakyReluQuantArg op_param; LeakyReluQuantArg op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_Prelu; op_param.op_parameter_.type_ = schema::PrimitiveType_LeakyReLU;
op_param.slope_ = reinterpret_cast<float *>(malloc(sizeof(float))); op_param.slope_ = reinterpret_cast<float *>(malloc(sizeof(float)));
op_param.slope_[0] = 0.25; op_param.slope_[0] = 0.25;
lite::Context *ctx = new lite::Context; lite::Context *ctx = new lite::Context;
ctx->thread_num_ = 2; ctx->thread_num_ = 2;
op_param.axis_ = 0.25; op_param.axis_ = 0.25;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Prelu}; kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_LeakyReLU};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr); ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel = kernel::LiteKernel *kernel =
......
...@@ -119,15 +119,6 @@ TEST_F(TestTfliteParserPrelu, OpType) { ...@@ -119,15 +119,6 @@ TEST_F(TestTfliteParserPrelu, OpType) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Prelu) << "wrong Op Type";
}
TEST_F(TestTfliteParserPrelu, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value;
std::vector<float> slope(20, 0);
ASSERT_EQ(val.AsPrelu()->slope, slope);
ASSERT_EQ(val.type, schema::PrimitiveType_Prelu);
} }
class TestTfliteParserLeakyRelu : public TestTfliteParser { class TestTfliteParserLeakyRelu : public TestTfliteParser {
......
...@@ -29,7 +29,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = { ...@@ -29,7 +29,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize, schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_CaffePReLU}; schema::PrimitiveType_PReLU};
static const std::vector<schema::PrimitiveType> fp32FullOpList = { static const std::vector<schema::PrimitiveType> fp32FullOpList = {
schema::PrimitiveType_Concat, schema::PrimitiveType_Add, schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
......
...@@ -34,7 +34,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, ...@@ -34,7 +34,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto,
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>(); std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
if (attr == nullptr) { if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR; return RET_NULL_PTR;
...@@ -60,7 +60,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, ...@@ -60,7 +60,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto,
weightVec->push_back(slope); weightVec->push_back(slope);
op->name = proto.name(); op->name = proto.name();
op->primitive->value.type = schema::PrimitiveType_CaffePReLU; op->primitive->value.type = schema::PrimitiveType_PReLU;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
......
...@@ -73,7 +73,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No ...@@ -73,7 +73,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
MS_LOG(ERROR) << "input num should be 2"; MS_LOG(ERROR) << "input num should be 2";
return RET_ERROR; return RET_ERROR;
} }
std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>(); std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
std::vector<onnx::TensorProto> params; std::vector<onnx::TensorProto> params;
const auto &input_name = onnx_node.input(1); const auto &input_name = onnx_node.input(1);
for (const auto &it : onnx_graph.initializer()) { for (const auto &it : onnx_graph.initializer()) {
...@@ -102,7 +102,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No ...@@ -102,7 +102,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
} }
} }
op->primitive->value.type = schema::PrimitiveType_CaffePReLU; op->primitive->value.type = schema::PrimitiveType_PReLU;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
......
...@@ -84,52 +84,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t ...@@ -84,52 +84,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_OK; return RET_OK;
} }
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TflitePreluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PreluT> attr = std::make_unique<schema::PreluT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
...@@ -68,18 +68,6 @@ class TfliteLeakyReluParser : public TfliteActivationParser { ...@@ -68,18 +68,6 @@ class TfliteLeakyReluParser : public TfliteActivationParser {
TfliteLeakyReluParser() : TfliteActivationParser() {} TfliteLeakyReluParser() : TfliteActivationParser() {}
}; };
class TflitePreluParser : public TfliteNodeParser {
public:
TflitePreluParser() : TfliteNodeParser("Prelu") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
......
...@@ -107,7 +107,6 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{ ...@@ -107,7 +107,6 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{tflite::BuiltinOperator_DEPTH_TO_SPACE, "DepthToSpace"}, {tflite::BuiltinOperator_DEPTH_TO_SPACE, "DepthToSpace"},
{tflite::BuiltinOperator_SPACE_TO_BATCH_ND, "SpaceToBatchND"}, {tflite::BuiltinOperator_SPACE_TO_BATCH_ND, "SpaceToBatchND"},
{tflite::BuiltinOperator_SPACE_TO_DEPTH, "SpaceToDepth"}, {tflite::BuiltinOperator_SPACE_TO_DEPTH, "SpaceToDepth"},
{tflite::BuiltinOperator_PRELU, "Prelu"},
{tflite::BuiltinOperator_ROUND, "Round"}, {tflite::BuiltinOperator_ROUND, "Round"},
{tflite::BuiltinOperator_WHERE, "Where"}, {tflite::BuiltinOperator_WHERE, "Where"},
{tflite::BuiltinOperator_SPARSE_TO_DENSE, "SparseToDense"}, {tflite::BuiltinOperator_SPARSE_TO_DENSE, "SparseToDense"},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册