未验证 提交 7b453631 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16855 from wzzju/fix_quantize_op

fix the hang bugs of memory copying. test=develop
...@@ -235,11 +235,13 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -235,11 +235,13 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
int g_find_max; int g_find_max;
memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
sizeof(int), 0); sizeof(int), ctx.stream());
ctx.Wait();
if (g_find_max) { if (g_find_max) {
int len; int len;
memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
sizeof(int), 0); sizeof(int), ctx.stream());
ctx.Wait();
FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len, FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
out_scale_data); out_scale_data);
} }
...@@ -258,25 +260,26 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -258,25 +260,26 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
T accum; T accum;
memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
sizeof(T), 0);
T state; T state;
memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
sizeof(T), 0);
T scale; T scale;
memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
sizeof(T), ctx.stream());
memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
sizeof(T), ctx.stream());
memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T), memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
0); ctx.stream());
ctx.Wait();
state = rate * state + 1; state = rate * state + 1;
accum = rate * accum + scale; accum = rate * accum + scale;
scale = accum / state; scale = accum / state;
memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place), memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place),
platform::CPUPlace(), &accum, sizeof(T), 0); platform::CPUPlace(), &accum, sizeof(T), ctx.stream());
memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place), memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place),
platform::CPUPlace(), &state, sizeof(T), 0); platform::CPUPlace(), &state, sizeof(T), ctx.stream());
memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place), memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place),
platform::CPUPlace(), &scale, sizeof(T), 0); platform::CPUPlace(), &scale, sizeof(T), ctx.stream());
ctx.Wait();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册