未验证 提交 64a40442 编写于 作者: L liu zhengxi 提交者: GitHub

add double register op_data_type of pad2d and fix compile error, test=develop (#22075)

上级 7ba7acd1
......@@ -661,5 +661,8 @@ REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker,
ops::Pad2dOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad,
ops::Pad2dOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>);
REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel<float>);
REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>,
ops::Pad2dCPUKernel<double>, ops::Pad2dCPUKernel<int>,
ops::Pad2dCPUKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel<float>,
ops::Pad2dGradCPUKernel<double>);
......@@ -215,7 +215,8 @@ __global__ void Pad2DGradReflectNCHW(const int out_size, T* d_in_data,
in_w = max(in_w, -in_w);
in_h = min(in_h, 2 * in_height - in_h - 2);
in_w = min(in_w, 2 * in_width - in_w - 2);
atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w],
platform::CudaAtomicAdd(
&d_in_data[(nc * in_height + in_h) * in_width + in_w],
d_out_data[out_index]);
}
}
......@@ -240,7 +241,7 @@ __global__ void Pad2DGradReflectNHWC(const int out_size, T* d_in_data,
in_w = max(in_w, -in_w);
in_h = min(in_h, in_height * 2 - in_h - 2);
in_w = min(in_w, in_width * 2 - in_w - 2);
atomicAdd(
platform::CudaAtomicAdd(
&d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c],
d_out_data[out_index]);
}
......@@ -260,7 +261,8 @@ __global__ void Pad2DGradEdgeNCHW(const int out_size, T* d_in_data,
nc /= out_height;
const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
const int in_w = min(in_width - 1, max(out_w - pad_left, 0));
atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w],
platform::CudaAtomicAdd(
&d_in_data[(nc * in_height + in_h) * in_width + in_w],
d_out_data[out_index]);
}
}
......@@ -281,7 +283,7 @@ __global__ void Pad2DGradEdgeNHWC(const int out_size, T* d_in_data,
n /= out_height;
const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
const int in_w = min(in_width - 1, max(out_w - pad_left, 0));
atomicAdd(
platform::CudaAtomicAdd(
&d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c],
d_out_data[out_index]);
}
......@@ -459,5 +461,8 @@ class Pad2dGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel<float>,
ops::Pad2dCUDAKernel<double>, ops::Pad2dCUDAKernel<int>,
ops::Pad2dCUDAKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel<float>,
ops::Pad2dGradCUDAKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册