提交 39c676e2 编写于 作者: K Kexin Zhao

initial commit

上级 3f5705c3
...@@ -270,9 +270,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -270,9 +270,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm, batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>); ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>( ...@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
cblas_daxpy(n, alpha, x, 1, y, 1); cblas_daxpy(n, alpha, x, 1, y, 1);
} }
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, float>; template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>; template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>; template struct SetConstant<platform::CPUDeviceContext, int>;
......
...@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>( ...@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
&alpha, x, 1, y, 1)); &alpha, x, 1, y, 1));
} }
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>; template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>; template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>; template struct SetConstant<platform::CUDADeviceContext, int>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册