From aab4d12c0eb0af10d70878527dd6d402b31d6468 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 10 Jun 2019 17:39:52 +0800 Subject: [PATCH] refine GetExpectedKernelType in conat op, test=develop (#17934) --- paddle/fluid/operators/concat_op.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 2e9cfaa59..0d0d92f42 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -23,7 +23,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::Tensor; +using Tensor = framework::Tensor; class ConcatOp : public framework::OperatorWithKernel { public: @@ -80,12 +80,12 @@ class ConcatOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto vars = ctx.MultiInputVar("X"); + auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); bool flag = 0; - for (auto *var : vars) { - if (var->IsInitialized()) { - input_data_type = framework::GetDataTypeOfVar(var); + for (auto *input : inputs) { + if (input->IsInitialized() && input->numel() > 0) { + input_data_type = input->type(); flag = 1; break; } -- GitLab