提交 82e8884e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4662 [MS][LITE]add fp16 reshape kernel and fix register kernel bug

Merge pull request !4662 from 张学同/to_merge
......@@ -82,6 +82,11 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
return nullptr;
}
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
<< desc.type;
return nullptr;
}
auto it = creator_arrays_[index];
if (it != nullptr) {
return it;
......@@ -91,9 +96,9 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
int index;
int device_index = static_cast<int>(desc.arch);
int dType_index = static_cast<int>(desc.data_type);
int op_index = static_cast<int>(desc.type);
int device_index = static_cast<int>(desc.arch) - kKernelArch_MIN;
int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin;
int op_index = static_cast<int>(desc.type) - PrimitiveType_MIN;
index = device_index * data_type_length_ * op_type_length_ + dType_index * op_type_length_ + op_index;
return index;
}
......@@ -115,6 +120,11 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c
}
KernelKey desc = {arch, data_type, op_type};
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
<< desc.type;
return;
}
creator_arrays_[index] = creator;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/reshape_fp16.h"
#include <vector>
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/reshape.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Reshape;
namespace mindspore::kernel {
int ReshapeCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
auto in_tensor = in_tensors_.at(kInputIndex);
auto out_tensor = out_tensors_.at(kOutputIndex);
auto input_ptr = in_tensor->Data();
auto output_ptr = out_tensor->Data();
size_t data_size = out_tensor->Size();
auto in_datatype = in_tensor->data_type();
auto out_datatype = out_tensor->data_type();
if (in_datatype != out_datatype) {
if (in_datatype == kNumberTypeFloat32 && out_datatype == kNumberTypeFloat16) {
input_ptr = context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float16_t));
if (input_ptr == nullptr) {
MS_LOG(ERROR) << "malloc in tensor fail!";
return mindspore::lite::RET_MEMORY_FAILED;
}
Float32ToFloat16(reinterpret_cast<float *>(in_tensor->Data()), reinterpret_cast<float16_t *>(input_ptr),
in_tensor->ElementsNum());
} else if ((in_datatype == kNumberTypeFloat16 && out_datatype == kNumberTypeFloat32)) {
input_ptr = context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float));
if (input_ptr == nullptr) {
MS_LOG(ERROR) << "malloc in tensor fail!";
return mindspore::lite::RET_MEMORY_FAILED;
}
Float16ToFloat32(reinterpret_cast<float16_t *>(in_tensor->Data()), reinterpret_cast<float *>(input_ptr),
in_tensor->ElementsNum());
} else {
MS_LOG(ERROR) << "unsupported data type, in_datatype: " << in_datatype << ",out_datatype: " << out_datatype;
return RET_ERROR;
}
}
Reshape(input_ptr, output_ptr, data_size);
if (in_datatype != out_datatype) {
context_->allocator->Free(input_ptr);
}
return RET_OK;
} // namespace mindspore::kernel
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "src/runtime/kernel/arm/fp32/reshape.h"
using mindspore::lite::Context;
namespace mindspore::kernel {
class ReshapeFp16CPUKernel : public ReshapeCPUKernel {
public:
ReshapeFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ReshapeCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ReshapeFp16CPUKernel() = default;
int Run() override;
private:
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册