未验证 提交 9c3739b0 编写于 作者: Y Yang Zhang 提交者: GitHub

Refine `squared_l2_distance_grad` error message (#24418)

test=develop
上级 24cf3932
......@@ -77,8 +77,16 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
auto* x_g = context.Output<Tensor>(framework::GradVarName("X"));
auto* y_g = context.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_NOT_NULL(x_g);
PADDLE_ENFORCE_NOT_NULL(y_g);
PADDLE_ENFORCE_NOT_NULL(
x_g, platform::errors::NotFound(
"variable(%s) cannot be found "
"in scope for operator 'squared_l2_distance_grad'.",
framework::GradVarName("X")));
PADDLE_ENFORCE_NOT_NULL(
y_g, platform::errors::NotFound(
"variable(%s) cannot be found "
"in scope for operator 'squared_l2_distance_grad'.",
framework::GradVarName("Y")));
auto sub_result = EigenMatrix<T>::From(*in0);
auto out_grad = EigenMatrix<T>::From(*in1);
......@@ -106,8 +114,11 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
y_g->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0],
platform::errors::InvalidArgument(
"First dimension of gradient must be greater or "
"equal than first dimension of target.");
"equal than first dimension of target. But received "
"gradient dimension = %d and target dimension is %d.",
sub_result.dimensions()[0], y_dims[0]));
if (sub_result.dimensions()[0] == y_dims[0]) {
auto y_grad =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册