infershape_utils.cc 29.2 KB
Newer Older
C
Chen Weihang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"

17
#include <algorithm>
18 19
#include <string>

20
#include "paddle/fluid/framework/convert_utils.h"
C
Chen Weihang 已提交
21
#include "paddle/fluid/framework/framework.pb.h"
22
#include "paddle/fluid/framework/phi_utils.h"
C
Chen Weihang 已提交
23
#include "paddle/fluid/platform/enforce.h"
24
#include "paddle/phi/common/int_array.h"
25
#include "paddle/phi/common/scalar.h"
26 27 28 29 30
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
31
#include "paddle/phi/core/kernel_factory.h"
32
#include "paddle/phi/core/tensor_utils.h"
C
Chen Weihang 已提交
33 34 35 36

namespace paddle {
namespace framework {

37
class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
C
Chen Weihang 已提交
38 39 40 41 42 43 44 45 46 47 48 49
 public:
  explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx)
      : ctx_(ctx) {}

  bool HasInput(const std::string& name) const override {
    return ctx_.HasInput(name);
  }

  bool HasOutput(const std::string& name) const override {
    return ctx_.HasOutput(name);
  }

50 51 52 53
  bool HasAttr(const std::string& name) const override {
    return ctx_.HasAttr(name);
  }

C
Chen Weihang 已提交
54
  paddle::any Attr(const std::string& name) const override {
55 56
    auto* attr = ctx_.Attrs().GetAttr(name);
    PADDLE_ENFORCE_NOT_NULL(
57 58 59
        attr,
        platform::errors::NotFound("Attribute (%s) should be in AttributeMap.",
                                   name));
60
    return GetAttrValue(*attr);
C
Chen Weihang 已提交
61 62 63
  }

  size_t InputSize(const std::string& name) const override {
64 65 66 67 68 69
    if (ctx_.HasInputs(name)) {
      return ctx_.Inputs(name).size();
    } else if (ctx_.HasInput(name)) {
      return 1;
    }
    return 0;
C
Chen Weihang 已提交
70 71 72 73 74 75 76
  }

  size_t OutputSize(const std::string& name) const override {
    return ctx_.Outputs(name).size();
  }

  bool IsDenseTensorInput(const std::string& name) const override {
77 78 79 80 81
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::LOD_TENSOR;
  }

  bool IsDenseTensorInputs(const std::string& name) const override {
C
Chen Weihang 已提交
82
    auto var_types = ctx_.GetInputsVarType(name);
83 84
    return std::all_of(var_types.begin(),
                       var_types.end(),
85 86 87
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
C
Chen Weihang 已提交
88 89 90
  }

  bool IsSelectedRowsInput(const std::string& name) const override {
91 92
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::SELECTED_ROWS;
C
Chen Weihang 已提交
93 94
  }

95 96
  bool IsDenseTensorVectorInput(const std::string& name) const override {
    auto var_types = ctx_.GetInputsVarType(name);
97 98
    return std::all_of(var_types.begin(),
                       var_types.end(),
99 100 101
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR_ARRAY;
                       });
102 103
  }

104 105
  bool IsDenseTensorOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
106 107
    return std::all_of(var_types.begin(),
                       var_types.end(),
108 109 110
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
111 112 113 114
  }

  bool IsSelectedRowsOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
115 116
    return std::all_of(var_types.begin(),
                       var_types.end(),
117 118 119
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::SELECTED_ROWS;
                       });
120 121
  }

122 123
  bool IsForInferShape() const override { return true; }

124 125
  bool IsRuntime() const override { return ctx_.IsRuntime(); }

C
Chen Weihang 已提交
126 127 128 129
 private:
  const InferShapeContext& ctx_;
};

130 131 132 133 134 135 136
static inline void ValidCheck(const phi::MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current CompatMetaTensor is not initialized."));
}

137
int64_t CompatMetaTensor::numel() const {
138
  ValidCheck(*this);
139
  if (is_runtime_) {
R
Ruibiao Chen 已提交
140
    auto* var = PADDLE_GET_CONST(Variable*, var_);
141 142
    return var->Get<Tensor>().numel();
  } else {
R
Ruibiao Chen 已提交
143
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
144
    return var->ElementSize();
C
Chen Weihang 已提交
145
  }
146
}
C
Chen Weihang 已提交
147

148
DDim CompatMetaTensor::dims() const {
149
  ValidCheck(*this);
150
  if (is_runtime_) {
R
Ruibiao Chen 已提交
151
    auto* var = PADDLE_GET_CONST(Variable*, var_);
152 153 154 155 156 157 158 159
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dims();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().dims();
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // use tensor array size as dims
      auto& tensor_array = var->Get<framework::LoDTensorArray>();
      return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
C
Chen Weihang 已提交
160
    } else {
161 162 163
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dims from DenseTensor or SelectedRows or "
          "DenseTensorArray."));
C
Chen Weihang 已提交
164
    }
165
  } else {
R
Ruibiao Chen 已提交
166
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
167 168 169

    return var->GetShape().empty() ? phi::make_ddim({0UL})
                                   : phi::make_ddim(var->GetShape());
C
Chen Weihang 已提交
170
  }
171
}
C
Chen Weihang 已提交
172

173
phi::DataType CompatMetaTensor::dtype() const {
174
  ValidCheck(*this);
175
  if (is_runtime_) {
R
Ruibiao Chen 已提交
176
    auto* var = PADDLE_GET_CONST(Variable*, var_);
177 178 179 180 181 182 183 184
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dtype();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().dtype();
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported get dtype from LoDTensorArray now
      return phi::DataType::UNDEFINED;
C
Chen Weihang 已提交
185
    } else {
186 187
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
188
    }
189
  } else {
R
Ruibiao Chen 已提交
190
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
191
    return paddle::framework::TransToPhiDataType(var->GetDataType());
C
Chen Weihang 已提交
192
  }
193
}
C
Chen Weihang 已提交
194

195
DataLayout CompatMetaTensor::layout() const {
196
  ValidCheck(*this);
197
  if (is_runtime_) {
R
Ruibiao Chen 已提交
198
    auto* var = PADDLE_GET_CONST(Variable*, var_);
199 200 201 202 203
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().layout();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().layout();
    } else if (var->IsType<framework::LoDTensorArray>()) {
204
      // NOTE(chenweihang): do nothing
205 206 207 208 209 210
      // Unsupported get layout from LoDTensorArray now
      return phi::DataLayout::UNDEFINED;
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
211
    }
212 213 214 215
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported get layout for VarDesc now
    return DataLayout::UNDEFINED;
C
Chen Weihang 已提交
216
  }
217 218 219
}

void CompatMetaTensor::set_dims(const DDim& dims) {
220
  ValidCheck(*this);
221
  if (is_runtime_) {
R
Ruibiao Chen 已提交
222
    auto* var = PADDLE_GET(Variable*, var_);
223 224 225 226 227 228 229 230 231 232 233
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
    } else if (var->IsType<framework::LoDTensorArray>()) {
      auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
      // Note: Here I want enforce `tensor_array->size() == 0UL`, because
      // inplace using on LoDTensorArray is dangerous, but the unittest
      // `test_list` contains this behavior
234 235
      PADDLE_ENFORCE_EQ(dims.size(),
                        1UL,
236 237 238 239
                        platform::errors::InvalidArgument(
                            "LoDTensorArray can only have one dimension."));
      // only set the array size for LoDTensorArray input
      tensor_array->resize(dims[0]);
C
Chen Weihang 已提交
240
    } else {
241 242
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dims from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
243
    }
244
  } else {
R
Ruibiao Chen 已提交
245
    auto* var = PADDLE_GET(VarDesc*, var_);
246
    var->SetShape(vectorize(dims));
C
Chen Weihang 已提交
247
  }
248 249 250
}

void CompatMetaTensor::set_dtype(phi::DataType dtype) {
251
  ValidCheck(*this);
252
  if (is_runtime_) {
R
Ruibiao Chen 已提交
253
    auto* var = PADDLE_GET(Variable*, var_);
254 255 256 257 258 259 260 261 262
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported set dtype for LoDTensorArray now
C
Chen Weihang 已提交
263
    } else {
264 265
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
266
    }
267
  } else {
R
Ruibiao Chen 已提交
268
    auto* var = PADDLE_GET(VarDesc*, var_);
269
    var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
C
Chen Weihang 已提交
270
  }
271 272 273
}

void CompatMetaTensor::set_layout(DataLayout layout) {
274
  ValidCheck(*this);
275
  if (is_runtime_) {
R
Ruibiao Chen 已提交
276
    auto* var = PADDLE_GET(Variable*, var_);
277 278 279 280 281 282 283
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
    } else if (var->IsType<framework::LoDTensorArray>()) {
284
      // NOTE(chenweihang): do nothing
285 286 287 288 289
      // Unsupported set dtype for LoDTensorArray now
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
290
    }
291 292 293
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported set layout for VarDesc now
C
Chen Weihang 已提交
294
  }
295 296 297
}

void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
298 299
  ValidCheck(*this);
  ValidCheck(meta_tensor);
300
  if (is_runtime_) {
R
Ruibiao Chen 已提交
301
    auto* var = PADDLE_GET(Variable*, var_);
302 303 304 305
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
C
Chen Weihang 已提交
306
    } else {
307 308
      // NOTE(chenweihang): do nothing
      // only LoDTensor need to share lod
C
Chen Weihang 已提交
309
    }
310
  } else {
R
Ruibiao Chen 已提交
311
    auto* var = PADDLE_GET(VarDesc*, var_);
312 313
    var->SetLoDLevel(
        static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
C
Chen Weihang 已提交
314
  }
315 316 317
}

void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
318 319
  ValidCheck(*this);
  ValidCheck(meta_tensor);
320 321
  set_dims(meta_tensor.dims());
  if (is_runtime_) {
R
Ruibiao Chen 已提交
322
    auto* var = PADDLE_GET(Variable*, var_);
323 324 325 326 327 328
    if (var->IsType<phi::SelectedRows>()) {
      auto* selected_rows = var->GetMutable<phi::SelectedRows>();
      auto& input_selected_rows =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
      selected_rows->set_rows(input_selected_rows.rows());
      selected_rows->set_height(input_selected_rows.height());
329
    }
330
  }
331 332 333 334 335 336 337 338 339
}

void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
  share_dims(meta_tensor);
  set_dtype(meta_tensor.dtype());
  set_layout(meta_tensor.layout());
  // special case: share lod of LoDTensor
  share_lod(meta_tensor);
}
C
Chen Weihang 已提交
340

341 342 343 344 345 346 347 348 349 350 351 352
void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
  int index = compat_inputs_.size();
  compat_inputs_.emplace_back(std::move(input));
  input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
  int index = compat_outputs_.size();
  compat_outputs_.emplace_back(std::move(output));
  output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}

void CompatInferMetaContext::EmplaceBackInputs(
C
Chen Weihang 已提交
353
    paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
354 355 356 357 358 359 360 361
  int index = compat_inputs_.size();
  input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
  compat_inputs_.insert(compat_inputs_.end(),
                        std::make_move_iterator(inputs.begin()),
                        std::make_move_iterator(inputs.end()));
}

void CompatInferMetaContext::EmplaceBackOutputs(
C
Chen Weihang 已提交
362
    paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
        outputs) {
  int index = compat_outputs_.size();
  output_range_.emplace_back(
      std::pair<int, int>(index, index + outputs.size()));
  compat_outputs_.insert(compat_outputs_.end(),
                         std::make_move_iterator(outputs.begin()),
                         std::make_move_iterator(outputs.end()));
}

const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
  return compat_inputs_.at(idx);
}

std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
    size_t start, size_t end) const {
  std::vector<const phi::MetaTensor*> result;
  result.reserve(end - start);

  for (size_t i = start; i < end; ++i) {
    auto& in = compat_inputs_.at(i);
    result.emplace_back(in.initialized() ? &in : nullptr);
  }

  return result;
}

389
paddle::optional<std::vector<const phi::MetaTensor*>>
390 391 392 393 394 395 396 397 398 399 400 401
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
  const auto& first = compat_inputs_.at(start);

  if (first.initialized()) {
    std::vector<const phi::MetaTensor*> result;
    result.reserve(end - start);

    for (size_t i = start; i < end; ++i) {
      auto& in = compat_inputs_.at(i);
      result.emplace_back(in.initialized() ? &in : nullptr);
    }

402 403
    return paddle::optional<std::vector<const phi::MetaTensor*>>(
        std::move(result));
404
  }
405
  return paddle::none;
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
}

phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
  auto& out = compat_outputs_.at(idx);
  return out.initialized() ? &out : nullptr;
}

std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
    size_t start, size_t end) {
  std::vector<phi::MetaTensor*> result;
  result.reserve(end - start);
  for (size_t i = start; i < end; ++i) {
    auto& out = compat_outputs_.at(i);
    result.emplace_back(out.initialized() ? &out : nullptr);
  }
  return result;
}

CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
                                             const std::string& op_type) {
426
  // 1. get kernel args
427
  auto* arg_map_fn = ctx->GetPhiArgumentMappingFn();
428
  InferShapeArgumentMappingContext arg_map_context(*ctx);
429 430 431
  phi::KernelSignature signature = arg_map_fn
                                       ? (*arg_map_fn)(arg_map_context)
                                       : *ctx->GetPhiDefaultKernelSignature();
432 433 434
  VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;

  // 2. build infermeta context
435
  CompatInferMetaContext infer_meta_context(
F
From00 已提交
436
      {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
437

438 439 440
  const auto& input_names = signature.input_names;
  const auto& attr_names = signature.attr_names;
  const auto& output_names = signature.output_names;
441

442 443 444
  const auto& args_def =
      phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
  const auto& attr_defs = args_def.attribute_defs();
445

446
  for (auto& in_name : input_names) {
447
    if (ctx->HasInputs(in_name)) {
448
      auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
449 450
      if (input_var.size() == 1) {
        infer_meta_context.EmplaceBackInput(
451
            std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
452
      } else {
C
Chen Weihang 已提交
453
        paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize>
454
            inputs;
455
        for (const auto& in : input_var) {
456 457
          inputs.emplace_back(
              std::move(CompatMetaTensor(in, ctx->IsRuntime())));
458 459 460
        }
        infer_meta_context.EmplaceBackInputs(std::move(inputs));
      }
461
    } else {
462 463
      infer_meta_context.EmplaceBackInput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
464
    }
465
  }
466

467 468
  VLOG(6) << "BuildInferMetaContext: Done inputs";

469
  auto attr_reader = ctx->Attrs();
470
  for (size_t i = 0; i < attr_names.size(); ++i) {
471
    auto& attr_name = attr_names[i];
472
    auto* attr_ptr = attr_reader.GetAttr(attr_name);
473 474 475
    bool is_attr_var = attr_ptr != nullptr && HasAttrVar(*attr_ptr);
    VLOG(6) << "BuildInferMetaContext: " << attr_name << ": "
            << attr_defs[i].type_index << ", is_attr_var: " << is_attr_var;
476 477
    switch (attr_defs[i].type_index) {
      case phi::AttributeType::SCALAR:
478
        if (attr_ptr && !is_attr_var) {
479 480 481 482
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::FLOAT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
483
                  phi::Scalar(PADDLE_GET_CONST(float, attr)));
484
              break;
485 486 487 488
            case framework::proto::AttrType::FLOAT64:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(double, attr)));
              break;
489 490
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
491
                  phi::Scalar(PADDLE_GET_CONST(int, attr)));
492 493 494
              break;
            case framework::proto::AttrType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
495
                  phi::Scalar(PADDLE_GET_CONST(std::string, attr)));
496
              break;
497 498 499 500
            case framework::proto::AttrType::BOOLEAN:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(bool, attr)));
              break;
501 502 503 504 505
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to Scalar when construct "
                  "InferMetaContext.",
                  attr_name));
506
          }
507 508 509 510
        } else if (ctx->HasInput(attr_name)) {
          auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
          if (infershape_input.size() == 1) {
            if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
511
              Variable* var = PADDLE_GET_CONST(Variable*, infershape_input[0]);
512 513 514 515 516 517 518
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiScalarFromVar(*var)));
            } else {
              phi::Scalar tensor_scalar(-1);
              tensor_scalar.SetFromTensor(true);
              infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar));
            }
519
          } else {
520 521 522
            PADDLE_THROW(platform::errors::InvalidArgument(
                "Invalid input.size() when cast op attribute `%s` to Scalar, "
                "expected 1, but actually is %d .",
523 524
                attr_name,
                infershape_input.size()));
525 526
          }
        } else {
527 528 529 530 531
          // do nothing, skip current attr
        }
        break;
      case phi::AttributeType::INT_ARRAY:
        // When attr is a vector_tensor or tensor, transform it to IntArray
532
        if (attr_ptr && !is_attr_var) {
533 534 535 536
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
537
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int32_t>, attr))));
538 539 540
              break;
            case framework::proto::AttrType::LONGS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
541
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int64_t>, attr))));
542 543 544
              break;
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
545
                  phi::IntArray({PADDLE_GET_CONST(int, attr)}));
546 547 548 549 550 551
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to IntArray when "
                  "construct InferMetaContext.",
                  attr_name));
552
          }
553 554 555 556 557 558 559 560
        } else if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
          auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
          if (ctx->IsRuntime()) {
            // If is in runtime, we will get tensor's value for IntArray
            // and push it into attrs
            std::vector<Variable*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); i++) {
R
Ruibiao Chen 已提交
561
              vars.push_back(PADDLE_GET_CONST(Variable*, infershape_inputs[i]));
562
            }
563 564 565 566 567 568
            if (infershape_inputs.size() != 1) {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVarList(vars)));
            } else {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVar(*vars[0])));
569
            }
570
          } else {
571 572 573 574
            // If is not in runtime, we will set default value(-1) for IntArray
            std::vector<VarDesc*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); ++i) {
R
Ruibiao Chen 已提交
575
              vars.push_back(PADDLE_GET_CONST(VarDesc*, infershape_inputs[i]));
576 577 578 579 580 581 582 583 584 585 586
            }

            int64_t num_ele = 0;
            if (vars.size() == 1) {
              num_ele = 1;
              const auto& tensor_dims = vars[0]->GetShape();
              for (size_t i = 0; i < tensor_dims.size(); ++i) {
                num_ele *= tensor_dims[i];
              }

              if (num_ele <= 0) {
587
                num_ele = tensor_dims.size();
588 589 590 591 592 593 594 595
              }

            } else {
              num_ele = vars.size();
            }
            phi::IntArray tensor_attr(std::vector<int32_t>(num_ele, -1));
            tensor_attr.SetFromTensor(true);
            infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
596 597
          }
        } else {
598
          // do nothing, skip current attr
599
        }
600 601 602 603 604 605
        break;
      case phi::AttributeType::SCALARS:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS: {
R
Ruibiao Chen 已提交
606
              const auto& vec = PADDLE_GET_CONST(std::vector<int32_t>, attr);
607 608 609 610 611 612 613 614
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::LONGS: {
R
Ruibiao Chen 已提交
615
              const auto& vec = PADDLE_GET_CONST(std::vector<int64_t>, attr);
616 617 618 619 620 621 622 623
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOATS: {
R
Ruibiao Chen 已提交
624
              const auto& vec = PADDLE_GET_CONST(std::vector<float>, attr);
625 626 627 628 629 630 631 632
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOAT64S: {
R
Ruibiao Chen 已提交
633
              const auto& vec = PADDLE_GET_CONST(std::vector<double>, attr);
634 635 636 637 638 639 640 641 642 643 644 645
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to vector<Scalar> when "
                  "construct KernelContext.",
                  attr_names[i]));
646 647
          }
        } else {
648
          // do nothing, skip current attr
649
        }
650 651 652 653 654 655
        break;
      default:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (attr_defs[i].type_index) {
            case phi::AttributeType::FLOAT32:
R
Ruibiao Chen 已提交
656
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
657
              break;
658 659 660 661
            case phi::AttributeType::FLOAT64:
              infer_meta_context.EmplaceBackAttr(
                  PADDLE_GET_CONST(double, attr));
              break;
662
            case phi::AttributeType::INT32:
R
Ruibiao Chen 已提交
663
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
664 665
              break;
            case phi::AttributeType::BOOL:
R
Ruibiao Chen 已提交
666
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(bool, attr));
667 668 669
              break;
            case phi::AttributeType::INT64:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
670
                  PADDLE_GET_CONST(int64_t, attr));
671 672 673
              break;
            case phi::AttributeType::INT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
674
                  PADDLE_GET_CONST(std::vector<int>, attr));
675 676 677 678
              break;
            case phi::AttributeType::DATA_TYPE: {
              auto data_type = paddle::framework::TransToPhiDataType(
                  static_cast<framework::proto::VarType::Type>(
R
Ruibiao Chen 已提交
679
                      PADDLE_GET_CONST(int, attr)));
680 681 682 683
              infer_meta_context.EmplaceBackAttr(data_type);
            } break;
            case phi::AttributeType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
684
                  PADDLE_GET_CONST(std::string, attr));
685 686 687 688 689
              break;
            case phi::AttributeType::INT64S:
              switch (AttrTypeID(attr)) {
                case framework::proto::AttrType::LONGS:
                  infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
690
                      PADDLE_GET_CONST(std::vector<int64_t>, attr));
691 692 693
                  break;
                case framework::proto::AttrType::INTS: {
                  const auto& vector_int_attr =
R
Ruibiao Chen 已提交
694
                      PADDLE_GET_CONST(std::vector<int>, attr);
695 696 697 698 699 700 701 702 703 704 705 706 707 708
                  const std::vector<int64_t> vector_int64_attr(
                      vector_int_attr.begin(), vector_int_attr.end());
                  infer_meta_context.EmplaceBackAttr(vector_int64_attr);
                } break;
                default:
                  PADDLE_THROW(platform::errors::Unimplemented(
                      "Unsupported cast op attribute `%s` to vector<int64_t> "
                      "when "
                      "construct KernelContext.",
                      attr_names[i]));
              }
              break;
            case phi::AttributeType::FLOAT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
709
                  PADDLE_GET_CONST(std::vector<float>, attr));
710 711 712
              break;
            case phi::AttributeType::STRINGS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
713
                  PADDLE_GET_CONST(std::vector<std::string>, attr));
714 715 716
              break;
            case phi::AttributeType::BOOLS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
717
                  PADDLE_GET_CONST(std::vector<bool>, attr));
718 719 720
              break;
            case phi::AttributeType::FLOAT64S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
721
                  PADDLE_GET_CONST(std::vector<double>, attr));
722 723 724 725 726 727 728
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` when construct "
                  "KernelContext in dygraph.",
                  attr_names[i]));
          }
H
hong 已提交
729
        } else {
730
          // do nothing, skip currnet attr
H
hong 已提交
731
        }
732 733 734
    }
  }

735 736
  VLOG(6) << "BuildInferMetaContext: Done attrs";

737
  for (auto& out_name : output_names) {
738
    if (ctx->HasOutputs(out_name, true)) {
739
      auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
740
      if (output_var.size() == 1) {
741 742
        infer_meta_context.EmplaceBackOutput(
            std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
743
      } else {
C
Chen Weihang 已提交
744
        paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
745
            outputs;
746
        for (const auto& out : output_var) {
747
          if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
748
            if (PADDLE_GET_CONST(Variable*, out)) {
749
              outputs.emplace_back(
750
                  std::move(CompatMetaTensor(out, ctx->IsRuntime())));
751 752
              continue;
            }
R
Ruibiao Chen 已提交
753
          } else if (PADDLE_GET_CONST(VarDesc*, out)) {
754
            outputs.emplace_back(
755
                std::move(CompatMetaTensor(out, ctx->IsRuntime())));
756 757
            continue;
          }
758
          outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
759 760 761 762
        }
        infer_meta_context.EmplaceBackOutputs(std::move(outputs));
      }
    } else {
763 764
      infer_meta_context.EmplaceBackOutput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
765
    }
766 767
  }

768 769
  VLOG(6) << "BuildInferMetaContext: Done outputs";

770 771 772
  return infer_meta_context;
}

C
Chen Weihang 已提交
773 774
}  // namespace framework
}  // namespace paddle