From 55827199a625450003835dae08c51a20026c5a05 Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 14 May 2020 19:18:27 +0800 Subject: [PATCH] Optimize error message, include dgc, nccl, size op (#24456), test=release/1.8 (#24524) --- paddle/fluid/operators/dgc_clip_by_norm_op.cc | 4 +- paddle/fluid/operators/dgc_op.cc | 37 ++++----- paddle/fluid/operators/dgc_op.h | 23 ++++-- paddle/fluid/operators/nccl/nccl_op.cc | 75 ++++++++++--------- paddle/fluid/operators/nccl/nccl_op.cu.cc | 67 ++++++++--------- .../fluid/operators/nccl/nccl_op_test.cu.cc | 59 +++++++++------ paddle/fluid/operators/size_op.cc | 7 +- 7 files changed, 145 insertions(+), 127 deletions(-) diff --git a/paddle/fluid/operators/dgc_clip_by_norm_op.cc b/paddle/fluid/operators/dgc_clip_by_norm_op.cc index 6ebad4de3c8..85a29271b13 100644 --- a/paddle/fluid/operators/dgc_clip_by_norm_op.cc +++ b/paddle/fluid/operators/dgc_clip_by_norm_op.cc @@ -23,8 +23,8 @@ class DGCClipByNormOp : public ClipByNormOp { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("current_step"), - "current_step should be set."); + OP_INOUT_CHECK(ctx->HasInput("current_step"), "Input", "current_step", + "DGCClipByNormOp"); return ClipByNormOp::InferShape(ctx); } diff --git a/paddle/fluid/operators/dgc_op.cc b/paddle/fluid/operators/dgc_op.cc index 5657349d02d..5fe66fa38a8 100644 --- a/paddle/fluid/operators/dgc_op.cc +++ b/paddle/fluid/operators/dgc_op.cc @@ -25,28 +25,21 @@ class DGCOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) of DGCop should not be null."); - PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) of DGCop should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(Grad) of DGCop should not be null."); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Param"), true, - platform::errors::NotFound("Input(Param) of DGCop is not found.")); - PADDLE_ENFORCE(ctx->HasInput("current_step"), - "Input(current_step) of DGCop should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("nranks"), true, - "Input(nranks) of DGCop should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("U_out"), - "Output(U_out) of DGCop should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("V_out"), - "Output(V_out) of DGCop should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("k"), - "Output(k) of DGCop should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("EncodeGrad"), - "Output(EncodeGrad) of DGCop should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasOutput("GatherBuff"), true, - "Output(EncodeGrad) of DGCop should not be null."); + OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "DGCOp"); + OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "DGCOp"); + OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "DGCOp"); + OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "DGCOp"); + OP_INOUT_CHECK(ctx->HasInput("current_step"), "Input", "current_step", + "DGCOp"); + OP_INOUT_CHECK(ctx->HasInput("nranks"), "Input", "nranks", "DGCOp"); + + OP_INOUT_CHECK(ctx->HasOutput("U_out"), "Output", "U_out", "DGCOp"); + OP_INOUT_CHECK(ctx->HasOutput("V_out"), "Output", "V_out", "DGCOp"); + OP_INOUT_CHECK(ctx->HasOutput("k"), "Output", "k", "DGCOp"); + OP_INOUT_CHECK(ctx->HasOutput("EncodeGrad"), "Output", "EncodeGrad", + "DGCOp"); + OP_INOUT_CHECK(ctx->HasOutput("GatherBuff"), "Output", "GatherBuff", + "DGCOp"); } protected: diff --git a/paddle/fluid/operators/dgc_op.h b/paddle/fluid/operators/dgc_op.h index 1736fc36f64..8de57ccf623 100644 --- a/paddle/fluid/operators/dgc_op.h +++ b/paddle/fluid/operators/dgc_op.h @@ -24,14 +24,22 @@ namespace operators { inline float get_period_sparcity(const std::vector& sparsity, float cur_step, float rampup_steps) { - PADDLE_ENFORCE_GE(static_cast(cur_step), 0); + PADDLE_ENFORCE_GE(static_cast(cur_step), 0, + platform::errors::InvalidArgument( + "DGC current step=%d, but it must >= 0, " + "please submit issue in github", + static_cast(cur_step))); size_t idx = static_cast(cur_step * sparsity.size() / rampup_steps); if (idx >= sparsity.size()) { idx = sparsity.size() - 1; } - PADDLE_ENFORCE_LT(idx, sparsity.size()); + PADDLE_ENFORCE_LT( + idx, sparsity.size(), + platform::errors::OutOfRange( + "sparsity index out of bounds. idx=%d >= sparsity.size=%d", idx, + sparsity.size())); return sparsity[idx]; } @@ -55,7 +63,10 @@ class DGCOpKernel : public framework::OpKernel { // nranks auto nranks_tensor = ctx.Input("nranks"); const int nranks = static_cast(*nranks_tensor->data()); - PADDLE_ENFORCE_GT(nranks, 1, "DGC is not useful when num_trainers <= 1"); + PADDLE_ENFORCE_GT(nranks, 1, + platform::errors::PreconditionNotMet( + "DGC is not useful when num_trainers <= 1. Please " + "use multi card or multi machine GPU")); // regularization auto p = ctx.Input("Param"); @@ -105,8 +116,10 @@ class DGCOpKernel : public framework::OpKernel { 1 - get_period_sparcity( sparsity, static_cast(*current_step - rampup_begin_step), rampup_step); - PADDLE_ENFORCE_GE(ratio, 0.0); - PADDLE_ENFORCE_LT(ratio, 1.0); + PADDLE_ENFORCE_GE(ratio, 0.0, platform::errors::InvalidArgument( + "DGC sparsity ratio must >= 0")); + PADDLE_ENFORCE_LT(ratio, 1.0, platform::errors::InvalidArgument( + "DGC sparsity ratio must < 1")); int k = static_cast(g->numel() * ratio); VLOG(10) << "m:" << m << ", use_nesterov:" << use_nesterov diff --git a/paddle/fluid/operators/nccl/nccl_op.cc b/paddle/fluid/operators/nccl/nccl_op.cc index 519fcf5924a..86efd703b82 100644 --- a/paddle/fluid/operators/nccl/nccl_op.cc +++ b/paddle/fluid/operators/nccl/nccl_op.cc @@ -31,12 +31,15 @@ class NCCLInitOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kParallelScopes)), - "Can not find variable '%s' in the scope.", - kParallelScopes); + PADDLE_ENFORCE_NOT_NULL( + scope.FindVar(Input(kParallelScopes)), + platform::errors::NotFound("Can not find variable '%s' in the scope.", + kParallelScopes)); const auto &name = Output("Communicator"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), - "Can not find variable '%s' in the scope.", name); + PADDLE_ENFORCE_NOT_NULL( + scope.FindVar(name), + platform::errors::NotFound( + "Output(%s) is needed for ncclInit operator.", name)); // A parallel do may not use all the gpus. For example, the batch size is 7 // in the last batch while we have 8 gpu. In this case, parallel_do will // create 7 parallel scopes, so should ncclInitOp create 7 gpu peers @@ -46,11 +49,9 @@ class NCCLInitOp : public framework::OperatorBase { for (int i = 0; i < static_cast(parallel_scopes.size()); ++i) { gpus[i] = i; } - PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus."); - - if (scope.FindVar(name) == nullptr) { - PADDLE_THROW("Output(Communicator) is needed for ncclInit operator."); - } + PADDLE_ENFORCE_EQ(!gpus.empty(), true, + platform::errors::PreconditionNotMet( + "gpus is empty, NCCL must init with gpus")); platform::Communicator *comm = scope.FindVar(name)->GetMutable(); @@ -92,17 +93,17 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - " Input(X) of AllReduce op input should not be NULL"); - PADDLE_ENFORCE( - ctx->HasInput("Communicator"), - " Input(Communicator) of AllReduce op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - " Output(Out) of AllReduce op output should not be NULL"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NCCLAllReduce"); + OP_INOUT_CHECK(ctx->HasInput("Communicator"), "Input", "Communicator", + "NCCLAllReduce"); + + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NCCLAllReduce"); + std::string reduction = ctx->Attrs().Get("reduction"); - PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || - reduction == "ncclMin" || reduction == "ncclMax"), - "invalid reduction."); + PADDLE_ENFORCE_EQ( + (reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + true, platform::errors::InvalidArgument("invalid nccl reduction.")); auto x_dims = ctx->GetInputsDim("X"); ctx->SetOutputsDim("Out", x_dims); @@ -137,18 +138,17 @@ class NCCLReduceOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - " Input(X) of Reduce op input should not be NULL"); - PADDLE_ENFORCE( - ctx->HasInput("Communicator"), - " Input(Communicator) of Reduce op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - " Input(X) of Reduce op input should not be NULL"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NCCLReduce"); + OP_INOUT_CHECK(ctx->HasInput("Communicator"), "Input", "Communicator", + "NCCLReduce"); + + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NCCLReduce"); std::string reduction = ctx->Attrs().Get("reduction"); - PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || - reduction == "ncclMin" || reduction == "ncclMax"), - "invalid reduction."); + PADDLE_ENFORCE_EQ( + (reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + true, platform::errors::InvalidArgument("invalid nccl reduction.")); auto x_dims = ctx->GetInputsDim("X"); ctx->SetOutputsDim("Out", x_dims); @@ -188,15 +188,16 @@ class NCCLBcastOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - " Input(X) of Bcast op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasInput("Communicator"), - " Input(Communicator) of Bcast op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - " Output(Out) of Bcast op output should not be NULL"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NCCLBcast"); + OP_INOUT_CHECK(ctx->HasInput("Communicator"), "Input", "Communicator", + "NCCLBcast"); + + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NCCLBcast"); int root = ctx->Attrs().Get("root"); - PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set."); + PADDLE_ENFORCE_EQ( + root != platform::kInvalidGPUId, true, + platform::errors::InvalidArgument("Bcast root must be set.")); auto x_dims = ctx->GetInputsDim("X"); ctx->SetOutputsDim("Out", x_dims); diff --git a/paddle/fluid/operators/nccl/nccl_op.cu.cc b/paddle/fluid/operators/nccl/nccl_op.cu.cc index 8de974bc2b3..65100ad8e39 100644 --- a/paddle/fluid/operators/nccl/nccl_op.cu.cc +++ b/paddle/fluid/operators/nccl/nccl_op.cu.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -37,36 +38,42 @@ class NCCLTypeWrapper { static const ncclDataType_t type = ncclDouble; }; +static ncclRedOp_t str_to_nccl_red_type(std::string reduction) { + static const std::unordered_map str_to_type = { + {"ncclSum", ncclSum}, + {"ncclMin", ncclMin}, + {"ncclMax", ncclMax}, + {"ncclProd", ncclProd}, + }; + auto it = str_to_type.find(reduction); + PADDLE_ENFORCE_EQ(it != str_to_type.end(), true, + platform::errors::InvalidArgument( + "Invalid nccl reduction. Must be ncclMin | ncclMax | " + "ncclProd | ncclSum")); + return it->second; +} + template class NCCLAllReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet( + "This kernel only runs on GPU device.")); auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); auto* comm = ctx.Input("Communicator"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t reduction_op_ = ncclSum; - if (reduction == "ncclMin") { - reduction_op_ = ncclMin; - } else if (reduction == "ncclMax") { - reduction_op_ = ncclMax; - } else if (reduction == "ncclSum") { - reduction_op_ = ncclSum; - } else if (reduction == "ncclProd") { - reduction_op_ = ncclProd; - } else { - PADDLE_THROW("Invalid reduction. default ncclSum."); - } + auto reduction_op_ = str_to_nccl_red_type(reduction); + // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); VLOG(3) << "gpu : " << " invoke allreduce. send " << x->numel() << " recv " << out->numel(); - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( x->data(), out->mutable_data(ctx.GetPlace()), out->numel(), NCCLTypeWrapper::type, reduction_op_, comm->comms().at(idx), ctx.cuda_device_context().stream())); @@ -80,26 +87,17 @@ template class NCCLReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "This kernel only runs on GPU device.")); auto x = ctx.Input("X"); // x0, x1, x2 auto out = ctx.Output("Out"); auto* comm = ctx.Input("Communicator"); int root = ctx.Attr("root"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t reduction_op_ = ncclSum; - if (reduction == "ncclMin") { - reduction_op_ = ncclMin; - } else if (reduction == "ncclMax") { - reduction_op_ = ncclMax; - } else if (reduction == "ncclSum") { - reduction_op_ = ncclSum; - } else if (reduction == "ncclProd") { - reduction_op_ = ncclProd; - } else { - PADDLE_THROW("Invalid reduction. default ncclSum."); - } + auto reduction_op_ = str_to_nccl_red_type(reduction); + // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); @@ -111,7 +109,7 @@ class NCCLReduceKernel : public framework::OpKernel { } VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel() << " recv " << out->numel(); - PADDLE_ENFORCE(platform::dynload::ncclReduce( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduce( x->data(), recvbuffer, x->numel(), NCCLTypeWrapper::type, reduction_op_, root, comm->comms().at(idx), ctx.cuda_device_context().stream())); @@ -124,8 +122,9 @@ template class NCCLBcastKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "This kernel only runs on GPU device.")); int root = ctx.Attr("root"); auto* comm = ctx.Input("Communicator"); // device id @@ -134,7 +133,7 @@ class NCCLBcastKernel : public framework::OpKernel { if (idx == root) { auto* x = ctx.Input("X"); VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. send " << x->numel(); - PADDLE_ENFORCE(platform::dynload::ncclBcast( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( reinterpret_cast(const_cast(x->data())), x->numel(), NCCLTypeWrapper::type, root, comm->comms().at(idx), ctx.cuda_device_context().stream())); @@ -143,7 +142,7 @@ class NCCLBcastKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " << framework::product(out->dims()); - PADDLE_ENFORCE(platform::dynload::ncclBcast( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( out->mutable_data(ctx.GetPlace()), out->numel(), NCCLTypeWrapper::type, root, comm->comms().at(idx), ctx.cuda_device_context().stream())); diff --git a/paddle/fluid/operators/nccl/nccl_op_test.cu.cc b/paddle/fluid/operators/nccl/nccl_op_test.cu.cc index d5fb7a12e5d..216a277938f 100644 --- a/paddle/fluid/operators/nccl/nccl_op_test.cu.cc +++ b/paddle/fluid/operators/nccl/nccl_op_test.cu.cc @@ -45,10 +45,9 @@ class NCCLTester : public ::testing::Test { public: void SetUp() override { int count = p::GetCUDADeviceCount(); - if (count <= 1) { - LOG(WARNING) - << "Cannot test multi-gpu nccl, because the CUDA device count is " - << count; + if (count <= 0) { + LOG(WARNING) << "Cannot test gpu nccl, because the CUDA device count is " + << count; exit(0); } for (int i = 0; i < count; ++i) { @@ -114,8 +113,9 @@ class NCCLTester : public ::testing::Test { lk.unlock(); - PADDLE_ENFORCE(send_tensor->numel() == f::product(kDims), - "Tensor numel not match!"); + PADDLE_ENFORCE_EQ( + send_tensor->numel(), f::product(kDims), + paddle::platform::errors::InvalidArgument("Tensor numel not match!")); auto op = f::OpRegistry::CreateOp(*op1); @@ -126,6 +126,10 @@ class NCCLTester : public ::testing::Test { VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type(); } + void testNcclReduceOp(); + void testNcclAllReduceOp(); + void testNcclBcastOp(); + public: std::vector dev_ctxs_; f::Scope g_scope_; @@ -133,13 +137,7 @@ class NCCLTester : public ::testing::Test { std::vector gpu_list_; }; -// ncclInitOp with desc -TEST_F(NCCLTester, ncclInitOp) {} - -// ncclAllReduceOp with desc -// TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9367 -/* -TEST_F(NCCLTester, ncclAllReduceOp) { +void NCCLTester::testNcclAllReduceOp() { std::unique_ptr op2(new f::OpDesc); op2->SetType("ncclAllReduce"); op2->SetInput("X", {"st"}); @@ -186,10 +184,8 @@ TEST_F(NCCLTester, ncclAllReduceOp) { } } } -*/ -// ncclReduceOp with desc -TEST_F(NCCLTester, ncclReduceOp) { +void NCCLTester::testNcclReduceOp() { std::unique_ptr op2(new f::OpDesc); const int kRoot = 0; op2->SetType("ncclReduce"); @@ -236,10 +232,7 @@ TEST_F(NCCLTester, ncclReduceOp) { } } -// ncclBcastOp with desc -// TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9540 -/* -TEST_F(NCCLTester, ncclBcastOp) { +void NCCLTester::testNcclBcastOp() { std::unique_ptr op2(new f::OpDesc); const int kRoot = 0; op2->SetType("ncclBcast"); @@ -263,13 +256,17 @@ TEST_F(NCCLTester, ncclBcastOp) { ths[i].join(); } - const int idx = 1; + const int idx = gpu_list_.size() - 1; float result = GetGPUData(kRoot); p::CPUPlace cpu_place; p::CUDAPlace gpu_place(gpu_list_[idx]); - auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get(); + std::string rt_str = "rt"; + if (idx == kRoot) { + rt_str = "st"; + } + auto &recv_tensor = dev_scopes[idx]->FindVar(rt_str)->Get(); auto *rt = recv_tensor.data(); auto *result_tensor = dev_scopes[idx]->Var("ct")->GetMutable(); result_tensor->Resize(kDims); @@ -284,4 +281,20 @@ TEST_F(NCCLTester, ncclBcastOp) { ASSERT_NEAR(ct[j], result, 1e-5); } } -*/ + +// ncclInitOp with desc +TEST_F(NCCLTester, ncclInitOp) {} + +TEST_F(NCCLTester, ncclOp) { + // Serial execution is required for the same nccl comm. + + // ncclAllReduceOp with desc + // TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9367 + testNcclReduceOp(); + + testNcclAllReduceOp(); + + // ncclBcastOp with desc + // TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9540 + testNcclBcastOp(); +} diff --git a/paddle/fluid/operators/size_op.cc b/paddle/fluid/operators/size_op.cc index ca7f36205a2..06eaca0216b 100644 --- a/paddle/fluid/operators/size_op.cc +++ b/paddle/fluid/operators/size_op.cc @@ -23,10 +23,9 @@ class SizeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input (Input) of Size op should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output (Out) of Size op should not be null."); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Size"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Size"); + ctx->SetOutputDim("Out", {1}); } }; -- GitLab