未验证 提交 57be5c6c 编写于 作者: D dzhwinter 提交者: GitHub

"fix double type error" (#10322)

* "fix double type error"

* "fix ci"
上级 faebadd9
......@@ -87,9 +87,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// For float or float16 input tensor, the type of the scale, bias, mean,
// and var tensors should both be float.
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto bn_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type");
......@@ -492,8 +496,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(
batch_norm,
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -287,6 +287,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<plat::CUDADeviceContext, double>,
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>);
......@@ -204,6 +204,8 @@ REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(mul_grad, ops::MulGradOp);
REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>);
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>);
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -18,6 +18,8 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
ops::MulKernel<plat::CUDADeviceContext, double>,
ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(mul_grad,
ops::MulGradKernel<plat::CUDADeviceContext, float>);
ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册