未验证 提交 57be5c6c 编写于 作者: D dzhwinter 提交者: GitHub

"fix double type error" (#10322)

* "fix double type error"

* "fix ci"
上级 faebadd9
...@@ -87,9 +87,13 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -87,9 +87,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type()); framework::ToDataType(ctx.Input<Tensor>("X")->type());
// For float or float16 input tensor, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto bn_param_type = framework::proto::VarType::FP32; auto bn_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()), framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type"); "Scale input should be of float type");
...@@ -492,8 +496,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, ...@@ -492,8 +496,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_norm, batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>); ops::BatchNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_norm_grad, batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>); ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -287,6 +287,8 @@ namespace ops = paddle::operators; ...@@ -287,6 +287,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>, batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<plat::CUDADeviceContext, double>,
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>); ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>); batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>);
...@@ -204,6 +204,8 @@ REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ...@@ -204,6 +204,8 @@ REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(mul_grad, ops::MulGradOp); REGISTER_OPERATOR(mul_grad, ops::MulGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>); mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>); mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
ops::MulKernel<plat::CUDADeviceContext, double>,
ops::MulKernel<plat::CUDADeviceContext, plat::float16>); ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(mul_grad, REGISTER_OP_CUDA_KERNEL(mul_grad,
ops::MulGradKernel<plat::CUDADeviceContext, float>); ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册