diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index c9939e8602ed341d37784ca292a55326899e8e65..f8b2505ccfb143f9f74cf0b16d92e8e1ca059709 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -87,9 +87,13 @@ class BatchNormOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); - // For float or float16 input tensor, the type of the scale, bias, mean, - // and var tensors should both be float. + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). auto bn_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + bn_param_type = framework::proto::VarType::FP64; + } PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(ctx.Input("Scale")->type()), "Scale input should be of float type"); @@ -492,8 +496,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OP_CPU_KERNEL( - batch_norm, - ops::BatchNormKernel); + batch_norm, ops::BatchNormKernel, + ops::BatchNormKernel); REGISTER_OP_CPU_KERNEL( batch_norm_grad, - ops::BatchNormGradKernel); + ops::BatchNormGradKernel, + ops::BatchNormGradKernel); diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index cb1927bc0f2eb735f0a3184df5f0f8fada2f9dca..550dd32d36767f90e880415bfffaf01aeb623609 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -287,6 +287,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( batch_norm, ops::BatchNormKernel, + ops::BatchNormKernel, ops::BatchNormKernel); REGISTER_OP_CUDA_KERNEL( - batch_norm_grad, ops::BatchNormGradKernel); + batch_norm_grad, ops::BatchNormGradKernel, + ops::BatchNormGradKernel); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index c9fabc8d485b3bba2c8ae14b3616d0bdcae058a7..6903cf83b41a54b54382fac2cf58f7bfe192b55f 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -204,6 +204,8 @@ REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(mul_grad, ops::MulGradOp); REGISTER_OP_CPU_KERNEL( - mul, ops::MulKernel); + mul, ops::MulKernel, + ops::MulKernel); REGISTER_OP_CPU_KERNEL( - mul_grad, ops::MulGradKernel); + mul_grad, ops::MulGradKernel, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 757f9c3ee2665c7ac654659416fe8dd727dca16d..81f3e42bf412fa4d2cb48405f2f8ee49b6aa0b67 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -18,6 +18,8 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, + ops::MulKernel, ops::MulKernel); REGISTER_OP_CUDA_KERNEL(mul_grad, - ops::MulGradKernel); + ops::MulGradKernel, + ops::MulGradKernel);