未验证 提交 361c6ccc 编写于 作者: Z Zhong Hui 提交者: GitHub

OP error message enhancement of l2_normalize, matmul, mean, etc

* fix error message of l2_normalize, matmul, mean, etc. 
* add the test case for those ops
上级 b3520b14
......@@ -324,12 +324,9 @@ class MatMulOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of MatMulOp should not be null.");
PADDLE_ENFORCE(context->HasInput("Y"),
"Input(Y) of MatMulOp should not be null.");
PADDLE_ENFORCE(context->HasOutput("Out"),
"Output(Out) of MatMulOp should not be null.");
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "matmul");
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
......@@ -349,14 +346,15 @@ class MatMulOp : public framework::OperatorWithKernel {
}
if (context->IsRuntime()) {
PADDLE_ENFORCE(
PADDLE_ENFORCE_EQ(
mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0,
"ShapeError: The batch size of the two matrices should be equal, or "
"at least one is zero.\n"
"But received X's shape: %s, Y's shape: %s.",
DumpMatrixShape(mat_dim_x).c_str(),
DumpMatrixShape(mat_dim_y).c_str());
true, platform::errors::InvalidArgument(
"The batch size of the two matrices should be equal, or "
"at least one is zero.\n"
"But received X's shape: %s, Y's shape: %s.",
DumpMatrixShape(mat_dim_x).c_str(),
DumpMatrixShape(mat_dim_y).c_str()));
}
int64_t dim_out_y = mat_dim_y.width_;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
......@@ -365,23 +363,23 @@ class MatMulOp : public framework::OperatorWithKernel {
if (context->IsRuntime()) {
PADDLE_ENFORCE_LE(
head_number, mat_dim_x.width_,
"ShapeError: Unsatisfied mkl acceleration library requirements: "
"The number of heads "
"(%d) must be equal to X's width. But received X's shape: %s.",
head_number, DumpMatrixShape(mat_dim_x).c_str());
platform::errors::InvalidArgument(
"Unsatisfied mkl acceleration library requirements: "
"The number of heads "
"(%d) must be equal to X's width. But received X's shape: %s.",
head_number, DumpMatrixShape(mat_dim_x).c_str()));
if (!split_vertical_y && head_number > 0) {
dim_out_y = head_number * mat_dim_y.width_;
}
}
#else
PADDLE_ENFORCE_EQ(
mat_dim_x.width_, mat_dim_y.height_,
platform::errors::InvalidArgument(
"ShapeError: Input X's width should be equal to the Y's height, "
"but received X's shape: [%s],"
"Y's shape: [%s].",
dim_x, dim_y));
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_,
platform::errors::InvalidArgument(
"Input X's width should be equal to the Y's height, "
"but received X's shape: [%s],"
"Y's shape: [%s].",
dim_x, dim_y));
#endif
std::vector<int64_t> dim_out;
......@@ -520,10 +518,10 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "matmul");
auto x_dims = context->GetInputDim("X");
auto y_dims = context->GetInputDim("Y");
......
......@@ -25,10 +25,8 @@ class MeanOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MeanOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MeanOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mean");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mean");
ctx->SetOutputDim("Out", {1});
}
};
......
......@@ -59,17 +59,20 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err,
"MeanOP failed to get reduce workspace size",
cudaGetErrorString(err));
PADDLE_ENFORCE_CUDA_SUCCESS(
err, platform::errors::External(
"MeanOP failed to get reduce workspace size %s.",
cudaGetErrorString(err)));
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
context.GetPlace());
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err, "MeanOP failed to run reduce computation",
cudaGetErrorString(err));
PADDLE_ENFORCE_CUDA_SUCCESS(
err, platform::errors::External(
"MeanOP failed to run CUDA reduce computation: %s.",
cudaGetErrorString(err)));
}
};
......@@ -78,11 +81,11 @@ class MeanCUDAGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
OG->numel(), 1,
platform::errors::InvalidArgument(
"Mean Gradient Input Tensor len should be 1. But received %d",
OG->numel()));
PADDLE_ENFORCE_EQ(OG->numel(), 1,
platform::errors::InvalidArgument(
"Mean Gradient Input Tensor len should be 1. But "
"received Out@Grad's elements num is %d.",
OG->numel()));
auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace());
......
......@@ -50,7 +50,11 @@ class MeanGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(OG->numel() == 1, "Mean Gradient should be scalar");
PADDLE_ENFORCE_EQ(OG->numel(), 1UL,
platform::errors::InvalidArgument(
"Mean Gradient should be scalar. But received "
"Out@Grad's elements num is %d.",
OG->numel()));
auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace());
......
......@@ -74,10 +74,9 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
x_mat_dims[1], y_mat_dims[0],
platform::errors::InvalidArgument(
"After flatten the input tensor X and Y to 2-D dimensions "
"matrix X1 and Y1, the matrix X1's width must be equal with matrix "
"Y1's height. But received X's shape = [%s], X1's shape = [%s], "
"X1's "
"After flatten the input tensor X and Y to 2-D dimensions matrix "
"X1 and Y1, the matrix X1's width must be equal with matrix Y1's "
"height. But received X's shape = [%s], X1's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = "
"%s.",
x_dims, x_mat_dims, x_mat_dims[1], y_dims, y_mat_dims,
......@@ -212,10 +211,10 @@ class MulGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mul");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "mul");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "mul");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
......@@ -253,9 +252,9 @@ class MulDoubleGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mul");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "mul");
OP_INOUT_CHECK(ctx->HasInput("DOut"), "Input", "DOut", "mul");
if (ctx->HasOutput("DDOut") &&
(ctx->HasInput("DDX") || (ctx->HasInput("DDY")))) {
......
......@@ -43,7 +43,11 @@ struct Array {
template <typename VectorLikeType>
static inline Array<T, ElementCount> From(const VectorLikeType& vec) {
PADDLE_ENFORCE_EQ(vec.size(), ElementCount, "size not match");
PADDLE_ENFORCE_EQ(vec.size(), ElementCount,
platform::errors::InvalidArgument(
"Cub reduce Array: size not match. Received "
"vec.size() %d != ElementCount %d.",
vec.size(), ElementCount));
size_t n = static_cast<size_t>(vec.size());
Array<T, ElementCount> ret;
for (size_t i = 0; i < n; ++i) ret[i] = vec[i];
......@@ -159,13 +163,20 @@ static inline int GetDesiredBlockDim(int block_dim) {
static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
if (rank % 2 == 0) {
PADDLE_ENFORCE_EQ(reduce_rank, rank / 2);
PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
platform::errors::InvalidArgument(
"ReduceOp: invalid reduce rank. When rank = %d, "
"reduce_rank must be %d, but got %d.",
rank, rank / 2, reduce_rank));
} else {
auto lower_rank = (rank - 1) / 2;
auto upper_rank = (rank + 1) / 2;
PADDLE_ENFORCE(reduce_rank == lower_rank || reduce_rank == upper_rank,
"When rank = %d, reduce_rank must be %d or %d, but got %d",
rank, lower_rank, upper_rank, reduce_rank);
PADDLE_ENFORCE_EQ(
reduce_rank == lower_rank || reduce_rank == upper_rank, true,
platform::errors::InvalidArgument(
"ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
"must be %d or %d, but got %d.",
rank, lower_rank, upper_rank, reduce_rank));
}
}
......
......@@ -264,31 +264,31 @@ class ReduceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReduceOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceOp");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6,
"ShapeError: The input tensor X's dimensions of Reduce "
"should be less equal than 6. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank, x_dims);
platform::errors::InvalidArgument(
"The input tensor X's dimensions of ReduceOp "
"should be less equal than 6. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank, x_dims));
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
PADDLE_ENFORCE_GT(
dims.size(), 0,
"ShapeError: The input dim dimensions of Reduce "
"should be greater than 0. But received the dim dimesions of Reduce "
" = %d",
dims.size());
PADDLE_ENFORCE_GT(dims.size(), 0,
platform::errors::InvalidArgument(
"The input dim dimensions of ReduceOp "
"should be greater than 0. But received the dim "
"dimesions of Reduce = %d.",
dims.size()));
for (size_t i = 0; i < dims.size(); ++i) {
PADDLE_ENFORCE_LT(dims[i], x_rank,
"ShapeError: The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)]."
"which dimesion = %d, But received dim index = %d",
i, x_rank, dims[i]);
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i]));
if (dims[i] < 0) dims[i] = x_rank + dims[i];
}
sort(dims.begin(), dims.end());
......@@ -346,19 +346,24 @@ class ReduceGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "ReduceOp");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
PADDLE_ENFORCE_LE(x_rank, 6,
platform::errors::InvalidArgument(
"Tensors with rank at most 6 are supported by "
"ReduceOp. Received tensor with rank %d.",
x_rank));
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
PADDLE_ENFORCE_LT(dims[i], x_rank,
"ShapeError: The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)]."
"which dimesion = %d, But received dim index = %d",
i, x_rank, dims[i]);
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)], "
"which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i]));
if (dims[i] < 0) dims[i] = x_rank + dims[i];
}
sort(dims.begin(), dims.end());
......
......@@ -94,17 +94,24 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "UniformRandom");
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"),
ctx->Attrs().Get<float>("max"),
platform::errors::InvalidArgument(
"uniform_random's min must less then max"));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "UniformRandomOp");
PADDLE_ENFORCE_LT(
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
platform::errors::InvalidArgument(
"The uniform_random's min must less then max. But received min = "
"%f great than or equal max = %f.",
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
platform::errors::InvalidArgument(
"diag_num must greater than or equal 0"));
"The uniform_random's diag_num must greater than or "
"equal 0. But recevied diag_num (%d) < 0.",
ctx->Attrs().Get<int>("diag_num")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
platform::errors::InvalidArgument(
"diag_step must greater than or equal 0"));
"The uniform_random's diag_step must greater than or "
"equal 0. But recevied diag_step (%d) < 0.",
ctx->Attrs().Get<int>("diag_step")));
if (ctx->HasInputs("ShapeTensorList")) {
// top prority shape
......
......@@ -59,8 +59,12 @@ inline std::vector<int64_t> GetNewDataFromShapeTensorList(
vec_new_shape.reserve(list_new_shape_tensor.size());
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
platform::errors::InvalidArgument(
"Shape of dim tensor in uniform_random_op should be [1]"
"But received tensor's dim=%s.",
tensor->dims()));
if (tensor->type() == framework::proto::VarType::INT32) {
if (platform::is_gpu_place(tensor->place())) {
......
......@@ -4627,6 +4627,7 @@ def reduce_all(input, dim=None, keep_dim=False, name=None):
# keep_dim=True, x.shape=(2,2), out.shape=(2,1)
"""
check_variable_and_dtype(input, 'input', ('bool'), 'reduce_all')
helper = LayerHelper('reduce_all', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
......@@ -4686,6 +4687,7 @@ def reduce_any(input, dim=None, keep_dim=False, name=None):
# keep_dim=True, x.shape=(2,2), out.shape=(2,1)
"""
check_variable_and_dtype(input, 'input', ('bool'), 'reduce_any')
helper = LayerHelper('reduce_any', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
......@@ -4919,8 +4921,9 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
if len(x.shape) == 1:
axis = 0
helper = LayerHelper("l2_normalize", **locals())
check_variable_and_dtype(x, "X", ("float32", "float64"), "norm")
helper = LayerHelper("l2_normalize", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
norm = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -9951,6 +9954,11 @@ def uniform_random_batch_size_like(input,
"""
check_variable_and_dtype(input, 'Input', ("float32", 'float64'),
'uniform_random_batch_size_like')
check_type(shape, 'shape', (list, tuple), 'uniform_random_batch_size_like')
check_dtype(dtype, 'dtype', ('float32', 'float64'),
'uniform_random_batch_size_like')
helper = LayerHelper('uniform_random_batch_size_like', **locals())
out = helper.create_variable_for_type_inference(dtype)
......@@ -10634,7 +10642,7 @@ def rank(input):
input = fluid.data(name="input", shape=[3, 100, 100], dtype="float32")
rank = fluid.layers.rank(input) # rank=(3,)
"""
check_type(input, 'input', (Variable), 'input')
ndims = len(input.shape)
out = assign(np.array(ndims, 'int32'))
......@@ -14263,7 +14271,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random')
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'uniform_random')
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random')
def get_new_shape_tensor(list_shape):
new_shape_tensor = []
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
......@@ -87,5 +89,16 @@ class TestNormOp5(TestNormOp):
pass
class API_NormTest(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program()):
def test_norm_x_type():
data = fluid.data(name="x", shape=[3, 3], dtype="int64")
out = fluid.layers.l2_normalize(data)
self.assertRaises(TypeError, test_norm_x_type)
if __name__ == '__main__':
unittest.main()
......@@ -138,6 +138,18 @@ class TestAllOpWithKeepDim(OpTest):
self.check_output()
class TestAllOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of reduce_all_op must be Variable.
input1 = 12
self.assertRaises(TypeError, fluid.layers.reduce_all, input1)
# The input dtype of reduce_all_op must be bool.
input2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.reduce_all, input2)
class TestAnyOp(OpTest):
def setUp(self):
self.op_type = "reduce_any"
......@@ -174,6 +186,18 @@ class TestAnyOpWithKeepDim(OpTest):
self.check_output()
class TestAnyOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of reduce_any_op must be Variable.
input1 = 12
self.assertRaises(TypeError, fluid.layers.reduce_any, input1)
# The input dtype of reduce_any_op must be bool.
input2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.reduce_any, input2)
class Test1DReduce(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
......
......@@ -177,6 +177,12 @@ class TestUniformRandomOpError(unittest.TestCase):
self.assertRaises(TypeError, test_Variable)
def test_Variable2():
x1 = np.zeros((4, 784))
fluid.layers.uniform_random(x1)
self.assertRaises(TypeError, test_Variable2)
def test_dtype():
x2 = fluid.layers.data(
name='x2', shape=[4, 784], dtype='float32')
......@@ -426,5 +432,33 @@ class TestUniformRandomDygraphMode(unittest.TestCase):
self.assertTrue((x_np[i] > 0 and x_np[i] < 1.0))
class TestUniformRandomBatchSizeLikeOpError(unittest.TestCase):
def test_errors(self):
main_prog = Program()
start_prog = Program()
with program_guard(main_prog, start_prog):
def test_Variable():
x1 = fluid.create_lod_tensor(
np.zeros((4, 784)), [[1, 1, 1, 1]], fluid.CPUPlace())
fluid.layers.uniform_random_batch_size_like(x1)
self.assertRaises(TypeError, test_Variable)
def test_shape():
x1 = fluid.layers.data(
name='x2', shape=[4, 784], dtype='float32')
fluid.layers.uniform_random_batch_size_like(x1, shape="shape")
self.assertRaises(TypeError, test_shape)
def test_dtype():
x2 = fluid.layers.data(
name='x2', shape=[4, 784], dtype='float32')
fluid.layers.uniform_random_batch_size_like(x2, 'int32')
self.assertRaises(TypeError, test_dtype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册