From df7cc457599cf4e26c7607b184f2e24a82ee1427 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 15 Nov 2021 07:39:02 +0100 Subject: [PATCH] Added BF16 to mean op (#37104) * Added BF16 to mean op * fix for CI * fix for CI * fix for CI --- paddle/fluid/operators/mean_op.cc | 8 ++++++-- paddle/pten/kernels/cpu/math.cc | 3 ++- .../paddle/fluid/tests/book/test_fit_a_line.py | 13 +++++++++++++ .../paddle/fluid/tests/unittests/test_mean_op.py | 16 +++++++++++++++- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 764529a15b6..83fe1aa6dd1 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -96,7 +96,11 @@ REGISTER_OPERATOR(mean_grad, ops::MeanGradOp, ops::MeanGradNoNeedBufferVarsInferer); REGISTER_OP_CPU_KERNEL( mean, ops::MeanKernel, - ops::MeanKernel); + ops::MeanKernel, + ops::MeanKernel); REGISTER_OP_CPU_KERNEL( mean_grad, ops::MeanGradKernel, - ops::MeanGradKernel); + ops::MeanGradKernel, + ops::MeanGradKernel); diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc index 25c4671baad..d63292ba1f5 100644 --- a/paddle/pten/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -97,7 +97,8 @@ using complex128 = ::paddle::platform::complex; // using bfloat16 = ::paddle::platform::bfloat16; PT_REGISTER_KERNEL("sign", CPU, ANY, pten::Sign, float, double) {} -PT_REGISTER_KERNEL("mean", CPU, ANY, pten::Mean, float, double) {} +PT_REGISTER_KERNEL( + "mean", CPU, ANY, pten::Mean, float, double, paddle::platform::bfloat16) {} PT_REGISTER_KERNEL("scale", CPU, ANY, diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 8db8b793597..a8a5c8bf315 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -24,10 +24,19 @@ import unittest import math import sys import os +import struct paddle.enable_static() +def convert_uint16_to_float(in_list): + in_list = numpy.asarray(in_list) + out = numpy.vectorize( + lambda x: struct.unpack('