提交 71101c9c 编写于 作者: D dengkaipeng

fix input_grad not set zero. test=develop

上级 c9e0ade5
...@@ -129,6 +129,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -129,6 +129,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0));
int pixelNum = nt * chw; int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512; int grid_dim = (pixelNum + 512 - 1) / 512;
......
...@@ -88,6 +88,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> { ...@@ -88,6 +88,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0; int src_it = 0;
for (int i = 0; i < output_grad->numel(); i++) { for (int i = 0; i < output_grad->numel(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册