未验证 提交 643a268e 编写于 作者: S sneaxiy 提交者: GitHub

Support FP16 mean (#38289)

* mean first version

* fix scalar mean

* add fp16 dtype for api
上级 c197d73b
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#pragma once #pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/eigen_ext.h" #include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -74,16 +77,20 @@ struct IdentityFunctor { ...@@ -74,16 +77,20 @@ struct IdentityFunctor {
*/ */
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct DivideFunctor { struct DivideFunctor {
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<Tx>(1.0f); } private:
using MPType = typename ::paddle::operators::details::MPTypeTrait<Tx>::Type;
public:
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<MPType>(1.0f); }
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((Tx)(1.0 / n)) {} HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x * n_inv); return static_cast<Ty>(static_cast<MPType>(x) * n_inv);
} }
private: private:
Tx n_inv; MPType n_inv;
}; };
/** /**
......
...@@ -18,30 +18,23 @@ limitations under the License. */ ...@@ -18,30 +18,23 @@ limitations under the License. */
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
template <typename T> template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) { __global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
using MT = typename details::MPTypeTrait<T>::Type;
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
T data = in_data[0]; auto data = static_cast<MT>(in_data[0]);
for (; idx < N; idx += blockDim.x * gridDim.x) { for (; idx < N; idx += blockDim.x * gridDim.x) {
out_data[idx] = data / (static_cast<T>(N)); out_data[idx] = static_cast<T>(data / (static_cast<MT>(N)));
} }
} }
...@@ -52,27 +45,29 @@ class MeanCUDAKernel : public framework::OpKernel<T> { ...@@ -52,27 +45,29 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
auto* input = context.Input<Tensor>("X"); auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out"); auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto size_prob = input->numel();
const T* in_data = input->data<T>(); const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(context.GetPlace()); T* out_data = output->mutable_data<T>(context.GetPlace());
auto numel = input->numel();
auto rank = input->dims().size();
auto place = context.GetPlace();
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
DivideFunctor<T> transformer(size_prob); if (rank == 0) { // scalar
cub::TransformInputIterator<T, DivideFunctor<T>, const T*> trans_x( auto gpu_place = BOOST_GET(platform::CUDAPlace, place);
in_data, transformer); memory::Copy(gpu_place, out_data, gpu_place, in_data, numel * sizeof(T),
size_t temp_storage_bytes = 0; stream);
return;
}
auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x, using MT = typename details::MPTypeTrait<T>::Type;
out_data, size_prob, stream); using Div = kernel_primitives::DivideFunctor<T, MT>;
PADDLE_ENFORCE_GPU_SUCCESS(err); std::vector<int> reduce_dims;
framework::Tensor tmp; reduce_dims.reserve(rank);
auto* temp_storage = tmp.mutable_data<uint8_t>( for (decltype(rank) i = 0; i < rank; ++i) {
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), reduce_dims.push_back(i);
context.GetPlace()); }
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x, TensorReduceFunctorImpl<T, T, kernel_primitives::AddFunctor, Div>(
out_data, size_prob, stream); *input, output, Div(numel), reduce_dims, stream);
PADDLE_ENFORCE_GPU_SUCCESS(err);
} }
}; };
......
...@@ -77,7 +77,7 @@ struct CustomSub { ...@@ -77,7 +77,7 @@ struct CustomSub {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMean { struct CustomMean {
using Transformer = kps::DivideFunctor<Tx>; using Transformer = kps::DivideFunctor<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); } inline Ty initial() { return static_cast<Ty>(0.0f); }
......
...@@ -19,5 +19,7 @@ ...@@ -19,5 +19,7 @@
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_mean, reduce_mean,
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>, ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
kps::DivideFunctor>,
ops::ReduceCudaKernel<float, kps::AddFunctor, kps::DivideFunctor>, ops::ReduceCudaKernel<float, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::DivideFunctor>); ops::ReduceCudaKernel<double, kps::AddFunctor, kps::DivideFunctor>);
...@@ -35,5 +35,18 @@ struct MeanGradFunctor { ...@@ -35,5 +35,18 @@ struct MeanGradFunctor {
} }
}; };
// TODO(zengjinle): Should refine the numeric stability of FP16 reduce_mean
// and reduce_mean_grad later.
struct FP16MeanGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = (dy->template cast<float>().broadcast(dim) /
dx->template cast<float>().constant(size))
.template cast<platform::float16>();
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,12 @@ using CUDAReduceMeanGradKernel = ...@@ -20,6 +20,12 @@ using CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>; ops::MeanGradFunctor, true>;
using FP16CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16, ops::FP16MeanGradFunctor,
true>;
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>, REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
FP16CUDAReduceMeanGradKernel,
CUDAReduceMeanGradKernel<float>, CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>); CUDAReduceMeanGradKernel<double>);
...@@ -38,7 +38,9 @@ namespace cub = hipcub; ...@@ -38,7 +38,9 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/string/string_helper.h"
// Reduce split or not, Whether to use ReduceHigherDim // Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512 #define REDUCE_SPLIT_BOUNDARY 512
...@@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, ...@@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
} }
} }
template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
static typename std::enable_if<!std::is_same<Tx, platform::float16>::value,
void>::type
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
const TransformOp& transform, int reduce_num,
const platform::Place& place, gpuStream_t stream) {
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, reducer.initial(), stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, reducer.initial(), stream);
}
template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
static typename std::enable_if<std::is_same<Tx, platform::float16>::value,
void>::type
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
const TransformOp& transform, int reduce_num,
const platform::Place& place, gpuStream_t stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}
template <typename Tx, typename Ty, template <typename> class ReduceOp, template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp> typename TransformOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
const TransformOp& transform, const TransformOp& transform,
std::vector<int> origin_reduce_dims, const std::vector<int>& origin_reduce_dims,
gpuStream_t stream) { gpuStream_t stream) {
auto x_dim = framework::vectorize<int>(x.dims()); auto x_dim = framework::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim); auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
...@@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
} }
config.SetOutputData(y_data, x.place(), &tmp); config.SetOutputData(y_data, x.place(), &tmp);
bool use_cub_reduce = (config.reduce_num == numel) && constexpr bool kIsTxFP16 = std::is_same<Tx, paddle::platform::float16>::value;
(!std::is_same<Tx, paddle::platform::float16>::value); bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
if (use_cub_reduce) { if (use_cub_reduce) {
// launch CUB::Reduce CubTensorReduceFunctorImpl<Tx, Ty, ReduceOp, TransformOp>(
auto reducer = ReduceOp<Ty>(); x_data, y_data, transform, config.reduce_num, x.place(), stream);
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
x.place());
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);
return; return;
} }
......
...@@ -703,7 +703,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> { ...@@ -703,7 +703,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
GetReduceDim(dims, input->dims().size(), reduce_all); GetReduceDim(dims, input->dims().size(), reduce_all);
int reduce_num = 1; int reduce_num = 1;
for (int i = 0; i < input->dims().size(); i++) { for (auto i : reduce_dims) {
reduce_num *= (input->dims())[i]; reduce_num *= (input->dims())[i];
} }
gpuStream_t stream = context.cuda_device_context().stream(); gpuStream_t stream = context.cuda_device_context().stream();
...@@ -713,8 +713,10 @@ class ReduceCudaKernel : public framework::OpKernel<T> { ...@@ -713,8 +713,10 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
TensorReduceFunc<T, ReduceOp, TransformOp>( TensorReduceFunc<T, ReduceOp, TransformOp>(
*input, output, reduce_dims, reduce_num, stream)); *input, output, reduce_dims, reduce_num, stream));
} else { } else {
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, T>>( using MPType = typename details::MPTypeTrait<T>::Type;
*input, output, TransformOp<T, T>(reduce_num), reduce_dims, stream); TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
*input, output, TransformOp<T, MPType>(reduce_num), reduce_dims,
stream);
} }
} }
}; };
......
...@@ -63,17 +63,25 @@ class TestMeanOpError(unittest.TestCase): ...@@ -63,17 +63,25 @@ class TestMeanOpError(unittest.TestCase):
class TestFP16MeanOp(TestMeanOp): class TestFP16MeanOp(TestMeanOp):
def init_dtype_type(self): def init_dtype_type(self):
self.dtype = np.float16 self.dtype = np.float16
self.__class__.no_need_check_grad = True
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-3) self.check_output_with_place(place)
def test_checkout_grad(self): def test_checkout_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_grad_with_place( with fluid.dygraph.guard():
place, ['X'], 'Out', max_relative_error=0.8) x_np = np.random.random((10, 10)).astype(self.dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = fluid.layers.mean(x)
dx = paddle.grad(y, x)[0].numpy()
dx_expected = self.dtype(1.0 / np.prod(x_np.shape)) * np.ones(
x_np.shape).astype(self.dtype)
self.assertTrue(np.array_equal(dx, dx_expected))
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
...@@ -98,6 +106,14 @@ def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False): ...@@ -98,6 +106,14 @@ def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
return np.mean(x, axis=axis, keepdims=keepdim) return np.mean(x, axis=axis, keepdims=keepdim)
def ref_reduce_mean_grad(x, axis, dtype):
if reduce_all:
axis = list(range(x.ndim))
shape = [x.shape[i] for i in axis]
return (1.0 / np.prod(shape) * np.ones(shape)).astype(dtype)
class TestReduceMeanOp(OpTest): class TestReduceMeanOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'reduce_mean' self.op_type = 'reduce_mean'
...@@ -105,11 +121,13 @@ class TestReduceMeanOp(OpTest): ...@@ -105,11 +121,13 @@ class TestReduceMeanOp(OpTest):
self.shape = [2, 3, 4, 5] self.shape = [2, 3, 4, 5]
self.axis = [0] self.axis = [0]
self.keepdim = False self.keepdim = False
self.reduce_all = False
self.set_attrs() self.set_attrs()
np.random.seed(10) np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if not hasattr(self, "reduce_all"):
self.reduce_all = (not self.axis) or len(self.axis) == len(x_np)
out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all) out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': x_np} self.inputs = {'X': x_np}
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
...@@ -119,14 +137,39 @@ class TestReduceMeanOp(OpTest): ...@@ -119,14 +137,39 @@ class TestReduceMeanOp(OpTest):
'reduce_all': self.reduce_all 'reduce_all': self.reduce_all
} }
if self.dtype == 'float16':
self.__class__.no_need_check_grad = True
def set_attrs(self): def set_attrs(self):
pass pass
def test_check_output(self): def test_check_output(self):
self.check_output() if self.dtype != 'float16':
self.check_output()
else:
if not core.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
self.check_output_with_place(place=place)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], ['Out']) if self.dtype != 'float16':
self.check_grad(['X'], ['Out'])
else:
return
if not core.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
if core.is_float16_supported(place):
return
with fluid.dygraph.guard(place=place):
x = paddle.tensor(self.inputs['X'])
y = paddle.mean(
x, axis=self.attrs['dim'], keepdim=self.attrs['keep_dim'])
dx = paddle.grad(y, x)[0].numpy()
dx_expected = ref_reduce_mean_grad(
self.inputs['X'], self.attrs['dim'], self.dtype)
self.assertTrue(np.array_equal(dx, dx_expected))
class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp): class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp):
...@@ -146,47 +189,101 @@ class TestReduceMeanOpFloat32(TestReduceMeanOp): ...@@ -146,47 +189,101 @@ class TestReduceMeanOpFloat32(TestReduceMeanOp):
self.dtype = 'float32' self.dtype = 'float32'
class TestReduceMeanOpFloat16(TestReduceMeanOp):
def set_attrs(self):
self.dtype = 'float16'
class TestReduceMeanOpShape1D(TestReduceMeanOp): class TestReduceMeanOpShape1D(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.shape = [100] self.shape = [100]
class TestReduceMeanOpShape1DFP16(TestReduceMeanOp):
def set_attrs(self):
self.shape = [100]
self.dtype = 'float16'
class TestReduceMeanOpShape6D(TestReduceMeanOp): class TestReduceMeanOpShape6D(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7] self.shape = [2, 3, 4, 5, 6, 7]
class TestReduceMeanOpShape6DFP16(TestReduceMeanOp):
def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7]
self.dtype = 'float16'
class TestReduceMeanOpAxisAll(TestReduceMeanOp): class TestReduceMeanOpAxisAll(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.axis = [0, 1, 2, 3] self.axis = [0, 1, 2, 3]
class TestReduceMeanOpAxisAllFP16(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.dtype = 'float16'
class TestReduceMeanOpAxisTuple(TestReduceMeanOp): class TestReduceMeanOpAxisTuple(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.axis = (0, 1, 2) self.axis = (0, 1, 2)
class TestReduceMeanOpAxisTupleFP16(TestReduceMeanOp):
def set_attrs(self):
self.axis = (0, 1, 2)
self.dtype = 'float16'
class TestReduceMeanOpAxisNegative(TestReduceMeanOp): class TestReduceMeanOpAxisNegative(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.axis = [-2, -1] self.axis = [-2, -1]
class TestReduceMeanOpAxisNegativeFP16(TestReduceMeanOp):
def set_attrs(self):
self.axis = [-2, -1]
self.dtype = 'float16'
class TestReduceMeanOpKeepdimTrue1(TestReduceMeanOp): class TestReduceMeanOpKeepdimTrue1(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.keepdim = True self.keepdim = True
class TestReduceMeanOpKeepdimTrue1FP16(TestReduceMeanOp):
def set_attrs(self):
self.keepdim = True
self.dtype = 'float16'
class TestReduceMeanOpKeepdimTrue2(TestReduceMeanOp): class TestReduceMeanOpKeepdimTrue2(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.axis = [0, 1, 2, 3] self.axis = [0, 1, 2, 3]
self.keepdim = True self.keepdim = True
class TestReduceMeanOpKeepdimTrue2FP16(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.keepdim = True
self.dtype = 'float16'
class TestReduceMeanOpReduceAllTrue(TestReduceMeanOp): class TestReduceMeanOpReduceAllTrue(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.reduce_all = True self.reduce_all = True
class TestReduceMeanOpReduceAllTrueFP16(TestReduceMeanOp):
def set_attrs(self):
self.reduce_all = True
self.dtype = 'float16'
class TestMeanAPI(unittest.TestCase): class TestMeanAPI(unittest.TestCase):
# test paddle.tensor.stat.mean # test paddle.tensor.stat.mean
......
...@@ -92,7 +92,8 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -92,7 +92,8 @@ def mean(x, axis=None, keepdim=False, name=None):
return _C_ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim, return _C_ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all) 'reduce_all', reduce_all)
check_variable_and_dtype(x, 'x/input', ['uint16', 'float32', 'float64'], check_variable_and_dtype(x, 'x/input',
['uint16', 'float16', 'float32', 'float64'],
'mean/reduce_mean') 'mean/reduce_mean')
check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean') check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean')
if isinstance(axis, (list, tuple)): if isinstance(axis, (list, tuple)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册