提交 1fda6e1c 编写于 作者: F FDInSky 提交者: lvmengsi

test=develop enhance uniform_random op python api (#20376)

上级 3c185cfa
...@@ -75,7 +75,11 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -75,7 +75,11 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val")); auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
if (diag_num > 0) { if (diag_num > 0) {
PADDLE_ENFORCE_GT(size, (diag_num - 1) * (diag_step + 1), PADDLE_ENFORCE_GT(size, (diag_num - 1) * (diag_step + 1),
"The index of diagonal elements is out of bounds"); "ShapeError: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d",
diag_num, diag_step, (diag_num - 1) * (diag_step + 1),
size);
for (int64_t i = 0; i < diag_num; ++i) { for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i; int64_t pos = i * diag_step + i;
data[pos] = diag_val; data[pos] = diag_val;
...@@ -118,9 +122,10 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -118,9 +122,10 @@ class UniformRandomOp : public framework::OperatorWithKernel {
auto shape_dims = ctx->GetInputDim("ShapeTensor"); auto shape_dims = ctx->GetInputDim("ShapeTensor");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
shape_dims.size(), 1, shape_dims.size(), 1,
"Input(ShapeTensor)' dimension size of Op(uniform_random) must be 1." "ShapeError: Input(ShapeTensor)' dimension size of "
"Please check the Attr(shape)'s dimension size of" "Op(uniform_random) must be 1."
"Op(fluid.layers.uniform_random).)"); "But received ShapeTensor's dimensions = %d, shape = [%s]",
shape_dims.size(), shape_dims);
int num_ele = 1; int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) { for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i]; num_ele *= shape_dims[i];
......
...@@ -16597,10 +16597,18 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -16597,10 +16597,18 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
""" """
if not (isinstance(shape, (list, tuple, Variable))): if not (isinstance(shape, (list, tuple, Variable))):
raise TypeError("Input shape must be a python list,Variable or tuple.") raise TypeError(
"Input shape must be a python list,Variable or tuple. But received %s"
% (type(shape)))
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if convert_dtype(dtype) not in ['float32', 'float64']:
raise TypeError(
"The attribute dtype in uniform_random op must be float32 or float64, but received %s."
% (convert_dtype(dtype)))
def contain_var(one_list): def contain_var(one_list):
for ele in one_list: for ele in one_list:
if isinstance(ele, Variable): if isinstance(ele, Variable):
......
...@@ -20,6 +20,7 @@ from op_test import OpTest ...@@ -20,6 +20,7 @@ from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def output_hist(out): def output_hist(out):
...@@ -116,6 +117,27 @@ class TestUniformRandomOp(OpTest): ...@@ -116,6 +117,27 @@ class TestUniformRandomOp(OpTest):
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpError(OpTest):
def test_errors(self):
main_prog = Program()
start_prog = Program()
with program_guard(main_prog, start_prog):
def test_Variable():
x1 = fluid.create_lod_tensor(
np.zeros((4, 784)), [[1, 1, 1, 1]], fluid.CPUPlace())
fluid.layers.uniform_random(x1)
self.assertRaises(TypeError, test_Variable)
def test_dtype():
x2 = fluid.layers.data(
name='x2', shape=[4, 784], dtype='float32')
fluid.layers.uniform_random(x2, 'int32')
self.assertRaises(TypeError, test_dtype)
class TestUniformRandomOpWithDiagInit(TestUniformRandomOp): class TestUniformRandomOpWithDiagInit(TestUniformRandomOp):
def init_attrs(self): def init_attrs(self):
self.attrs = { self.attrs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册