From 57be5c6c743408c010875cde372d17df529c71a0 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 2 May 2018 17:00:01 +0800 Subject: [PATCH] "fix double type error" (#10322) * "fix double type error" * "fix ci" --- paddle/fluid/operators/batch_norm_op.cc | 15 ++++++++++----- paddle/fluid/operators/batch_norm_op.cu.cc | 4 +++- paddle/fluid/operators/mul_op.cc | 6 ++++-- paddle/fluid/operators/mul_op.cu.cc | 4 +++- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index c9939e8602e..f8b2505ccfb 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -87,9 +87,13 @@ class BatchNormOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); - // For float or float16 input tensor, the type of the scale, bias, mean, - // and var tensors should both be float. + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). auto bn_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + bn_param_type = framework::proto::VarType::FP64; + } PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(ctx.Input("Scale")->type()), "Scale input should be of float type"); @@ -492,8 +496,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OP_CPU_KERNEL( - batch_norm, - ops::BatchNormKernel); + batch_norm, ops::BatchNormKernel, + ops::BatchNormKernel); REGISTER_OP_CPU_KERNEL( batch_norm_grad, - ops::BatchNormGradKernel); + ops::BatchNormGradKernel, + ops::BatchNormGradKernel); diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index cb1927bc0f2..550dd32d367 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -287,6 +287,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( batch_norm, ops::BatchNormKernel, + ops::BatchNormKernel, ops::BatchNormKernel); REGISTER_OP_CUDA_KERNEL( - batch_norm_grad, ops::BatchNormGradKernel); + batch_norm_grad, ops::BatchNormGradKernel, + ops::BatchNormGradKernel); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index c9fabc8d485..6903cf83b41 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -204,6 +204,8 @@ REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(mul_grad, ops::MulGradOp); REGISTER_OP_CPU_KERNEL( - mul, ops::MulKernel); + mul, ops::MulKernel, + ops::MulKernel); REGISTER_OP_CPU_KERNEL( - mul_grad, ops::MulGradKernel); + mul_grad, ops::MulGradKernel, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 757f9c3ee26..81f3e42bf41 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -18,6 +18,8 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, + ops::MulKernel, ops::MulKernel); REGISTER_OP_CUDA_KERNEL(mul_grad, - ops::MulGradKernel); + ops::MulGradKernel, + ops::MulGradKernel); -- GitLab