提交 58e6227a 编写于 作者: Z zhongligeng

fix space to depth bug

上级 201bcdd9
......@@ -77,6 +77,8 @@ int SpaceToDepthRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}
int SpaceToDepthCPUKernel::Run() {
input_ptr_ = reinterpret_cast<float *>(inputs_[0]->Data());
output_ptr_ = reinterpret_cast<float *>(outputs_[0]->Data());
if (inputs_[0]->GetFormat() == schema::Format_NHWC) {
int ret = LiteBackendParallelLaunch(SpaceToDepthRun, this, thread_h_num_);
if (ret != RET_OK) {
......
......@@ -158,6 +158,8 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest3) {
}
std::cout << "\n";
CompareOutputData(output.data(), expect_out, out_size, 0.000001);
input_tensor.SetData(nullptr);
output_tensor.SetData(nullptr);
}
} // namespace mindspore
......@@ -56,7 +56,7 @@ TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest2) {
input_tensor.SetFormat(schema::Format_NHWC);
input_tensor.set_data_type(kNumberTypeFloat32);
std::vector<lite::tensor::Tensor *> inputs_tensor;
inputs_tensor.emplace_back(&input_tensor);
inputs_tensor.push_back(&input_tensor);
const int out_size = 16;
float expect_out[16] = {1, 2, 10, 20, 5, 6, 3, 8, 18, 10, 11, 55, 3, 4, 15, 25};
......@@ -68,7 +68,7 @@ TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest2) {
output_tensor.SetFormat(schema::Format_NHWC);
output_tensor.set_data_type(kNumberTypeFloat32);
std::vector<lite::tensor::Tensor *> outputs_tensor;
outputs_tensor.emplace_back(&output_tensor);
outputs_tensor.push_back(&output_tensor);
SpaceToDepthParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册