提交 07a75658 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4995 [MS][LITE][Develop]stack support int32

Merge pull request !4995 from chenjianping/lite_dev2
......@@ -17,7 +17,7 @@
#include "nnacl/fp32/stack.h"
#include "nnacl/arithmetic_common.h"
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
size_t GetStackCopyNum(int axis, int *in_shape, size_t shape_size) {
size_t one_input_size = 1;
for (size_t i = 0; i < shape_size; ++i) {
one_input_size *= in_shape[i];
......@@ -26,11 +26,37 @@ void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t
ComputeStrides(in_shape, in_strides, shape_size);
size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size;
size_t copy_size = copy_num * sizeof(float);
return copy_num;
}
size_t GetStackPreAxisCount(const int *in_shape, int axis) {
size_t pre_axis_count = 1;
for (size_t i = 0; i < axis; ++i) {
pre_axis_count *= in_shape[i];
}
return pre_axis_count;
}
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
size_t copy_size = copy_num * sizeof(float);
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
size_t in_offset = 0;
size_t out_offset = 0;
for (size_t i = 0; i < pre_axis_count; ++i) {
for (size_t j = 0; j < input_num; ++j) {
memcpy(output + out_offset, inputs[j] + in_offset, copy_size);
out_offset += copy_num;
}
in_offset += copy_num;
}
}
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
int32_t *output) {
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
size_t copy_size = copy_num * sizeof(int32_t);
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
size_t in_offset = 0;
size_t out_offset = 0;
for (size_t i = 0; i < pre_axis_count; ++i) {
......
......@@ -27,6 +27,8 @@ typedef struct StackParameter {
extern "C" {
#endif
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output);
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
int32_t *output);
#ifdef __cplusplus
}
#endif
......
......@@ -56,7 +56,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID;
}
auto input = inputs.at(0);
outputs[0]->set_data_type(input->data_type());
auto input0_data_type = input->data_type();
outputs[0]->set_data_type(input0_data_type);
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
......@@ -69,12 +70,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
return RET_PARAM_INVALID;
}
schema::Format input0_format = input->GetFormat();
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs[i]->GetFormat() != input0_format) {
MS_LOG(ERROR) << "All inputs should have the same format!";
return RET_PARAM_INVALID;
}
auto input_shape_tmp = inputs[i]->shape();
if (input_shape_tmp.size() != input_shape.size()) {
MS_LOG(ERROR) << "All input shape size should be the same!";
......@@ -86,6 +83,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID;
}
}
if (inputs[i]->data_type() != input0_data_type) {
MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i << "] data type = "
<< inputs[i]->data_type();
return RET_PARAM_INVALID;
}
}
output_shape.insert(output_shape.begin() + axis, inputs.size());
outputs[0]->set_shape(output_shape);
......
......@@ -49,12 +49,21 @@ int StackCPUKernel::Run() {
}
size_t inputs_num = in_tensors_.size();
auto input0_shape = in_tensors_[0]->shape();
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
float *inputs[inputs_num];
for (size_t i = 0; i < inputs_num; ++i) {
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data());
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
float *inputs[inputs_num];
for (size_t i = 0; i < inputs_num; ++i) {
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data());
}
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
} else {
auto *output_data = reinterpret_cast<int32_t *>(out_tensors_[0]->Data());
int32_t *inputs[inputs_num];
for (size_t i = 0; i < inputs_num; ++i) {
inputs[i] = reinterpret_cast<int32_t *>(in_tensors_[i]->Data());
}
DoStackInt32(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
}
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
return RET_OK;
}
......@@ -85,4 +94,5 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector<lite::tensor::Te
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
} // namespace mindspore::kernel
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册