未验证 提交 6bb8206d 编写于 作者: W wangguanzhong 提交者: GitHub

enhance the error message of box_clip, test=develop (#23638)

上级 8987946f
...@@ -21,22 +21,36 @@ class BoxClipOp : public framework::OperatorWithKernel { ...@@ -21,22 +21,36 @@ class BoxClipOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of BoxClipOp should not be null."); platform::errors::NotFound("Input(Input) of BoxClipOp "
PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "is not found."));
"Input(ImInfo) of BoxClipOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("ImInfo"), true,
platform::errors::NotFound("Input(ImInfo) of BoxClipOp "
"is not found."));
auto input_box_dims = ctx->GetInputDim("Input"); auto input_box_dims = ctx->GetInputDim("Input");
auto im_info_dims = ctx->GetInputDim("ImInfo"); auto im_info_dims = ctx->GetInputDim("ImInfo");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
auto input_box_size = input_box_dims.size(); auto input_box_size = input_box_dims.size();
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4, PADDLE_ENFORCE_EQ(
"The last dimension of Input must be 4"); input_box_dims[input_box_size - 1], 4,
platform::errors::InvalidArgument(
"The last dimension "
"of Input must be 4. But received last dimension = %d",
input_box_dims[input_box_size - 1]));
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(Input) in BoxClipOp must be 2"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(im_info_dims[1], 3, "The rank of "
"The last dimension of ImInfo must be 3"); "Input(Input) in BoxClipOp must be 2. But received "
"rank = %d",
im_info_dims.size()));
PADDLE_ENFORCE_EQ(
im_info_dims[1], 3,
platform::errors::InvalidArgument(
"The last dimension "
"of ImInfo must be 3. But received last dimension = %d",
im_info_dims[1]));
} }
ctx->ShareDim("Input", /*->*/ "Output"); ctx->ShareDim("Input", /*->*/ "Output");
ctx->ShareLoD("Input", /*->*/ "Output"); ctx->ShareLoD("Input", /*->*/ "Output");
......
...@@ -46,8 +46,6 @@ template <typename DeviceContext, typename T> ...@@ -46,8 +46,6 @@ template <typename DeviceContext, typename T>
class GPUBoxClipKernel : public framework::OpKernel<T> { class GPUBoxClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto *input = context.Input<LoDTensor>("Input"); auto *input = context.Input<LoDTensor>("Input");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto *output = context.Output<LoDTensor>("Output"); auto *output = context.Output<LoDTensor>("Output");
......
...@@ -33,7 +33,9 @@ class BoxClipKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,9 @@ class BoxClipKernel : public framework::OpKernel<T> {
output_box->mutable_data<T>(context.GetPlace()); output_box->mutable_data<T>(context.GetPlace());
if (input_box->lod().size()) { if (input_box->lod().size()) {
PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL, PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL,
"Only support 1 level of LoD."); platform::errors::InvalidArgument(
"Input(Input) of "
"BoxClip only supports 1 level of LoD."));
} }
auto box_lod = input_box->lod().back(); auto box_lod = input_box->lod().back();
int64_t n = static_cast<int64_t>(box_lod.size() - 1); int64_t n = static_cast<int64_t>(box_lod.size() - 1);
......
...@@ -30,6 +30,7 @@ import math ...@@ -30,6 +30,7 @@ import math
import six import six
import numpy import numpy
from functools import reduce from functools import reduce
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = [ __all__ = [
'prior_box', 'prior_box',
...@@ -2866,6 +2867,10 @@ def box_clip(input, im_info, name=None): ...@@ -2866,6 +2867,10 @@ def box_clip(input, im_info, name=None):
input=boxes, im_info=im_info) input=boxes, im_info=im_info)
""" """
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'box_clip')
check_variable_and_dtype(im_info, 'im_info', ['float32', 'float64'],
'box_clip')
helper = LayerHelper("box_clip", **locals()) helper = LayerHelper("box_clip", **locals())
output = helper.create_variable_for_type_inference(dtype=input.dtype) output = helper.create_variable_for_type_inference(dtype=input.dtype)
inputs = {"Input": input, "ImInfo": im_info} inputs = {"Input": input, "ImInfo": im_info}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册