未验证 提交 e7c81600 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add BF16 uniform random initializer (#32468) (#32677)

* Add bf16 uniform random initializer

* Remove duplicated section

* Change UT to CPU place only

* Put detail functions into anonymous namespace
上级 93535c59
......@@ -117,6 +117,9 @@ class FillConstantKernel : public framework::OpKernel<T> {
}
if (actual_place == 0) {
VLOG(4) << "[CPU] FillConstantKernel"
<< ((data_type == framework::proto::VarType::BF16) ? "<bfloat16>"
: "<T>");
tensor->mutable_data(platform::CPUPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
......
......@@ -18,10 +18,41 @@ limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/bfloat16.h"
namespace paddle {
namespace operators {
namespace {
template <typename T>
inline void UniformRealDistribution(T *data, const int64_t &size,
const float &min, const float &max,
const unsigned int &seed) {
VLOG(4) << "[CPU] UniformRandomKernel<T>";
std::uniform_real_distribution<T> dist(static_cast<T>(min),
static_cast<T>(max));
auto engine = paddle::framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
template <>
inline void UniformRealDistribution(paddle::platform::bfloat16 *data,
const int64_t &size, const float &min,
const float &max,
const unsigned int &seed) {
VLOG(4) << "[CPU] UniformRandomKernel<bfloat16>";
std::uniform_real_distribution<float> dist(min, max);
auto engine = paddle::framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = static_cast<paddle::platform::bfloat16>(dist(*engine));
}
}
} // namespace
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
......@@ -61,17 +92,11 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
framework::ToTypeName(out_var->Type())));
}
T *data = tensor->mutable_data<T>(ctx.GetPlace());
int64_t size = tensor->numel();
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
UniformRealDistribution<T>(
data, size, ctx.Attr<float>("min"), ctx.Attr<float>("max"),
static_cast<unsigned int>(ctx.Attr<int>("seed")));
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
......@@ -257,9 +282,12 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::UniformRandomOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(uniform_random_batch_size_like,
paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(
uniform_random, paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>,
paddle::operators::CPUUniformRandomKernel<paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
uniform_random_batch_size_like,
paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>,
paddle::operators::CPUUniformRandomKernel<paddle::platform::bfloat16>);
......@@ -24,9 +24,9 @@ namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int64_t> GetNewDataFromShapeTensor(
const Tensor *new_data_tensor) {
const Tensor* new_data_tensor) {
if (new_data_tensor->type() == framework::proto::VarType::INT64) {
auto *new_data = new_data_tensor->data<int64_t>();
auto* new_data = new_data_tensor->data<int64_t>();
framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(),
......@@ -37,7 +37,7 @@ inline std::vector<int64_t> GetNewDataFromShapeTensor(
new_data + new_data_tensor->numel());
return vec_new_data;
} else if (new_data_tensor->type() == framework::proto::VarType::INT32) {
auto *new_data = new_data_tensor->data<int32_t>();
auto* new_data = new_data_tensor->data<int32_t>();
std::vector<int64_t> vec_new_data;
framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) {
......@@ -58,7 +58,7 @@ inline std::vector<int64_t> GetNewDataFromShapeTensor(
}
inline std::vector<int64_t> GetNewDataFromShapeTensorList(
const std::vector<const Tensor *> &list_new_shape_tensor) {
const std::vector<const Tensor*>& list_new_shape_tensor) {
std::vector<int64_t> vec_new_shape;
vec_new_shape.reserve(list_new_shape_tensor.size());
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
......@@ -97,6 +97,5 @@ inline std::vector<int64_t> GetNewDataFromShapeTensorList(
return vec_new_shape;
}
} // namespace operators
} // namespace paddle
......@@ -245,7 +245,7 @@ class UniformInitializer(Initializer):
self._seed = block.program.random_seed
# to be compatible of fp16 initializers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
......@@ -274,7 +274,7 @@ class UniformInitializer(Initializer):
},
stop_gradient=True)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
......@@ -540,7 +540,8 @@ class XavierInitializer(Initializer):
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
......@@ -582,7 +583,8 @@ class XavierInitializer(Initializer):
},
stop_gradient=True)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
block.append_op(
type="cast",
inputs={"X": out_var},
......@@ -671,7 +673,8 @@ class MSRAInitializer(Initializer):
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
......@@ -713,7 +716,8 @@ class MSRAInitializer(Initializer):
},
stop_gradient=True)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
block.append_op(
type="cast",
inputs={"X": out_var},
......
......@@ -10524,10 +10524,10 @@ def uniform_random_batch_size_like(input,
"""
check_variable_and_dtype(input, 'Input', ("float32", 'float64'),
check_variable_and_dtype(input, 'Input', ("float32", 'float64', "uint16"),
'uniform_random_batch_size_like')
check_type(shape, 'shape', (list, tuple), 'uniform_random_batch_size_like')
check_dtype(dtype, 'dtype', ('float32', 'float64'),
check_dtype(dtype, 'dtype', ('float32', 'float64', "uint16"),
'uniform_random_batch_size_like')
helper = LayerHelper('uniform_random_batch_size_like', **locals())
......@@ -15121,7 +15121,8 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
float(max), 'seed', seed, 'dtype', dtype)
check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random/rand')
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random/rand')
check_dtype(dtype, 'dtype', ('float32', 'float64', 'uint16'),
'uniform_random/rand')
inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
......
......@@ -53,7 +53,7 @@ class TestConstantInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.ConstantInitializer())
num_ops = 2 if dtype in ["float16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'fill_constant')
......@@ -72,7 +72,7 @@ class TestConstantInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.ConstantInitializer(2.3))
num_ops = 2 if dtype in ["float16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'fill_constant')
......@@ -108,7 +108,7 @@ class TestUniformInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.UniformInitializer())
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
......@@ -153,7 +153,7 @@ class TestUniformInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.UniformInitializer(-4.2, 3.1, 123))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
......@@ -174,7 +174,7 @@ class TestUniformInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.UniformInitializer(-4.2, float(i), 123))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op0 = block.ops[0]
self.assertEqual(init_op0.type, 'uniform_random')
......@@ -195,13 +195,11 @@ class TestUniformInitializer(unittest.TestCase):
def test_uniform_initializer_bf16(self):
"""Test uniform initializer with bfloat16
No cast operator has been added here
"""
block = self.test_uniform_initializer_default_value("uint16")
self.assertTrue(check_cast_op(block.ops[1]))
block = self.test_uniform_initializer(dtype="uint16")
self.assertTrue(check_cast_op(block.ops[1]))
block = self.test_uniform_initializer_two_op("uint16")
self.assertTrue(check_cast_op(block.ops[1]))
class TestNormalInitializer(unittest.TestCase):
......@@ -347,7 +345,9 @@ class TestXavierInitializer(unittest.TestCase):
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_xavier_initializer_supplied_arguments(self, dtype="float32"):
def test_xavier_initializer_supplied_arguments(self,
dtype="float32",
uniform=True):
"""Test the Xavier initializer with supplied arguments
"""
program = framework.Program()
......@@ -359,14 +359,18 @@ class TestXavierInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.XavierInitializer(
fan_in=12, fan_out=23, seed=134))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
uniform=uniform, fan_in=12, fan_out=23, seed=134))
num_ops = 2 if (dtype == "float16" or (dtype == "uint16" and
not uniform)) else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
limit = np.sqrt(6.0 / (12 + 23))
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
if uniform:
self.assertEqual(init_op.type, 'uniform_random')
limit = np.sqrt(6.0 / (12 + 23))
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
else:
self.assertEqual(init_op.type, 'gaussian_random')
self.assertEqual(init_op.attr('seed'), 134)
return block
......@@ -379,8 +383,12 @@ class TestXavierInitializer(unittest.TestCase):
def test_xavier_initializer_bf16(self):
"""Test the Xavier initializer with bfloat16
"""
block = self.test_xavier_initializer_supplied_arguments("uint16")
self.assertTrue(check_cast_op(block.ops[1]))
block_uniform = self.test_xavier_initializer_supplied_arguments(
"uint16")
self.assertEqual(len(block_uniform.ops), 1)
block_gaussian = self.test_xavier_initializer_supplied_arguments(
"uint16", False)
self.assertTrue(check_cast_op(block_gaussian.ops[1]))
class TestMSRAInitializer(unittest.TestCase):
......@@ -483,7 +491,7 @@ class TestMSRAInitializer(unittest.TestCase):
name="param",
initializer=initializer.MSRAInitializer(
fan_in=12, seed=134))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
......@@ -503,7 +511,6 @@ class TestMSRAInitializer(unittest.TestCase):
"""Test the MSRA initializer with bfloat16
"""
block = self.test_msra_initializer_supplied_arguments("uint16")
self.assertTrue(check_cast_op(block.ops[1]))
class TestBilinearInitializer(unittest.TestCase):
......
......@@ -225,7 +225,7 @@ class TestUniform(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.Uniform())
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
......@@ -256,7 +256,7 @@ class TestUniform(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.Uniform())
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
......@@ -287,7 +287,7 @@ class TestUniform(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.Uniform(min_value, max_vlaue))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
......@@ -317,7 +317,7 @@ class TestUniform(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.Uniform(min_value, float(i)))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 2 if dtype == "float16" else 1
self.assertEqual(len(block.ops), num_ops)
init_op0 = block.ops[0]
self.assertEqual(init_op0.type, 'uniform_random')
......@@ -343,11 +343,8 @@ class TestUniform(unittest.TestCase):
"""Test uniform initializer with bfloat16
"""
block = self.test_uniform_initializer_default_value("uint16") #bfloat16
self.assertTrue(check_cast_op(block.ops[1]))
block = self.test_uniform_initializer(dtype="uint16") #bfloat16
self.assertTrue(check_cast_op(block.ops[1]))
block = self.test_uniform_initializer_two_op("uint16") #bfloat16
self.assertTrue(check_cast_op(block.ops[1]))
def test_uniform_initializer_dygraph(self):
"""Test uniform initializer in dygraph model.
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest, convert_uint16_to_float, convert_float_to_uint16
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.tests.unittests.test_uniform_random_op import output_hist, output_hist_diag
class TestUniformRandomOpBF16(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.dtype = "uint16"
self.inputs = {}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("uint16")}
def init_attrs(self):
self.attrs = {
"shape": [1000, 784],
"min": -5.0,
"max": 10.0,
"seed": 10,
'dtype': int(core.VarDesc.VarType.BF16)
}
self.output_hist = output_hist
def verify_output(self, outs):
if np.array(outs[0]).dtype == np.uint16:
result = convert_uint16_to_float(np.array(outs[0]))
else:
result = np.array(outs[0])
hist, prob = self.output_hist(result)
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
def test_check_output(self):
outs = self.calc_output(core.CPUPlace())
outs = [np.array(out) for out in outs]
outs.sort(key=len)
self.verify_output(outs)
class TestUniformRandomOpBF16AttrTensorList(TestUniformRandomOpBF16):
def setUp(self):
self.op_type = "uniform_random"
self.new_shape = (1000, 784)
self.dtype = "uint16"
shape_tensor = []
for index, ele in enumerate(self.new_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype("int64") * ele))
self.inputs = {'ShapeTensorList': shape_tensor}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("uint16")}
def init_attrs(self):
self.attrs = {
"min": -5.0,
"max": 10.0,
"seed": 10,
'dtype': int(core.VarDesc.VarType.BF16)
}
self.output_hist = output_hist
class TestUniformRandomOpBF16AttrTensorInt32(
TestUniformRandomOpBF16AttrTensorList):
def setUp(self):
self.op_type = "uniform_random"
self.dtype = "uint16"
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int32")}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("uint16")}
class TestUniformRandomOpBF16WithDiagInit(TestUniformRandomOpBF16):
def init_attrs(self):
self.attrs = {
"shape": [1000, 784],
"min": -5.0,
"max": 10.0,
"seed": 10,
"diag_num": 784,
"diag_step": 784,
"diag_val": 1.0,
'dtype': int(core.VarDesc.VarType.BF16)
}
self.output_hist = output_hist_diag
class TestUniformRandomOpBF16SelectedRows(unittest.TestCase):
def test_check_output(self):
self.check_with_place(core.CPUPlace())
def check_with_place(self, place):
scope = core.Scope()
out = scope.var("X").get_selected_rows()
paddle.seed(10)
op = Operator(
"uniform_random",
Out="X",
shape=[1000, 784],
min=-5.0,
max=10.0,
seed=10,
dtype=int(core.VarDesc.VarType.BF16))
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [1000, 784])
result = convert_uint16_to_float(np.array(out.get_tensor()))
hist, prob = output_hist(result)
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpBF16SelectedRowsWithDiagInit(
TestUniformRandomOpBF16SelectedRows):
def check_with_place(self, place):
scope = core.Scope()
out = scope.var("X").get_selected_rows()
paddle.seed(10)
op = Operator(
"uniform_random",
Out="X",
shape=[500, 784],
min=-5.0,
max=10.0,
seed=10,
diag_num=500,
diag_step=784,
diag_val=1.0,
dtype=int(core.VarDesc.VarType.BF16))
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [500, 784])
result = convert_uint16_to_float(np.array(out.get_tensor()))
hist, prob = output_hist(result)
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpBF16AttrTensorAPI(unittest.TestCase):
def test_attr_tensor_API(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
dim_tensor = fluid.layers.fill_constant([1], "int64", 3)
ret = fluid.layers.nn.uniform_random(
[1, dim_tensor, 2], dtype=np.uint16)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run(train_program, fetch_list=[ret])
class TestUniformRandomOpAPISeed(unittest.TestCase):
def test_attr_tensor_API(self):
_seed = 10
gen = paddle.seed(_seed)
gen._is_init_py = False
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
_min = 5
_max = 10
ret = fluid.layers.nn.uniform_random(
[2, 3, 2], min=_min, max=_max, seed=_seed)
ret_2 = fluid.layers.nn.uniform_random(
[2, 3, 2], min=_min, max=_max, seed=_seed)
res = fluid.layers.equal(ret, ret_2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
ret_value, cmp_value = exe.run(train_program, fetch_list=[ret, res])
self.assertTrue(np.array(cmp_value).all())
for i in ret_value.flatten():
self.assertGreaterEqual(i, _min)
self.assertLess(i, _max)
class TestUniformRandomOpBF16SelectedRowsShapeTensor(unittest.TestCase):
def test_check_output(self):
place = core.CPUPlace()
scope = core.Scope()
out = scope.var("X").get_selected_rows()
shape_tensor = scope.var("Shape").get_tensor()
shape_tensor.set(np.array([1000, 784]).astype("int64"), place)
paddle.seed(10)
op = Operator(
"uniform_random",
ShapeTensor="Shape",
Out="X",
min=-5.0,
max=10.0,
seed=10,
dtype=int(core.VarDesc.VarType.BF16))
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [1000, 784])
result = convert_uint16_to_float(np.array(out.get_tensor()))
hist, prob = output_hist(result)
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpBF16SelectedRowsShapeTensorList(
TestUniformRandomOpBF16SelectedRowsShapeTensor):
def test_check_output(self):
place = core.CPUPlace()
scope = core.Scope()
out = scope.var("X").get_selected_rows()
shape_1 = scope.var("shape1").get_tensor()
shape_1.set(np.array([1000]).astype("int64"), place)
shape_2 = scope.var("shape2").get_tensor()
shape_2.set(np.array([784]).astype("int64"), place)
paddle.seed(10)
op = Operator(
"uniform_random",
ShapeTensorList=["shape1", "shape2"],
Out="X",
min=-5.0,
max=10.0,
seed=10,
dtype=int(core.VarDesc.VarType.BF16))
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [1000, 784])
result = convert_uint16_to_float(np.array(out.get_tensor()))
hist, prob = output_hist(result)
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomBatchSizeLikeOpBF16API(unittest.TestCase):
def test_attr_tensorlist_int32_API(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
input = fluid.data(name="input", shape=[1, 3], dtype='uint16')
out_1 = fluid.layers.uniform_random_batch_size_like(
input, [2, 4], dtype=np.uint16) # out_1.shape=[1, 4]
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run(train_program, fetch_list=[out_1])
if __name__ == "__main__":
from paddle import enable_static
enable_static()
unittest.main()
......@@ -498,6 +498,7 @@ STATIC_MODE_TESTING_LIST = [
'test_truncated_gaussian_random_op',
'test_unbind_op',
'test_unfold_op',
'test_uniform_random_bf16_op',
'test_uniform_random_op',
'test_unique',
'test_unique_with_counts',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册