infershape_utils.cc 28.9 KB
Newer Older
C
Chen Weihang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"

17
#include <algorithm>
18 19
#include <string>

20
#include "paddle/fluid/framework/convert_utils.h"
C
Chen Weihang 已提交
21
#include "paddle/fluid/framework/framework.pb.h"
22
#include "paddle/fluid/framework/phi_utils.h"
C
Chen Weihang 已提交
23
#include "paddle/fluid/platform/enforce.h"
24
#include "paddle/phi/common/int_array.h"
25
#include "paddle/phi/common/scalar.h"
26 27 28 29 30
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
31
#include "paddle/phi/core/kernel_factory.h"
32
#include "paddle/phi/core/tensor_utils.h"
C
Chen Weihang 已提交
33 34 35 36

namespace paddle {
namespace framework {

37
class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
C
Chen Weihang 已提交
38 39 40 41 42 43 44 45 46 47 48 49
 public:
  explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx)
      : ctx_(ctx) {}

  bool HasInput(const std::string& name) const override {
    return ctx_.HasInput(name);
  }

  bool HasOutput(const std::string& name) const override {
    return ctx_.HasOutput(name);
  }

50 51 52 53
  bool HasAttr(const std::string& name) const override {
    return ctx_.HasAttr(name);
  }

C
Chen Weihang 已提交
54
  paddle::any Attr(const std::string& name) const override {
55 56
    auto* attr = ctx_.Attrs().GetAttr(name);
    PADDLE_ENFORCE_NOT_NULL(
57 58 59
        attr,
        platform::errors::NotFound("Attribute (%s) should be in AttributeMap.",
                                   name));
60
    return GetAttrValue(*attr);
C
Chen Weihang 已提交
61 62 63
  }

  size_t InputSize(const std::string& name) const override {
64 65 66 67 68 69
    if (ctx_.HasInputs(name)) {
      return ctx_.Inputs(name).size();
    } else if (ctx_.HasInput(name)) {
      return 1;
    }
    return 0;
C
Chen Weihang 已提交
70 71 72 73 74 75 76
  }

  size_t OutputSize(const std::string& name) const override {
    return ctx_.Outputs(name).size();
  }

  bool IsDenseTensorInput(const std::string& name) const override {
77 78 79 80 81
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::LOD_TENSOR;
  }

  bool IsDenseTensorInputs(const std::string& name) const override {
C
Chen Weihang 已提交
82
    auto var_types = ctx_.GetInputsVarType(name);
83 84
    return std::all_of(var_types.begin(),
                       var_types.end(),
85 86 87
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
C
Chen Weihang 已提交
88 89 90
  }

  bool IsSelectedRowsInput(const std::string& name) const override {
91 92
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::SELECTED_ROWS;
C
Chen Weihang 已提交
93 94
  }

95 96
  bool IsDenseTensorVectorInput(const std::string& name) const override {
    auto var_types = ctx_.GetInputsVarType(name);
97 98
    return std::all_of(var_types.begin(),
                       var_types.end(),
99 100 101
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR_ARRAY;
                       });
102 103
  }

104 105
  bool IsDenseTensorOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
106 107
    return std::all_of(var_types.begin(),
                       var_types.end(),
108 109 110
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
111 112 113 114
  }

  bool IsSelectedRowsOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
115 116
    return std::all_of(var_types.begin(),
                       var_types.end(),
117 118 119
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::SELECTED_ROWS;
                       });
120 121
  }

122 123
  bool IsForInferShape() const override { return true; }

124 125
  bool IsRuntime() const override { return ctx_.IsRuntime(); }

C
Chen Weihang 已提交
126 127 128 129
 private:
  const InferShapeContext& ctx_;
};

130 131 132 133 134 135 136
static inline void ValidCheck(const phi::MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current CompatMetaTensor is not initialized."));
}

137
int64_t CompatMetaTensor::numel() const {
138
  ValidCheck(*this);
139
  if (is_runtime_) {
R
Ruibiao Chen 已提交
140
    auto* var = PADDLE_GET_CONST(Variable*, var_);
141 142
    return var->Get<Tensor>().numel();
  } else {
R
Ruibiao Chen 已提交
143
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
144
    return var->ElementSize();
C
Chen Weihang 已提交
145
  }
146
}
C
Chen Weihang 已提交
147

148
DDim CompatMetaTensor::dims() const {
149
  ValidCheck(*this);
150
  if (is_runtime_) {
R
Ruibiao Chen 已提交
151
    auto* var = PADDLE_GET_CONST(Variable*, var_);
152 153 154 155 156 157 158 159
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dims();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().dims();
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // use tensor array size as dims
      auto& tensor_array = var->Get<framework::LoDTensorArray>();
      return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
C
Chen Weihang 已提交
160
    } else {
161 162 163
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dims from DenseTensor or SelectedRows or "
          "DenseTensorArray."));
C
Chen Weihang 已提交
164
    }
165
  } else {
R
Ruibiao Chen 已提交
166
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
167 168 169

    return var->GetShape().empty() ? phi::make_ddim({0UL})
                                   : phi::make_ddim(var->GetShape());
C
Chen Weihang 已提交
170
  }
171
}
C
Chen Weihang 已提交
172

173
phi::DataType CompatMetaTensor::dtype() const {
174
  ValidCheck(*this);
175
  if (is_runtime_) {
R
Ruibiao Chen 已提交
176
    auto* var = PADDLE_GET_CONST(Variable*, var_);
177 178 179 180 181 182 183 184
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dtype();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().dtype();
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported get dtype from LoDTensorArray now
      return phi::DataType::UNDEFINED;
C
Chen Weihang 已提交
185
    } else {
186 187
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
188
    }
189
  } else {
R
Ruibiao Chen 已提交
190
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
191
    return paddle::framework::TransToPhiDataType(var->GetDataType());
C
Chen Weihang 已提交
192
  }
193
}
C
Chen Weihang 已提交
194

195
DataLayout CompatMetaTensor::layout() const {
196
  ValidCheck(*this);
197
  if (is_runtime_) {
R
Ruibiao Chen 已提交
198
    auto* var = PADDLE_GET_CONST(Variable*, var_);
199 200 201 202 203
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().layout();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().layout();
    } else if (var->IsType<framework::LoDTensorArray>()) {
204
      // NOTE(chenweihang): do nothing
205 206 207 208 209 210
      // Unsupported get layout from LoDTensorArray now
      return phi::DataLayout::UNDEFINED;
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
211
    }
212 213 214 215
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported get layout for VarDesc now
    return DataLayout::UNDEFINED;
C
Chen Weihang 已提交
216
  }
217 218 219
}

void CompatMetaTensor::set_dims(const DDim& dims) {
220
  ValidCheck(*this);
221
  if (is_runtime_) {
R
Ruibiao Chen 已提交
222
    auto* var = PADDLE_GET(Variable*, var_);
223 224 225 226 227 228 229 230 231 232 233
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
    } else if (var->IsType<framework::LoDTensorArray>()) {
      auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
      // Note: Here I want enforce `tensor_array->size() == 0UL`, because
      // inplace using on LoDTensorArray is dangerous, but the unittest
      // `test_list` contains this behavior
234 235
      PADDLE_ENFORCE_EQ(dims.size(),
                        1UL,
236 237 238 239
                        platform::errors::InvalidArgument(
                            "LoDTensorArray can only have one dimension."));
      // only set the array size for LoDTensorArray input
      tensor_array->resize(dims[0]);
C
Chen Weihang 已提交
240
    } else {
241 242
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dims from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
243
    }
244
  } else {
R
Ruibiao Chen 已提交
245
    auto* var = PADDLE_GET(VarDesc*, var_);
246
    var->SetShape(vectorize(dims));
C
Chen Weihang 已提交
247
  }
248 249 250
}

void CompatMetaTensor::set_dtype(phi::DataType dtype) {
251
  ValidCheck(*this);
252
  if (is_runtime_) {
R
Ruibiao Chen 已提交
253
    auto* var = PADDLE_GET(Variable*, var_);
254 255 256 257 258 259 260 261 262
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported set dtype for LoDTensorArray now
C
Chen Weihang 已提交
263
    } else {
264 265
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
266
    }
267
  } else {
R
Ruibiao Chen 已提交
268
    auto* var = PADDLE_GET(VarDesc*, var_);
269
    var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
C
Chen Weihang 已提交
270
  }
271 272 273
}

void CompatMetaTensor::set_layout(DataLayout layout) {
274
  ValidCheck(*this);
275
  if (is_runtime_) {
R
Ruibiao Chen 已提交
276
    auto* var = PADDLE_GET(Variable*, var_);
277 278 279 280 281 282 283
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
    } else if (var->IsType<framework::LoDTensorArray>()) {
284
      // NOTE(chenweihang): do nothing
285 286 287 288 289
      // Unsupported set dtype for LoDTensorArray now
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
290
    }
291 292 293
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported set layout for VarDesc now
C
Chen Weihang 已提交
294
  }
295 296 297
}

void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
298 299
  ValidCheck(*this);
  ValidCheck(meta_tensor);
300
  if (is_runtime_) {
R
Ruibiao Chen 已提交
301
    auto* var = PADDLE_GET(Variable*, var_);
302 303 304 305
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
C
Chen Weihang 已提交
306
    } else {
307 308
      // NOTE(chenweihang): do nothing
      // only LoDTensor need to share lod
C
Chen Weihang 已提交
309
    }
310
  } else {
R
Ruibiao Chen 已提交
311
    auto* var = PADDLE_GET(VarDesc*, var_);
312 313
    var->SetLoDLevel(
        static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
C
Chen Weihang 已提交
314
  }
315 316 317
}

void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
318 319
  ValidCheck(*this);
  ValidCheck(meta_tensor);
320 321
  set_dims(meta_tensor.dims());
  if (is_runtime_) {
R
Ruibiao Chen 已提交
322
    auto* var = PADDLE_GET(Variable*, var_);
323 324 325 326 327 328
    if (var->IsType<phi::SelectedRows>()) {
      auto* selected_rows = var->GetMutable<phi::SelectedRows>();
      auto& input_selected_rows =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
      selected_rows->set_rows(input_selected_rows.rows());
      selected_rows->set_height(input_selected_rows.height());
329
    }
330
  }
331 332 333 334 335 336 337 338 339
}

void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
  share_dims(meta_tensor);
  set_dtype(meta_tensor.dtype());
  set_layout(meta_tensor.layout());
  // special case: share lod of LoDTensor
  share_lod(meta_tensor);
}
C
Chen Weihang 已提交
340

341 342 343 344 345 346 347 348 349 350 351 352
void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
  int index = compat_inputs_.size();
  compat_inputs_.emplace_back(std::move(input));
  input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
  int index = compat_outputs_.size();
  compat_outputs_.emplace_back(std::move(output));
  output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}

void CompatInferMetaContext::EmplaceBackInputs(
C
Chen Weihang 已提交
353
    paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
354 355 356 357 358 359 360 361
  int index = compat_inputs_.size();
  input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
  compat_inputs_.insert(compat_inputs_.end(),
                        std::make_move_iterator(inputs.begin()),
                        std::make_move_iterator(inputs.end()));
}

void CompatInferMetaContext::EmplaceBackOutputs(
C
Chen Weihang 已提交
362
    paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
        outputs) {
  int index = compat_outputs_.size();
  output_range_.emplace_back(
      std::pair<int, int>(index, index + outputs.size()));
  compat_outputs_.insert(compat_outputs_.end(),
                         std::make_move_iterator(outputs.begin()),
                         std::make_move_iterator(outputs.end()));
}

const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
  return compat_inputs_.at(idx);
}

std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
    size_t start, size_t end) const {
  std::vector<const phi::MetaTensor*> result;
  result.reserve(end - start);

  for (size_t i = start; i < end; ++i) {
    auto& in = compat_inputs_.at(i);
    result.emplace_back(in.initialized() ? &in : nullptr);
  }

  return result;
}

389
paddle::optional<std::vector<const phi::MetaTensor*>>
390 391 392 393 394 395 396 397 398 399 400 401
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
  const auto& first = compat_inputs_.at(start);

  if (first.initialized()) {
    std::vector<const phi::MetaTensor*> result;
    result.reserve(end - start);

    for (size_t i = start; i < end; ++i) {
      auto& in = compat_inputs_.at(i);
      result.emplace_back(in.initialized() ? &in : nullptr);
    }

402 403
    return paddle::optional<std::vector<const phi::MetaTensor*>>(
        std::move(result));
404
  }
405
  return paddle::none;
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
}

phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
  auto& out = compat_outputs_.at(idx);
  return out.initialized() ? &out : nullptr;
}

std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
    size_t start, size_t end) {
  std::vector<phi::MetaTensor*> result;
  result.reserve(end - start);
  for (size_t i = start; i < end; ++i) {
    auto& out = compat_outputs_.at(i);
    result.emplace_back(out.initialized() ? &out : nullptr);
  }
  return result;
}

CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
                                             const std::string& op_type) {
426
  // 1. get kernel args
427
  auto* arg_map_fn = ctx->GetPhiArgumentMappingFn();
428
  InferShapeArgumentMappingContext arg_map_context(*ctx);
429 430 431
  phi::KernelSignature signature = arg_map_fn
                                       ? (*arg_map_fn)(arg_map_context)
                                       : *ctx->GetPhiDefaultKernelSignature();
432 433 434
  VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;

  // 2. build infermeta context
435
  CompatInferMetaContext infer_meta_context(
F
From00 已提交
436
      {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
437

438 439 440
  const auto& input_names = signature.input_names;
  const auto& attr_names = signature.attr_names;
  const auto& output_names = signature.output_names;
441

442 443 444
  const auto& args_def =
      phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
  const auto& attr_defs = args_def.attribute_defs();
445

446
  for (auto& in_name : input_names) {
447
    if (ctx->HasInputs(in_name)) {
448
      auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
449 450
      if (input_var.size() == 1) {
        infer_meta_context.EmplaceBackInput(
451
            std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
452
      } else {
C
Chen Weihang 已提交
453
        paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize>
454
            inputs;
455
        for (const auto& in : input_var) {
456 457
          inputs.emplace_back(
              std::move(CompatMetaTensor(in, ctx->IsRuntime())));
458 459 460
        }
        infer_meta_context.EmplaceBackInputs(std::move(inputs));
      }
461
    } else {
462 463
      infer_meta_context.EmplaceBackInput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
464
    }
465
  }
466

467 468
  VLOG(6) << "BuildInferMetaContext: Done inputs";

469
  auto attr_reader = ctx->Attrs();
470
  for (size_t i = 0; i < attr_names.size(); ++i) {
471
    auto& attr_name = attr_names[i];
472
    auto* attr_ptr = attr_reader.GetAttr(attr_name);
473 474 475
    bool is_attr_var = attr_ptr != nullptr && HasAttrVar(*attr_ptr);
    VLOG(6) << "BuildInferMetaContext: " << attr_name << ": "
            << attr_defs[i].type_index << ", is_attr_var: " << is_attr_var;
476 477
    switch (attr_defs[i].type_index) {
      case phi::AttributeType::SCALAR:
478
        if (attr_ptr && !is_attr_var) {
479 480 481 482
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::FLOAT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
483
                  phi::Scalar(PADDLE_GET_CONST(float, attr)));
484 485 486
              break;
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
487
                  phi::Scalar(PADDLE_GET_CONST(int, attr)));
488 489 490
              break;
            case framework::proto::AttrType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
491
                  phi::Scalar(PADDLE_GET_CONST(std::string, attr)));
492
              break;
493 494 495 496
            case framework::proto::AttrType::BOOLEAN:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(bool, attr)));
              break;
497 498 499 500 501
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to Scalar when construct "
                  "InferMetaContext.",
                  attr_name));
502
          }
503 504 505 506
        } else if (ctx->HasInput(attr_name)) {
          auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
          if (infershape_input.size() == 1) {
            if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
507
              Variable* var = PADDLE_GET_CONST(Variable*, infershape_input[0]);
508 509 510 511 512 513 514
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiScalarFromVar(*var)));
            } else {
              phi::Scalar tensor_scalar(-1);
              tensor_scalar.SetFromTensor(true);
              infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar));
            }
515
          } else {
516 517 518
            PADDLE_THROW(platform::errors::InvalidArgument(
                "Invalid input.size() when cast op attribute `%s` to Scalar, "
                "expected 1, but actually is %d .",
519 520
                attr_name,
                infershape_input.size()));
521 522
          }
        } else {
523 524 525 526 527
          // do nothing, skip current attr
        }
        break;
      case phi::AttributeType::INT_ARRAY:
        // When attr is a vector_tensor or tensor, transform it to IntArray
528
        if (attr_ptr && !is_attr_var) {
529 530 531 532
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
533
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int32_t>, attr))));
534 535 536
              break;
            case framework::proto::AttrType::LONGS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
537
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int64_t>, attr))));
538 539 540
              break;
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
541
                  phi::IntArray({PADDLE_GET_CONST(int, attr)}));
542 543 544 545 546 547
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to IntArray when "
                  "construct InferMetaContext.",
                  attr_name));
548
          }
549 550 551 552 553 554 555 556
        } else if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
          auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
          if (ctx->IsRuntime()) {
            // If is in runtime, we will get tensor's value for IntArray
            // and push it into attrs
            std::vector<Variable*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); i++) {
R
Ruibiao Chen 已提交
557
              vars.push_back(PADDLE_GET_CONST(Variable*, infershape_inputs[i]));
558
            }
559 560 561 562 563 564
            if (infershape_inputs.size() != 1) {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVarList(vars)));
            } else {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVar(*vars[0])));
565
            }
566
          } else {
567 568 569 570
            // If is not in runtime, we will set default value(-1) for IntArray
            std::vector<VarDesc*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); ++i) {
R
Ruibiao Chen 已提交
571
              vars.push_back(PADDLE_GET_CONST(VarDesc*, infershape_inputs[i]));
572 573 574 575 576 577 578 579 580 581 582
            }

            int64_t num_ele = 0;
            if (vars.size() == 1) {
              num_ele = 1;
              const auto& tensor_dims = vars[0]->GetShape();
              for (size_t i = 0; i < tensor_dims.size(); ++i) {
                num_ele *= tensor_dims[i];
              }

              if (num_ele <= 0) {
583
                num_ele = tensor_dims.size();
584 585 586 587 588 589 590 591
              }

            } else {
              num_ele = vars.size();
            }
            phi::IntArray tensor_attr(std::vector<int32_t>(num_ele, -1));
            tensor_attr.SetFromTensor(true);
            infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
592 593
          }
        } else {
594
          // do nothing, skip current attr
595
        }
596 597 598 599 600 601
        break;
      case phi::AttributeType::SCALARS:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS: {
R
Ruibiao Chen 已提交
602
              const auto& vec = PADDLE_GET_CONST(std::vector<int32_t>, attr);
603 604 605 606 607 608 609 610
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::LONGS: {
R
Ruibiao Chen 已提交
611
              const auto& vec = PADDLE_GET_CONST(std::vector<int64_t>, attr);
612 613 614 615 616 617 618 619
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOATS: {
R
Ruibiao Chen 已提交
620
              const auto& vec = PADDLE_GET_CONST(std::vector<float>, attr);
621 622 623 624 625 626 627 628
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOAT64S: {
R
Ruibiao Chen 已提交
629
              const auto& vec = PADDLE_GET_CONST(std::vector<double>, attr);
630 631 632 633 634 635 636 637 638 639 640 641
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to vector<Scalar> when "
                  "construct KernelContext.",
                  attr_names[i]));
642 643
          }
        } else {
644
          // do nothing, skip current attr
645
        }
646 647 648 649 650 651
        break;
      default:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (attr_defs[i].type_index) {
            case phi::AttributeType::FLOAT32:
R
Ruibiao Chen 已提交
652
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
653 654
              break;
            case phi::AttributeType::INT32:
R
Ruibiao Chen 已提交
655
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
656 657
              break;
            case phi::AttributeType::BOOL:
R
Ruibiao Chen 已提交
658
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(bool, attr));
659 660 661
              break;
            case phi::AttributeType::INT64:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
662
                  PADDLE_GET_CONST(int64_t, attr));
663 664 665
              break;
            case phi::AttributeType::INT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
666
                  PADDLE_GET_CONST(std::vector<int>, attr));
667 668 669 670
              break;
            case phi::AttributeType::DATA_TYPE: {
              auto data_type = paddle::framework::TransToPhiDataType(
                  static_cast<framework::proto::VarType::Type>(
R
Ruibiao Chen 已提交
671
                      PADDLE_GET_CONST(int, attr)));
672 673 674 675
              infer_meta_context.EmplaceBackAttr(data_type);
            } break;
            case phi::AttributeType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
676
                  PADDLE_GET_CONST(std::string, attr));
677 678 679 680 681
              break;
            case phi::AttributeType::INT64S:
              switch (AttrTypeID(attr)) {
                case framework::proto::AttrType::LONGS:
                  infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
682
                      PADDLE_GET_CONST(std::vector<int64_t>, attr));
683 684 685
                  break;
                case framework::proto::AttrType::INTS: {
                  const auto& vector_int_attr =
R
Ruibiao Chen 已提交
686
                      PADDLE_GET_CONST(std::vector<int>, attr);
687 688 689 690 691 692 693 694 695 696 697 698 699 700
                  const std::vector<int64_t> vector_int64_attr(
                      vector_int_attr.begin(), vector_int_attr.end());
                  infer_meta_context.EmplaceBackAttr(vector_int64_attr);
                } break;
                default:
                  PADDLE_THROW(platform::errors::Unimplemented(
                      "Unsupported cast op attribute `%s` to vector<int64_t> "
                      "when "
                      "construct KernelContext.",
                      attr_names[i]));
              }
              break;
            case phi::AttributeType::FLOAT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
701
                  PADDLE_GET_CONST(std::vector<float>, attr));
702 703 704
              break;
            case phi::AttributeType::STRINGS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
705
                  PADDLE_GET_CONST(std::vector<std::string>, attr));
706 707 708
              break;
            case phi::AttributeType::BOOLS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
709
                  PADDLE_GET_CONST(std::vector<bool>, attr));
710 711 712
              break;
            case phi::AttributeType::FLOAT64S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
713
                  PADDLE_GET_CONST(std::vector<double>, attr));
714 715 716 717 718 719 720
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` when construct "
                  "KernelContext in dygraph.",
                  attr_names[i]));
          }
H
hong 已提交
721
        } else {
722
          // do nothing, skip currnet attr
H
hong 已提交
723
        }
724 725 726
    }
  }

727 728
  VLOG(6) << "BuildInferMetaContext: Done attrs";

729
  for (auto& out_name : output_names) {
730
    if (ctx->HasOutputs(out_name, true)) {
731
      auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
732
      if (output_var.size() == 1) {
733 734
        infer_meta_context.EmplaceBackOutput(
            std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
735
      } else {
C
Chen Weihang 已提交
736
        paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
737
            outputs;
738
        for (const auto& out : output_var) {
739
          if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
740
            if (PADDLE_GET_CONST(Variable*, out)) {
741
              outputs.emplace_back(
742
                  std::move(CompatMetaTensor(out, ctx->IsRuntime())));
743 744
              continue;
            }
R
Ruibiao Chen 已提交
745
          } else if (PADDLE_GET_CONST(VarDesc*, out)) {
746
            outputs.emplace_back(
747
                std::move(CompatMetaTensor(out, ctx->IsRuntime())));
748 749
            continue;
          }
750
          outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
751 752 753 754
        }
        infer_meta_context.EmplaceBackOutputs(std::move(outputs));
      }
    } else {
755 756
      infer_meta_context.EmplaceBackOutput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
757
    }
758 759
  }

760 761
  VLOG(6) << "BuildInferMetaContext: Done outputs";

762 763 764
  return infer_meta_context;
}

C
Chen Weihang 已提交
765 766
}  // namespace framework
}  // namespace paddle