提交 39c676e2 编写于 作者: K Kexin Zhao

initial commit

上级 3f5705c3
......@@ -270,9 +270,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
batch_norm,
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>);
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);
......@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>;
......
......@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
&alpha, x, 1, y, 1));
}
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册