未验证 提交 76766636 编写于 作者: L liu zhengxi 提交者: GitHub

fix the cuda bilinear and nearest precision, test=develop (#2426)

fix the cuda bilinear and nearest precision caused by data type conversion. 
上级 9dcd9914
...@@ -28,8 +28,8 @@ inline std::vector<int> get_new_shape( ...@@ -28,8 +28,8 @@ inline std::vector<int> get_new_shape(
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i]; auto tensor = list_new_shape_tensor[i];
lite::Tensor temp; lite::Tensor temp;
auto temp_data = temp.mutable_data<int32_t>(); auto temp_data = temp.mutable_data<float>();
auto tensor_data = tensor->data<int32_t>(); auto tensor_data = tensor->data<float>();
cudaMemcpy(temp_data, cudaMemcpy(temp_data,
tensor_data, tensor_data,
tensor->dims().production() * sizeof(float), tensor->dims().production() * sizeof(float),
......
...@@ -28,8 +28,8 @@ inline std::vector<int> get_new_shape( ...@@ -28,8 +28,8 @@ inline std::vector<int> get_new_shape(
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i]; auto tensor = list_new_shape_tensor[i];
lite::Tensor temp; lite::Tensor temp;
auto temp_data = temp.mutable_data<int32_t>(); auto temp_data = temp.mutable_data<float>();
auto tensor_data = tensor->data<int32_t>(); auto tensor_data = tensor->data<float>();
cudaMemcpy(temp_data, cudaMemcpy(temp_data,
tensor_data, tensor_data,
tensor->dims().production() * sizeof(float), tensor->dims().production() * sizeof(float),
......
...@@ -202,6 +202,7 @@ TEST(nearest_interp, update) { ...@@ -202,6 +202,7 @@ TEST(nearest_interp, update) {
float* size_tensor1_ref_data = size_tensor_ref[1].mutable_data<float>(); float* size_tensor1_ref_data = size_tensor_ref[1].mutable_data<float>();
float* input_scale_ref_data = input_scale_ref.mutable_data<float>(); float* input_scale_ref_data = input_scale_ref.mutable_data<float>();
float* osz_ref_data = osz_ref.mutable_data<float>(); float* osz_ref_data = osz_ref.mutable_data<float>();
float* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) { for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 5.0; x_cpu_data[i] = i + 5.0;
...@@ -247,8 +248,9 @@ TEST(nearest_interp, update) { ...@@ -247,8 +248,9 @@ TEST(nearest_interp, update) {
CopySync<TARGET(kCUDA)>( CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
NearestInterpRef(&x_ref, &out_ref, false);
for (int i = 0; i < out.numel(); i++) { for (int i = 0; i < out.numel(); i++) {
LOG(INFO) << out_cpu_data[i]; EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册