diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 2e9cfaa599ec50fa2b2b61c97d59984710e80a0e..0d0d92f42a1a80e022fc1fef15c71797637e9a46 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -23,7 +23,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::Tensor; +using Tensor = framework::Tensor; class ConcatOp : public framework::OperatorWithKernel { public: @@ -80,12 +80,12 @@ class ConcatOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto vars = ctx.MultiInputVar("X"); + auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); bool flag = 0; - for (auto *var : vars) { - if (var->IsInitialized()) { - input_data_type = framework::GetDataTypeOfVar(var); + for (auto *input : inputs) { + if (input->IsInitialized() && input->numel() > 0) { + input_data_type = input->type(); flag = 1; break; }