未验证 提交 9e8f9037 编写于 作者: K Kqnonrime 提交者: GitHub

fix two error message (#32039)

* fix two error message

* fix two error message

* fix error

* fix error

* fix error

* fix error
上级 2e82b6c8
......@@ -102,9 +102,13 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"src shape and dst shape should match"));
PADDLE_ENFORCE_EQ(
src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i, src_dims[i], i, dst_dims[i]));
// slice size
size_t slice_size = 1;
......@@ -146,9 +150,13 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"src shape and dst shape should match"));
PADDLE_ENFORCE_EQ(
src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.",
i, src_dims[i], i, dst_dims[i]));
// slice size
size_t slice_size = 1;
......
......@@ -101,14 +101,18 @@ class UnStackGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
platform::errors::InvalidArgument(
"Number of Inputs(Y@Grad) must be larger than 0"));
"The Inputs(Y@Grad) of unstack operator are empty."));
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", "X",
"UnStackGrad");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
platform::errors::InvalidArgument(
"Dims of all Inputs(Y@Grad) must be the same"));
PADDLE_ENFORCE_EQ(
input_dims[i], input_dims[0],
platform::errors::InvalidArgument(
"The dimensions of all Inputs(Y@Grad) must be the same,"
"but received Inputs(Y@Grad)'s %d-th dimension is %d, "
"Inputs(Y@Grad)'s 0-th to %d-th dimension is %d.",
i, input_dims[i], i - 1, input_dims[0]));
}
int axis = ctx->Attrs().Get<int>("axis");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册