diff --git a/paddle/fluid/operators/distributed_ops/split_byref_op.cc b/paddle/fluid/operators/distributed_ops/split_byref_op.cc index d65e7ffe5a492fe5df038bb6bd469e09de6f95ca..43980107c14176f1751a3db2858c80cb65c764de 100644 --- a/paddle/fluid/operators/distributed_ops/split_byref_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_byref_op.cc @@ -31,14 +31,16 @@ class SplitByrefOp : public framework::OperatorWithKernel { auto in_dims = ctx->GetInputDim("X"); auto outs_names = ctx->Outputs("Out"); size_t num = static_cast(ctx->Attrs().Get("num")); - std::vector sections = static_cast>( - ctx->Attrs().Get>("sections")); + auto sections = ctx->Attrs().Get>("sections"); const size_t outs_number = outs_names.size(); std::vector outs_dims; outs_dims.reserve(outs_number); if (num > 0) { - int64_t in_axis_dim = in_dims[0]; + int64_t in_axis_dim = 0; + if (ctx->IsRuntime()) { + in_axis_dim = in_dims[0]; + } PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, "tensor split does not result" " in an equal division");