未验证 提交 a13a4dbc 编写于 作者: L lilong12 提交者: GitHub

Improving error reporting messages for ops (#24438)

* improve error reporting message
上级 897cec81
...@@ -27,16 +27,18 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -27,16 +27,18 @@ class CropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Crop");
"Input(X) of CropOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Crop");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CropOp should not be null.");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
if (!ctx->HasInput("Y")) { if (!ctx->HasInput("Y")) {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(), int64_t(shape.size()), x_dim.size(),
"Shape size should be equal to dimension size of input tensor."); platform::errors::InvalidArgument(
"The number of elements (%d) of CropOp's "
"'shape' attribute should be equal to the number of dimensions "
"(%d) of the Input(X).",
shape.size(), x_dim.size()));
std::vector<int64_t> tensor_shape(shape.size()); std::vector<int64_t> tensor_shape(shape.size());
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]); tensor_shape[i] = static_cast<int64_t>(shape[i]);
...@@ -45,8 +47,10 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -45,8 +47,10 @@ class CropOp : public framework::OperatorWithKernel {
} else { } else {
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim), PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim),
"Tensor rank of both CropOp's " platform::errors::InvalidArgument(
"inputs must be same."); "The number of dimensions (%d) of CropOp's input(X)"
" must be equal to that (%d) of input(Y).",
framework::arity(x_dim), framework::arity(y_dim)));
ctx->SetOutputDim("Out", y_dim); ctx->SetOutputDim("Out", y_dim);
} }
} }
...@@ -163,9 +167,9 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -163,9 +167,9 @@ class CropOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CropGrad");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) should not be null"); framework::GradVarName("Out"), "CropGrad");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
......
...@@ -31,14 +31,23 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) { ...@@ -31,14 +31,23 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
std::vector<int> res; std::vector<int> res;
int rank = ctx.Input<Tensor>("X")->dims().size(); int rank = ctx.Input<Tensor>("X")->dims().size();
if (ctx.HasInput("Offsets")) { if (ctx.HasInput("Offsets")) {
PADDLE_ENFORCE(ctx.Attr<std::vector<int>>("offsets").empty(), PADDLE_ENFORCE_EQ(ctx.Attr<std::vector<int>>("offsets").empty(), true,
"Input 'Offsets' and attribute 'offsets' should not be used " platform::errors::InvalidArgument(
"at the same time."); "Input 'Offsets' and attribute 'offsets' "
"should not be used at the same time for CropOp."));
const auto* offsets_tensor = ctx.Input<Tensor>("Offsets"); const auto* offsets_tensor = ctx.Input<Tensor>("Offsets");
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1); PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The number of dimensions of input 'Offsets' for "
"CropOp must be 1, but the value received is %d.",
offsets_tensor->dims().size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rank, offsets_tensor->dims()[0], rank, offsets_tensor->dims()[0],
"Offsets size should be equal to dimension size of input tensor."); platform::errors::InvalidArgument("The number of elements (%d) for "
"input 'Offsets' must be equal to "
"the number of dimensions (%d) "
"of the input tensor.",
offsets_tensor->dims()[0], rank));
const int* offsets_data; const int* offsets_data;
framework::Tensor cpu_tmp_tensor; framework::Tensor cpu_tmp_tensor;
if (platform::is_cpu_place(offsets_tensor->place())) { if (platform::is_cpu_place(offsets_tensor->place())) {
...@@ -53,7 +62,11 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) { ...@@ -53,7 +62,11 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
res = ctx.Attr<std::vector<int>>("offsets"); res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rank, static_cast<int>(res.size()), rank, static_cast<int>(res.size()),
"Offsets size should be equal to dimension size of input tensor."); platform::errors::InvalidArgument("The number of elements (%d) for "
"input 'Offsets' must be equal to "
"the number of dimensions (%d) "
"of the input tensor.",
res.size(), rank));
} }
return res; return res;
} }
...@@ -92,6 +105,18 @@ class CropKernel : public framework::OpKernel<T> { ...@@ -92,6 +105,18 @@ class CropKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The number of dimensions of the Input(X) for CropOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6,
platform::errors::InvalidArgument(
"The number of dimensions of the Input(X) for CropOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) { switch (rank) {
case 1: case 1:
CropFunction<DeviceContext, T, 1>(context); CropFunction<DeviceContext, T, 1>(context);
...@@ -111,9 +136,6 @@ class CropKernel : public framework::OpKernel<T> { ...@@ -111,9 +136,6 @@ class CropKernel : public framework::OpKernel<T> {
case 6: case 6:
CropFunction<DeviceContext, T, 6>(context); CropFunction<DeviceContext, T, 6>(context);
break; break;
default:
PADDLE_THROW(
"CropOp only support tensors with no more than 6 dimensions.");
} }
} }
}; };
...@@ -145,6 +167,18 @@ class CropGradKernel : public framework::OpKernel<T> { ...@@ -145,6 +167,18 @@ class CropGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size(); context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
PADDLE_ENFORCE_GE(
rank, 1, platform::errors::InvalidArgument(
"The number of dimensions of the input 'Out@GRAD' for "
"CropGrad must be greater than or equal "
"to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6, platform::errors::InvalidArgument(
"The number of dimensions of the input 'Out@GRAD' for "
"CropGrad must be less than or equal "
"to 6, but the value received is %d.",
rank));
switch (rank) { switch (rank) {
case 1: case 1:
CropGradFunction<DeviceContext, T, 1>(context); CropGradFunction<DeviceContext, T, 1>(context);
...@@ -164,9 +198,6 @@ class CropGradKernel : public framework::OpKernel<T> { ...@@ -164,9 +198,6 @@ class CropGradKernel : public framework::OpKernel<T> {
case 6: case 6:
CropGradFunction<DeviceContext, T, 6>(context); CropGradFunction<DeviceContext, T, 6>(context);
break; break;
default:
PADDLE_THROW(
"CropOp only support tensors with no more than 6 dimensions.");
} }
} }
}; };
......
...@@ -27,10 +27,8 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -27,10 +27,8 @@ class CropTensorOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CropTensor");
"Input(X) of Op(crop_tensor) should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CropTensor");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Op(crop_tensor) should not be null.");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets"); auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets");
...@@ -39,9 +37,11 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -39,9 +37,11 @@ class CropTensorOp : public framework::OperatorWithKernel {
auto inputs_name = ctx->Inputs("ShapeTensor"); auto inputs_name = ctx->Inputs("ShapeTensor");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
inputs_name.size(), 0, inputs_name.size(), 0,
"Input(ShapeTensor)'size of Op(crop_tensor) can't be zero. " platform::errors::InvalidArgument(
"Please check the Attr(shape)'s size of " "The number of elements of the input 'ShapeTensor' for "
"Op(fluid.layers.crop_tensor)."); "CropTensor must be greater than zero, "
"but the value received is %d.",
inputs_name.size()));
auto out_dims = std::vector<int>(inputs_name.size(), -1); auto out_dims = std::vector<int>(inputs_name.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] > 0) { if (shape[i] > 0) {
...@@ -59,16 +59,18 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -59,16 +59,18 @@ class CropTensorOp : public framework::OperatorWithKernel {
if (ctx->HasInput("Shape")) { if (ctx->HasInput("Shape")) {
auto shape_dim = ctx->GetInputDim("Shape"); auto shape_dim = ctx->GetInputDim("Shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(shape_dim.size(), 1,
shape_dim.size(), 1, platform::errors::InvalidArgument(
"Input(Shape)'s dimension size of Op(crop_tensor) must be 1. " "The number of dimensions of the input "
"Please check the Attr(shape)'s dimension size of " "'Shape' for CropTensor must be 1, "
"Op(fluid.layers.crop_tensor)."); "but the value received is %d.",
shape_dim.size()));
PADDLE_ENFORCE_EQ(shape_dim[0], x_dim.size(), PADDLE_ENFORCE_EQ(shape_dim[0], x_dim.size(),
"Input(Shape)'s size of Op(crop_tensor) must be equal " platform::errors::InvalidArgument(
"to dimension size of input tensor. " "The number of elements (%d) of the input 'Shape' "
"Please check the Attr(shape)'s size of " "for CropTensor must be equal to the number of"
"Op(fluid.layers.crop_tensor)."); " dimensions (%d) of the input.",
shape_dim[0], x_dim.size()));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
// If true, set the shape of Output(Out) according to Input(Shape) in // If true, set the shape of Output(Out) according to Input(Shape) in
// CropTensorKernel with ExecutionContext. Also check LoD in // CropTensorKernel with ExecutionContext. Also check LoD in
...@@ -80,9 +82,13 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -80,9 +82,13 @@ class CropTensorOp : public framework::OperatorWithKernel {
} }
return; return;
} }
PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(), PADDLE_ENFORCE_EQ(
"Attr(shape)'size of Op(crop_tensor) should be equal to " int64_t(shape.size()), x_dim.size(),
"dimension size of input tensor."); platform::errors::InvalidArgument(
"The number of elements (%d) of attribute 'shape' for "
"CropTensor must be equal to the number of "
"dimensions (%d) of the input.",
shape.size(), x_dim.size()));
std::vector<int64_t> out_shape(shape.size(), -1); std::vector<int64_t> out_shape(shape.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] > 0) { if (shape[i] > 0) {
...@@ -242,10 +248,9 @@ class CropTensorOpGrad : public framework::OperatorWithKernel { ...@@ -242,10 +248,9 @@ class CropTensorOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CropTensorGrad");
"Input(X) of Op(crop_tensor) should not be null."); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, framework::GradVarName("Out"), "CropTensorGrad");
"Input(Out@GRAD) of Op(crop_tensor) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
......
...@@ -35,7 +35,10 @@ inline std::vector<int> get_new_data( ...@@ -35,7 +35,10 @@ inline std::vector<int> get_new_data(
auto tensor = list_new_tensor[i]; auto tensor = list_new_tensor[i];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}), tensor->dims(), framework::make_ddim({1}),
"The tensor's shape in list of Op(crop_tensor) should be [1]."); platform::errors::InvalidArgument(
"The tensor's shape in list of Op(crop_tensor) should be [1], "
"but the value received is %d.",
tensor->dims()));
if (platform::is_gpu_place(tensor->place())) { if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp; framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp); TensorCopySync(*tensor, platform::CPUPlace(), &temp);
...@@ -56,18 +59,23 @@ static framework::DDim ValidateShape(const std::vector<int> shape, ...@@ -56,18 +59,23 @@ static framework::DDim ValidateShape(const std::vector<int> shape,
auto shape_size = shape.size(); auto shape_size = shape.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dim_size, shape_size, in_dim_size, shape_size,
"Attr(shape)'s size of Op(crop_tensor) should be equal " platform::errors::InvalidArgument(
"to that of input Tensor. " "The number of elements (%d) for shape of Op(crop_tensor) should be "
"Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor)."); "equal to the number of dimensions (%d) of the input tensor.",
shape_size, in_dim_size));
std::vector<int64_t> output_shape(shape.size(), 0); std::vector<int64_t> output_shape(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] <= 0 && in_dims[i] > 0) { if (shape[i] <= 0 && in_dims[i] > 0) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(shape[i], 0,
shape[i], 0, platform::errors::InvalidArgument(
"The element in Attr(shape) of Op(crop_tensor) should not be zero."); "The value (%d) of the %uth element for shape of "
PADDLE_ENFORCE_EQ(shape[i], -1, "Op(crop_tensor) should not be zero.",
"When the element in Attr(shape) of Op(crop_tensor) is " shape[i], i));
"negative, only -1 is supported."); PADDLE_ENFORCE_EQ(shape[i], -1, platform::errors::InvalidArgument(
"When the value (%d) of the %uth "
"element for shape of Op(crop_tensor)"
" is negative, only -1 is supported.",
shape[i], i));
output_shape[i] = in_dims[i] - offsets[i]; output_shape[i] = in_dims[i] - offsets[i];
} else { } else {
output_shape[i] = static_cast<int64_t>(shape[i]); output_shape[i] = static_cast<int64_t>(shape[i]);
...@@ -83,9 +91,13 @@ static std::vector<int> GetShape(const framework::ExecutionContext& ctx) { ...@@ -83,9 +91,13 @@ static std::vector<int> GetShape(const framework::ExecutionContext& ctx) {
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("ShapeTensor"); auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) { if (list_new_shape_tensor.size() > 0) {
// have offsets tensor list // have offsets tensor list
PADDLE_ENFORCE_EQ(list_new_shape_tensor.size(), rank, PADDLE_ENFORCE_EQ(
"Input(ShapeTensor)'s length of Op(crop_tensor) should " list_new_shape_tensor.size(), rank,
"be equal to dimension size of input tensor."); platform::errors::InvalidArgument(
"The number of tensors (%d) for the input ShapeTensor of "
"Op(crop_tensor) must be equal to the number of "
"dimensions (%d) of the input.",
list_new_shape_tensor.size(), rank));
res = get_new_data(list_new_shape_tensor); res = get_new_data(list_new_shape_tensor);
return res; return res;
...@@ -122,13 +134,21 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) { ...@@ -122,13 +134,21 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
if (ctx.HasInput("Offsets")) { if (ctx.HasInput("Offsets")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<std::vector<int>>("offsets").empty(), true, ctx.Attr<std::vector<int>>("offsets").empty(), true,
"Input 'Offsets' and attribute 'offsets' should not be used " platform::errors::InvalidArgument(
"at the same time."); "Input 'Offsets' and attribute 'offsets' for Op(crop_tensor) "
"cannot be used at the same time."));
const auto* offsets_tensor = ctx.Input<Tensor>("Offsets"); const auto* offsets_tensor = ctx.Input<Tensor>("Offsets");
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1); PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1,
PADDLE_ENFORCE_EQ( platform::errors::InvalidArgument(
rank, offsets_tensor->dims()[0], "The number of dimensions of input 'Offsets' must "
"Offsets size should be equal to dimension size of input tensor."); "be 1, but the value received is: %d.",
offsets_tensor->dims().size()));
PADDLE_ENFORCE_EQ(rank, offsets_tensor->dims()[0],
platform::errors::InvalidArgument(
"The number of elements (%d) for "
"input 'Offsets' must be equal to "
"the number of dimensions (%d) of the input tensor.",
offsets_tensor->dims()[0], rank));
const int* offsets_data; const int* offsets_data;
framework::Tensor cpu_tmp_tensor; framework::Tensor cpu_tmp_tensor;
if (platform::is_cpu_place(offsets_tensor->place())) { if (platform::is_cpu_place(offsets_tensor->place())) {
...@@ -143,7 +163,11 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) { ...@@ -143,7 +163,11 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
res = ctx.Attr<std::vector<int>>("offsets"); res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rank, static_cast<int>(res.size()), rank, static_cast<int>(res.size()),
"Offsets size should be equal to dimension size of input tensor."); platform::errors::InvalidArgument("The number of elements (%d) for "
"input 'Offsets' must be equal to "
"the number of dimensions (%d) "
"of the input tensor.",
static_cast<int>(res.size()), rank));
} }
return res; return res;
} }
...@@ -168,10 +192,13 @@ void CropTensorFunction(const framework::ExecutionContext& context) { ...@@ -168,10 +192,13 @@ void CropTensorFunction(const framework::ExecutionContext& context) {
out_dims = ValidateShape(shape, offsets, x->dims()); out_dims = ValidateShape(shape, offsets, x->dims());
out->mutable_data<T>(out_dims, context.GetPlace()); out->mutable_data<T>(out_dims, context.GetPlace());
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(offsets[i] + shape[i], x_dims[i],
offsets[i] + shape[i], x_dims[i], platform::errors::InvalidArgument(
"The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) " "The sum of the %uth elements of "
"should be less than or equal to corresponding input dimension size."); "offsets (%d) and shape (%d) of Op(crop_tensor) "
"should be less than or "
"equal to the size of %uth dimension of the input.",
i, offsets[i], shape[i], i));
} }
auto x_tensor = EigenTensor<T, D>::From(*x); auto x_tensor = EigenTensor<T, D>::From(*x);
...@@ -192,6 +219,19 @@ class CropTensorKernel : public framework::OpKernel<T> { ...@@ -192,6 +219,19 @@ class CropTensorKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The number of dimensions of the input 'x' for "
"Op(crop_tensor) must be greater than or equal to 1, but the "
"value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6, platform::errors::InvalidArgument(
"The number of dimensions of the input 'x' for "
"Op(crop_tensor) must be less than or equal to 6, but the "
"value received is %d.",
rank));
switch (rank) { switch (rank) {
case 1: case 1:
CropTensorFunction<DeviceContext, T, 1>(context); CropTensorFunction<DeviceContext, T, 1>(context);
...@@ -211,10 +251,6 @@ class CropTensorKernel : public framework::OpKernel<T> { ...@@ -211,10 +251,6 @@ class CropTensorKernel : public framework::OpKernel<T> {
case 6: case 6:
CropTensorFunction<DeviceContext, T, 6>(context); CropTensorFunction<DeviceContext, T, 6>(context);
break; break;
default:
PADDLE_THROW(
"CropTensorOp only support tensors with no more than 6 "
"dimensions.");
} }
} }
}; };
...@@ -246,6 +282,20 @@ class CropTensorGradKernel : public framework::OpKernel<T> { ...@@ -246,6 +282,20 @@ class CropTensorGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size(); context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The number of dimensions of the input 'Out@GRAD' for "
"Op(crop_tensor_grad) must be greater than or equal to 1, but the "
"value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6,
platform::errors::InvalidArgument(
"The number of dimensions of the input 'Out@GRAD' for "
"Op(crop_tensor_grad) must be less than or equal to 6, but the "
"value received is %d.",
rank));
switch (rank) { switch (rank) {
case 1: case 1:
CropTensorGradFunction<DeviceContext, T, 1>(context); CropTensorGradFunction<DeviceContext, T, 1>(context);
...@@ -265,10 +315,6 @@ class CropTensorGradKernel : public framework::OpKernel<T> { ...@@ -265,10 +315,6 @@ class CropTensorGradKernel : public framework::OpKernel<T> {
case 6: case 6:
CropTensorGradFunction<DeviceContext, T, 6>(context); CropTensorGradFunction<DeviceContext, T, 6>(context);
break; break;
default:
PADDLE_THROW(
"CropTensorOp only support tensors with no more than 6 "
"dimensions.");
} }
} }
}; };
......
...@@ -28,9 +28,8 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -28,9 +28,8 @@ class ExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Expand");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Expand");
"Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times"); auto expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
...@@ -38,11 +37,19 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -38,11 +37,19 @@ class ExpandOp : public framework::OperatorWithKernel {
expand_times = std::vector<int>(x_dims.size(), -1); expand_times = std::vector<int>(x_dims.size(), -1);
} }
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(), PADDLE_ENFORCE_EQ(
"The number of Attr(expand_times)'s value must be equal " static_cast<size_t>(x_dims.size()), expand_times.size(),
"to the rank of Input(X)."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_LE(x_dims.size(), 6, "The number of elements (%d) of 'expand_times' for "
"The rank of Input(X) must not be greater than 6."); "Op(expand) must be equal to the number of dimensions "
"(%d) of the input.",
expand_times.size(), static_cast<size_t>(x_dims.size())));
PADDLE_ENFORCE_LE(
x_dims.size(), 6,
platform::errors::InvalidArgument(
"The number of dimensions of the input for Op(expand) "
"must not be greater than 6, but the value received is %d.",
x_dims.size()));
std::vector<int64_t> out_shape(x_dims.size()); std::vector<int64_t> out_shape(x_dims.size());
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
...@@ -51,7 +58,10 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -51,7 +58,10 @@ class ExpandOp : public framework::OperatorWithKernel {
} else { } else {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
expand_times[i], 0, expand_times[i], 0,
"The element of Attr(expand_times) must greater than 0."); platform::errors::InvalidArgument(
"The %uth element of 'expand_times' for Op(expand) must be "
"greater than 0, but the value given is %d.",
i, expand_times[i]));
out_shape[i] = x_dims[i] * expand_times[i]; out_shape[i] = x_dims[i] * expand_times[i];
} }
} }
...@@ -139,9 +149,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -139,9 +149,9 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandGrad");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) should not be null."); framework::GradVarName("Out"), "ExpandGrad");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times = std::vector<int> expand_times =
...@@ -153,8 +163,10 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -153,8 +163,10 @@ class ExpandGradOp : public framework::OperatorWithKernel {
if (!ctx->IsRuntime() && x_dims[0] < 0) { if (!ctx->IsRuntime() && x_dims[0] < 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[0], out_dims[0], x_dims[0], out_dims[0],
"The first dimension size of Input(Out@GRAD) should be " platform::errors::InvalidArgument(
"equal to the crroresponding dimension size of Input(X)"); "The first dimension size (%d) of Input(Out@GRAD) should be "
"equal to the crroresponding dimension size (%d) of Input(X)",
out_dims[0], x_dims[0]));
start_pos = 1u; start_pos = 1u;
} }
...@@ -165,9 +177,11 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -165,9 +177,11 @@ class ExpandGradOp : public framework::OperatorWithKernel {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[i] * expand_times[i], out_dims[i], x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be " platform::errors::InvalidArgument(
"equal to multiplication of crroresponding dimension " "The %uth dimension size (%d) of Input(Out@GRAD) should be "
"size of Input(X) and Attr(expand_times) value."); "equal to the multiplication of the crroresponding dimension "
"sizes of Input(X) (%d) and expand_times (%d).",
i, out_dims[i], x_dims[i], expand_times[i]));
} }
} }
} }
......
...@@ -97,12 +97,19 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -97,12 +97,19 @@ class ExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size(); auto rank = context.Input<Tensor>("X")->dims().size();
switch (rank) { PADDLE_ENFORCE_GE(
REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) rank, 1,
default: platform::errors::InvalidArgument(
PADDLE_ENFORCE(false, "The number of dimensions of the input 'x' for Op(expand) "
"Only support tensor with rank being between 1 and 6."); "must be greater than or equal to 1, but the value received is %d.",
} rank));
PADDLE_ENFORCE_LE(
rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number of dimensions of the input 'x' for Op(expand) "
"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) }
} }
protected: protected:
...@@ -112,9 +119,13 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -112,9 +119,13 @@ class ExpandKernel : public framework::OpKernel<T> {
auto in_dims = in0->dims(); auto in_dims = in0->dims();
auto expand_times = get_expand_times(context); auto expand_times = get_expand_times(context);
PADDLE_ENFORCE_EQ(static_cast<size_t>(in_dims.size()), expand_times.size(), PADDLE_ENFORCE_EQ(
"The number of Attr(expand_times)'s value must be equal " static_cast<size_t>(in_dims.size()), expand_times.size(),
"to the rank of Input(X)."); platform::errors::InvalidArgument(
"The number of elements (%d) of 'expand_times' for "
"Op(expand) must be equal to the number "
"of dimensions (%d) of the input.",
expand_times.size(), static_cast<size_t>(in_dims.size())));
auto* out0 = context.Output<Tensor>("Out"); auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims; Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
...@@ -179,12 +190,19 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -179,12 +190,19 @@ class ExpandGradKernel : public framework::OpKernel<T> {
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0); out0);
} else { } else {
switch (dims) { PADDLE_ENFORCE_GE(dims, 1, platform::errors::InvalidArgument(
REP_EXPAND_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) "The number of dimensions of the input "
default: "'Out@GRAD' for Op(expand_grad)"
PADDLE_ENFORCE( " must be greater than or equal to 1, but "
false, "Only support tensor with rank being between 1 and 6."); "the value received is %d.",
} dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number of dimensions of the input 'Out@GRAD' "
"for Op(expand_grad) must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) { REP_EXPAND_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) }
} }
} }
...@@ -196,11 +214,15 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -196,11 +214,15 @@ class ExpandGradKernel : public framework::OpKernel<T> {
size_t reshape_size = reshape_dims_vec.size(); size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size(); size_t reduce_size = reduce_dims_vec.size();
PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(), PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
"Inconsistent size between template Dims and " platform::errors::InvalidArgument(
"reshape dimensions."); "Inconsistent size between template Dims (%d) and "
"reshape dimensions (%d).",
reshape_size, reshape_dims_vec.size()));
PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(), PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
"Inconsistent size between template Dims and " platform::errors::InvalidArgument(
"reduce dimensions."); "Inconsistent size between template Dims (%d) and "
"reduce dimensions (%d).",
reduce_size, reduce_dims_vec.size()));
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out")); auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
......
...@@ -23,16 +23,18 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -23,16 +23,18 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "MergeSelectedRows");
"Input(X) of MergeSelectedRowsOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MergeSelectedRows");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(
"Output(Out) of MergeSelectedRowsOp should not be null."); ctx->GetInputsVarType("X").front(),
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(), framework::proto::VarType::SELECTED_ROWS,
framework::proto::VarType::SELECTED_ROWS, platform::errors::InvalidArgument("Input(X) of MergeSelectedRowsOp "
"Input X only should be SelectedRows."); "should be of type SelectedRows."));
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(), PADDLE_ENFORCE_EQ(
framework::proto::VarType::SELECTED_ROWS, ctx->GetOutputsVarType("Out").front(),
"Output Y only should be SelectedRows."); framework::proto::VarType::SELECTED_ROWS,
platform::errors::InvalidArgument("Output(Out) of MergeSelectedRowsOp "
"should be of type SelectedRows."));
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
} }
......
...@@ -21,17 +21,21 @@ class ShardIndexOp : public framework::OperatorWithKernel { ...@@ -21,17 +21,21 @@ class ShardIndexOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShardIndex");
"Input(X) of ShardIndexOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShardIndex");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ShardIndexOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2, PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) should be at least 2."); platform::errors::InvalidArgument(
"Rank of Input(X) should be at least 2, "
"but the value given is %d.",
x_dims.size()));
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) { if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U,
"Last dimension of Input(X) should be 1."); platform::errors::InvalidArgument(
"The last dimension of Input(X) should be 1, "
"but the value given is %d.",
x_dims[x_dims.size() - 1]));
} }
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
......
...@@ -50,10 +50,29 @@ class ShardIndexCUDAKernel : public framework::OpKernel<T> { ...@@ -50,10 +50,29 @@ class ShardIndexCUDAKernel : public framework::OpKernel<T> {
int nshards = context.Attr<int>("nshards"); int nshards = context.Attr<int>("nshards");
int shard_id = context.Attr<int>("shard_id"); int shard_id = context.Attr<int>("shard_id");
int ignore_value = context.Attr<int>("ignore_value"); int ignore_value = context.Attr<int>("ignore_value");
PADDLE_ENFORCE_GT(index_num, 0); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GT(nshards, 0); index_num, 0,
PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, platform::errors::InvalidArgument(
"shard_id(%d) is not in range [0, %d)", shard_id, nshards); "The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d.",
index_num));
PADDLE_ENFORCE_GT(nshards, 0,
platform::errors::InvalidArgument(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d.",
nshards));
PADDLE_ENFORCE_GE(
shard_id, 0,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d.",
shard_id));
PADDLE_ENFORCE_LT(
shard_id, nshards,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d.",
nshards, shard_id));
out->Resize(in->dims()); out->Resize(in->dims());
out->set_lod(in->lod()); out->set_lod(in->lod());
......
...@@ -29,10 +29,29 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> { ...@@ -29,10 +29,29 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> {
int nshards = context.Attr<int>("nshards"); int nshards = context.Attr<int>("nshards");
int shard_id = context.Attr<int>("shard_id"); int shard_id = context.Attr<int>("shard_id");
int ignore_value = context.Attr<int>("ignore_value"); int ignore_value = context.Attr<int>("ignore_value");
PADDLE_ENFORCE_GT(index_num, 0); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GT(nshards, 0); index_num, 0,
PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, platform::errors::InvalidArgument(
"shard_id(%d) is not in range [0, %d)", shard_id, nshards); "The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d.",
index_num));
PADDLE_ENFORCE_GT(nshards, 0,
platform::errors::InvalidArgument(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d.",
nshards));
PADDLE_ENFORCE_GE(
shard_id, 0,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d.",
shard_id));
PADDLE_ENFORCE_LT(
shard_id, nshards,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d.",
nshards, shard_id));
int shard_size = (index_num + nshards - 1) / nshards; int shard_size = (index_num + nshards - 1) / nshards;
...@@ -42,9 +61,16 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> { ...@@ -42,9 +61,16 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> {
auto* out_data = out->mutable_data<T>(context.GetPlace()); auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = in->numel(); int64_t numel = in->numel();
for (int64_t i = 0; i < numel; ++i) { for (int64_t i = 0; i < numel; ++i) {
PADDLE_ENFORCE(in_data[i] >= 0 && in_data[i] < index_num, PADDLE_ENFORCE_GE(in_data[i], 0,
"Input index(%d) is out of range [0,%d)", in_data[i], platform::errors::InvalidArgument(
index_num); "The input_index for Op(shard_index) must be "
"greater or equal to 0, but the value given is %d.",
in_data[i]));
PADDLE_ENFORCE_LT(in_data[i], index_num,
platform::errors::InvalidArgument(
"The input_index for Op(shard_index) must be less "
"than index_num (%d), but the value given is %d.",
index_num, in_data[i]));
if (in_data[i] / shard_size == shard_id) { if (in_data[i] / shard_size == shard_id) {
out_data[i] = in_data[i] % shard_size; out_data[i] = in_data[i] % shard_size;
} else { } else {
......
...@@ -8843,12 +8843,10 @@ def crop(x, shape=None, offsets=None, name=None): ...@@ -8843,12 +8843,10 @@ def crop(x, shape=None, offsets=None, name=None):
crop = fluid.layers.crop(z, shape=[2, 2, 3]) crop = fluid.layers.crop(z, shape=[2, 2, 3])
""" """
check_variable_and_dtype(x, 'x', ['float32'], 'crop')
check_type(shape, 'shape', (list, tuple, Variable), 'crop')
helper = LayerHelper('crop', **locals()) helper = LayerHelper('crop', **locals())
if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.")
if offsets is None: if offsets is None:
offsets = [0] * len(x.shape) offsets = [0] * len(x.shape)
...@@ -14679,6 +14677,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): ...@@ -14679,6 +14677,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
nshards=2, nshards=2,
shard_id=0) shard_id=0)
""" """
check_variable_and_dtype(input, 'input', ['int64'], 'shard_index')
op_type = 'shard_index' op_type = 'shard_index'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
if shard_id < 0 or shard_id >= nshards: if shard_id < 0 or shard_id >= nshards:
......
...@@ -1895,6 +1895,14 @@ class TestLayer(LayerTest): ...@@ -1895,6 +1895,14 @@ class TestLayer(LayerTest):
self.assertIsNotNone(out2) self.assertIsNotNone(out2)
self.assertIsNotNone(out3) self.assertIsNotNone(out3)
def test_shard_index(self):
with self.static_graph():
x = fluid.layers.data(name="label", shape=[4, 1], dtype='int64')
shard_label = fluid.layers.shard_index(
input=x, index_num=20, nshards=2, shard_id=0)
self.assertIsNotNone(shard_label)
def test_accuracy(self): def test_accuracy(self):
x = np.random.rand(3, 32, 32).astype("float32") x = np.random.rand(3, 32, 32).astype("float32")
y = np.array([[1], [0], [1]]) y = np.array([[1], [0], [1]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册