未验证 提交 6bb8206d 编写于 作者: W wangguanzhong 提交者: GitHub

enhance the error message of box_clip, test=develop (#23638)

上级 8987946f
......@@ -21,22 +21,36 @@ class BoxClipOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of BoxClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
"Input(ImInfo) of BoxClipOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
platform::errors::NotFound("Input(Input) of BoxClipOp "
"is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("ImInfo"), true,
platform::errors::NotFound("Input(ImInfo) of BoxClipOp "
"is not found."));
auto input_box_dims = ctx->GetInputDim("Input");
auto im_info_dims = ctx->GetInputDim("ImInfo");
if (ctx->IsRuntime()) {
auto input_box_size = input_box_dims.size();
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
"The last dimension of Input must be 4");
PADDLE_ENFORCE_EQ(
input_box_dims[input_box_size - 1], 4,
platform::errors::InvalidArgument(
"The last dimension "
"of Input must be 4. But received last dimension = %d",
input_box_dims[input_box_size - 1]));
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(Input) in BoxClipOp must be 2");
PADDLE_ENFORCE_EQ(im_info_dims[1], 3,
"The last dimension of ImInfo must be 3");
platform::errors::InvalidArgument(
"The rank of "
"Input(Input) in BoxClipOp must be 2. But received "
"rank = %d",
im_info_dims.size()));
PADDLE_ENFORCE_EQ(
im_info_dims[1], 3,
platform::errors::InvalidArgument(
"The last dimension "
"of ImInfo must be 3. But received last dimension = %d",
im_info_dims[1]));
}
ctx->ShareDim("Input", /*->*/ "Output");
ctx->ShareLoD("Input", /*->*/ "Output");
......
......@@ -46,8 +46,6 @@ template <typename DeviceContext, typename T>
class GPUBoxClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto *input = context.Input<LoDTensor>("Input");
auto *im_info = context.Input<Tensor>("ImInfo");
auto *output = context.Output<LoDTensor>("Output");
......
......@@ -33,7 +33,9 @@ class BoxClipKernel : public framework::OpKernel<T> {
output_box->mutable_data<T>(context.GetPlace());
if (input_box->lod().size()) {
PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL,
"Only support 1 level of LoD.");
platform::errors::InvalidArgument(
"Input(Input) of "
"BoxClip only supports 1 level of LoD."));
}
auto box_lod = input_box->lod().back();
int64_t n = static_cast<int64_t>(box_lod.size() - 1);
......
......@@ -30,6 +30,7 @@ import math
import six
import numpy
from functools import reduce
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = [
'prior_box',
......@@ -2866,6 +2867,10 @@ def box_clip(input, im_info, name=None):
input=boxes, im_info=im_info)
"""
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'box_clip')
check_variable_and_dtype(im_info, 'im_info', ['float32', 'float64'],
'box_clip')
helper = LayerHelper("box_clip", **locals())
output = helper.create_variable_for_type_inference(dtype=input.dtype)
inputs = {"Input": input, "ImInfo": im_info}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册