infershape_utils.cc 29.6 KB
Newer Older
C
Chen Weihang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"

17
#include <algorithm>
18 19
#include <string>

20
#include "paddle/fluid/framework/convert_utils.h"
C
Chen Weihang 已提交
21
#include "paddle/fluid/framework/framework.pb.h"
22
#include "paddle/fluid/framework/phi_utils.h"
C
Chen Weihang 已提交
23
#include "paddle/fluid/platform/enforce.h"
24
#include "paddle/phi/common/int_array.h"
25
#include "paddle/phi/common/scalar.h"
26 27 28 29 30
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
31
#include "paddle/phi/core/kernel_factory.h"
32
#include "paddle/phi/core/tensor_utils.h"
C
Chen Weihang 已提交
33 34 35 36

namespace paddle {
namespace framework {

37
class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
C
Chen Weihang 已提交
38 39 40 41 42 43 44 45 46 47 48 49
 public:
  explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx)
      : ctx_(ctx) {}

  bool HasInput(const std::string& name) const override {
    return ctx_.HasInput(name);
  }

  bool HasOutput(const std::string& name) const override {
    return ctx_.HasOutput(name);
  }

50 51 52 53
  bool HasAttr(const std::string& name) const override {
    return ctx_.HasAttr(name);
  }

C
Chen Weihang 已提交
54
  paddle::any Attr(const std::string& name) const override {
55 56
    auto* attr = ctx_.Attrs().GetAttr(name);
    PADDLE_ENFORCE_NOT_NULL(
57 58 59
        attr,
        platform::errors::NotFound("Attribute (%s) should be in AttributeMap.",
                                   name));
60
    return GetAttrValue(*attr);
C
Chen Weihang 已提交
61 62 63
  }

  size_t InputSize(const std::string& name) const override {
64 65 66 67 68 69
    if (ctx_.HasInputs(name)) {
      return ctx_.Inputs(name).size();
    } else if (ctx_.HasInput(name)) {
      return 1;
    }
    return 0;
C
Chen Weihang 已提交
70 71 72 73 74 75 76
  }

  size_t OutputSize(const std::string& name) const override {
    return ctx_.Outputs(name).size();
  }

  bool IsDenseTensorInput(const std::string& name) const override {
77 78 79 80 81
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::LOD_TENSOR;
  }

  bool IsDenseTensorInputs(const std::string& name) const override {
C
Chen Weihang 已提交
82
    auto var_types = ctx_.GetInputsVarType(name);
83 84
    return std::all_of(var_types.begin(),
                       var_types.end(),
85 86 87
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
C
Chen Weihang 已提交
88 89 90
  }

  bool IsSelectedRowsInput(const std::string& name) const override {
91 92
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::SELECTED_ROWS;
C
Chen Weihang 已提交
93 94
  }

95 96
  bool IsDenseTensorVectorInput(const std::string& name) const override {
    auto var_types = ctx_.GetInputsVarType(name);
97 98
    return std::all_of(var_types.begin(),
                       var_types.end(),
99 100 101
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR_ARRAY;
                       });
102 103
  }

104 105
  bool IsDenseTensorOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
106 107
    return std::all_of(var_types.begin(),
                       var_types.end(),
108 109 110
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
111 112 113 114
  }

  bool IsSelectedRowsOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
115 116
    return std::all_of(var_types.begin(),
                       var_types.end(),
117 118 119
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::SELECTED_ROWS;
                       });
120 121
  }

122 123
  bool IsForInferShape() const override { return true; }

124 125
  bool IsRuntime() const override { return ctx_.IsRuntime(); }

C
Chen Weihang 已提交
126 127 128 129
 private:
  const InferShapeContext& ctx_;
};

130 131 132 133 134 135 136
static inline void ValidCheck(const phi::MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current CompatMetaTensor is not initialized."));
}

137
int64_t CompatMetaTensor::numel() const {
138
  ValidCheck(*this);
139
  if (is_runtime_) {
R
Ruibiao Chen 已提交
140
    auto* var = PADDLE_GET_CONST(Variable*, var_);
141 142
    return var->Get<Tensor>().numel();
  } else {
R
Ruibiao Chen 已提交
143
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
144
    return var->ElementSize();
C
Chen Weihang 已提交
145
  }
146
}
C
Chen Weihang 已提交
147

148
DDim CompatMetaTensor::dims() const {
149
  ValidCheck(*this);
150
  if (is_runtime_) {
R
Ruibiao Chen 已提交
151
    auto* var = PADDLE_GET_CONST(Variable*, var_);
152 153 154 155 156 157 158 159
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dims();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().dims();
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // use tensor array size as dims
      auto& tensor_array = var->Get<framework::LoDTensorArray>();
      return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
C
Chen Weihang 已提交
160
    } else {
161 162 163
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dims from DenseTensor or SelectedRows or "
          "DenseTensorArray."));
C
Chen Weihang 已提交
164
    }
165
  } else {
R
Ruibiao Chen 已提交
166
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
167 168 169

    return var->GetShape().empty() ? phi::make_ddim({0UL})
                                   : phi::make_ddim(var->GetShape());
C
Chen Weihang 已提交
170
  }
171
}
C
Chen Weihang 已提交
172

173
phi::DataType CompatMetaTensor::dtype() const {
174
  ValidCheck(*this);
175
  if (is_runtime_) {
R
Ruibiao Chen 已提交
176
    auto* var = PADDLE_GET_CONST(Variable*, var_);
177 178 179 180 181 182 183 184
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dtype();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().dtype();
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported get dtype from LoDTensorArray now
      return phi::DataType::UNDEFINED;
C
Chen Weihang 已提交
185
    } else {
186 187
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
188
    }
189
  } else {
R
Ruibiao Chen 已提交
190
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
191
    return paddle::framework::TransToPhiDataType(var->GetDataType());
C
Chen Weihang 已提交
192
  }
193
}
C
Chen Weihang 已提交
194

195
DataLayout CompatMetaTensor::layout() const {
196
  ValidCheck(*this);
197
  if (is_runtime_) {
R
Ruibiao Chen 已提交
198
    auto* var = PADDLE_GET_CONST(Variable*, var_);
199 200 201 202 203
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().layout();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().layout();
    } else if (var->IsType<framework::LoDTensorArray>()) {
204
      // NOTE(chenweihang): do nothing
205 206 207 208 209 210
      // Unsupported get layout from LoDTensorArray now
      return phi::DataLayout::UNDEFINED;
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
211
    }
212 213 214 215
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported get layout for VarDesc now
    return DataLayout::UNDEFINED;
C
Chen Weihang 已提交
216
  }
217 218 219
}

void CompatMetaTensor::set_dims(const DDim& dims) {
220
  ValidCheck(*this);
221
  if (is_runtime_) {
R
Ruibiao Chen 已提交
222
    auto* var = PADDLE_GET(Variable*, var_);
223 224 225 226 227 228 229 230 231 232 233
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
    } else if (var->IsType<framework::LoDTensorArray>()) {
      auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
      // Note: Here I want enforce `tensor_array->size() == 0UL`, because
      // inplace using on LoDTensorArray is dangerous, but the unittest
      // `test_list` contains this behavior
234 235
      PADDLE_ENFORCE_EQ(dims.size(),
                        1UL,
236 237 238 239
                        platform::errors::InvalidArgument(
                            "LoDTensorArray can only have one dimension."));
      // only set the array size for LoDTensorArray input
      tensor_array->resize(dims[0]);
C
Chen Weihang 已提交
240
    } else {
241 242
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dims from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
243
    }
244
  } else {
R
Ruibiao Chen 已提交
245
    auto* var = PADDLE_GET(VarDesc*, var_);
246
    var->SetShape(vectorize(dims));
C
Chen Weihang 已提交
247
  }
248 249 250
}

void CompatMetaTensor::set_dtype(phi::DataType dtype) {
251
  ValidCheck(*this);
252
  if (is_runtime_) {
R
Ruibiao Chen 已提交
253
    auto* var = PADDLE_GET(Variable*, var_);
254 255 256 257 258 259 260 261 262
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported set dtype for LoDTensorArray now
C
Chen Weihang 已提交
263
    } else {
264 265
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
266
    }
267
  } else {
R
Ruibiao Chen 已提交
268
    auto* var = PADDLE_GET(VarDesc*, var_);
269
    var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
C
Chen Weihang 已提交
270
  }
271 272 273
}

void CompatMetaTensor::set_layout(DataLayout layout) {
274
  ValidCheck(*this);
275
  if (is_runtime_) {
R
Ruibiao Chen 已提交
276
    auto* var = PADDLE_GET(Variable*, var_);
277 278 279 280 281 282 283
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
    } else if (var->IsType<framework::LoDTensorArray>()) {
284
      // NOTE(chenweihang): do nothing
285 286 287 288 289
      // Unsupported set dtype for LoDTensorArray now
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
290
    }
291 292 293
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported set layout for VarDesc now
C
Chen Weihang 已提交
294
  }
295 296 297
}

void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
298 299
  ValidCheck(*this);
  ValidCheck(meta_tensor);
300
  if (is_runtime_) {
R
Ruibiao Chen 已提交
301
    auto* var = PADDLE_GET(Variable*, var_);
302 303 304 305
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
C
Chen Weihang 已提交
306
    } else {
307 308
      // NOTE(chenweihang): do nothing
      // only LoDTensor need to share lod
C
Chen Weihang 已提交
309
    }
310
  } else {
R
Ruibiao Chen 已提交
311
    auto* var = PADDLE_GET(VarDesc*, var_);
312 313
    var->SetLoDLevel(
        static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
C
Chen Weihang 已提交
314
  }
315 316 317
}

void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
318 319
  ValidCheck(*this);
  ValidCheck(meta_tensor);
320 321
  set_dims(meta_tensor.dims());
  if (is_runtime_) {
R
Ruibiao Chen 已提交
322
    auto* var = PADDLE_GET(Variable*, var_);
323 324 325 326 327 328
    if (var->IsType<phi::SelectedRows>()) {
      auto* selected_rows = var->GetMutable<phi::SelectedRows>();
      auto& input_selected_rows =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
      selected_rows->set_rows(input_selected_rows.rows());
      selected_rows->set_height(input_selected_rows.height());
329
    }
330
  }
331 332 333 334 335 336 337 338 339
}

void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
  share_dims(meta_tensor);
  set_dtype(meta_tensor.dtype());
  set_layout(meta_tensor.layout());
  // special case: share lod of LoDTensor
  share_lod(meta_tensor);
}
C
Chen Weihang 已提交
340

341 342 343 344 345 346 347 348 349 350 351 352
void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
  int index = compat_inputs_.size();
  compat_inputs_.emplace_back(std::move(input));
  input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
  int index = compat_outputs_.size();
  compat_outputs_.emplace_back(std::move(output));
  output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}

void CompatInferMetaContext::EmplaceBackInputs(
C
Chen Weihang 已提交
353
    paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
354 355 356 357 358 359 360 361
  int index = compat_inputs_.size();
  input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
  compat_inputs_.insert(compat_inputs_.end(),
                        std::make_move_iterator(inputs.begin()),
                        std::make_move_iterator(inputs.end()));
}

void CompatInferMetaContext::EmplaceBackOutputs(
C
Chen Weihang 已提交
362
    paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
        outputs) {
  int index = compat_outputs_.size();
  output_range_.emplace_back(
      std::pair<int, int>(index, index + outputs.size()));
  compat_outputs_.insert(compat_outputs_.end(),
                         std::make_move_iterator(outputs.begin()),
                         std::make_move_iterator(outputs.end()));
}

const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
  return compat_inputs_.at(idx);
}

std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
    size_t start, size_t end) const {
  std::vector<const phi::MetaTensor*> result;
  result.reserve(end - start);

  for (size_t i = start; i < end; ++i) {
    auto& in = compat_inputs_.at(i);
    result.emplace_back(in.initialized() ? &in : nullptr);
  }

  return result;
}

389
paddle::optional<std::vector<const phi::MetaTensor*>>
390 391 392 393 394 395 396 397 398 399 400 401
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
  const auto& first = compat_inputs_.at(start);

  if (first.initialized()) {
    std::vector<const phi::MetaTensor*> result;
    result.reserve(end - start);

    for (size_t i = start; i < end; ++i) {
      auto& in = compat_inputs_.at(i);
      result.emplace_back(in.initialized() ? &in : nullptr);
    }

402 403
    return paddle::optional<std::vector<const phi::MetaTensor*>>(
        std::move(result));
404
  }
405
  return paddle::none;
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
}

phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
  auto& out = compat_outputs_.at(idx);
  return out.initialized() ? &out : nullptr;
}

std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
    size_t start, size_t end) {
  std::vector<phi::MetaTensor*> result;
  result.reserve(end - start);
  for (size_t i = start; i < end; ++i) {
    auto& out = compat_outputs_.at(i);
    result.emplace_back(out.initialized() ? &out : nullptr);
  }
  return result;
}

CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
                                             const std::string& op_type) {
426
  // 1. get kernel args
427
  auto* arg_map_fn = ctx->GetPhiArgumentMappingFn();
428
  InferShapeArgumentMappingContext arg_map_context(*ctx);
429 430 431
  phi::KernelSignature signature = arg_map_fn
                                       ? (*arg_map_fn)(arg_map_context)
                                       : *ctx->GetPhiDefaultKernelSignature();
432 433 434
  VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;

  // 2. build infermeta context
435
  CompatInferMetaContext infer_meta_context(
F
From00 已提交
436
      {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
437

438 439 440
  const auto& input_names = signature.input_names;
  const auto& attr_names = signature.attr_names;
  const auto& output_names = signature.output_names;
441

442 443 444
  const auto& args_def =
      phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
  const auto& attr_defs = args_def.attribute_defs();
445

446
  for (auto& in_name : input_names) {
447
    if (ctx->HasInputs(in_name)) {
448
      auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
449 450
      if (input_var.size() == 1) {
        infer_meta_context.EmplaceBackInput(
451
            std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
452
      } else {
C
Chen Weihang 已提交
453
        paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize>
454
            inputs;
455
        for (const auto& in : input_var) {
456 457
          inputs.emplace_back(
              std::move(CompatMetaTensor(in, ctx->IsRuntime())));
458 459 460
        }
        infer_meta_context.EmplaceBackInputs(std::move(inputs));
      }
461
    } else {
462 463 464
      // Note: Because the input of InferMetaFn is const MetaTensor&,
      // so when we prepare input MetaTensor by InferMetaContext->InputAt(),
      // we need to return a const reference of empty MetaTensor
465 466
      infer_meta_context.EmplaceBackInput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
467
    }
468
  }
469

470 471
  VLOG(6) << "BuildInferMetaContext: Done inputs";

472
  auto attr_reader = ctx->Attrs();
473
  for (size_t i = 0; i < attr_names.size(); ++i) {
474
    auto& attr_name = attr_names[i];
475
    auto* attr_ptr = attr_reader.GetAttr(attr_name);
476 477 478
    bool is_attr_var = attr_ptr != nullptr && HasAttrVar(*attr_ptr);
    VLOG(6) << "BuildInferMetaContext: " << attr_name << ": "
            << attr_defs[i].type_index << ", is_attr_var: " << is_attr_var;
479 480
    switch (attr_defs[i].type_index) {
      case phi::AttributeType::SCALAR:
481
        if (attr_ptr && !is_attr_var) {
482 483 484 485
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::FLOAT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
486
                  phi::Scalar(PADDLE_GET_CONST(float, attr)));
487
              break;
488 489 490 491
            case framework::proto::AttrType::FLOAT64:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(double, attr)));
              break;
492 493
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
494
                  phi::Scalar(PADDLE_GET_CONST(int, attr)));
495
              break;
496 497 498 499
            case framework::proto::AttrType::LONG:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(int64_t, attr)));
              break;
500 501
            case framework::proto::AttrType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
502
                  phi::Scalar(PADDLE_GET_CONST(std::string, attr)));
503
              break;
504 505 506 507
            case framework::proto::AttrType::BOOLEAN:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(bool, attr)));
              break;
508 509 510 511 512
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to Scalar when construct "
                  "InferMetaContext.",
                  attr_name));
513
          }
514 515 516 517
        } else if (ctx->HasInput(attr_name)) {
          auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
          if (infershape_input.size() == 1) {
            if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
518
              Variable* var = PADDLE_GET_CONST(Variable*, infershape_input[0]);
519 520 521 522 523 524 525
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiScalarFromVar(*var)));
            } else {
              phi::Scalar tensor_scalar(-1);
              tensor_scalar.SetFromTensor(true);
              infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar));
            }
526
          } else {
527 528 529
            PADDLE_THROW(platform::errors::InvalidArgument(
                "Invalid input.size() when cast op attribute `%s` to Scalar, "
                "expected 1, but actually is %d .",
530 531
                attr_name,
                infershape_input.size()));
532 533
          }
        } else {
534 535 536 537 538
          // do nothing, skip current attr
        }
        break;
      case phi::AttributeType::INT_ARRAY:
        // When attr is a vector_tensor or tensor, transform it to IntArray
539
        if (attr_ptr && !is_attr_var) {
540 541 542 543
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
544
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int32_t>, attr))));
545 546 547
              break;
            case framework::proto::AttrType::LONGS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
548
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int64_t>, attr))));
549 550 551
              break;
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
552
                  phi::IntArray({PADDLE_GET_CONST(int, attr)}));
553 554 555 556 557 558
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to IntArray when "
                  "construct InferMetaContext.",
                  attr_name));
559
          }
560 561 562 563 564 565 566 567
        } else if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
          auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
          if (ctx->IsRuntime()) {
            // If is in runtime, we will get tensor's value for IntArray
            // and push it into attrs
            std::vector<Variable*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); i++) {
R
Ruibiao Chen 已提交
568
              vars.push_back(PADDLE_GET_CONST(Variable*, infershape_inputs[i]));
569
            }
570 571 572 573 574 575
            if (infershape_inputs.size() != 1) {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVarList(vars)));
            } else {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVar(*vars[0])));
576
            }
577
          } else {
578 579 580 581
            // If is not in runtime, we will set default value(-1) for IntArray
            std::vector<VarDesc*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); ++i) {
R
Ruibiao Chen 已提交
582
              vars.push_back(PADDLE_GET_CONST(VarDesc*, infershape_inputs[i]));
583 584 585 586 587 588 589 590 591 592 593
            }

            int64_t num_ele = 0;
            if (vars.size() == 1) {
              num_ele = 1;
              const auto& tensor_dims = vars[0]->GetShape();
              for (size_t i = 0; i < tensor_dims.size(); ++i) {
                num_ele *= tensor_dims[i];
              }

              if (num_ele <= 0) {
594
                num_ele = tensor_dims.size();
595 596 597 598 599 600 601 602
              }

            } else {
              num_ele = vars.size();
            }
            phi::IntArray tensor_attr(std::vector<int32_t>(num_ele, -1));
            tensor_attr.SetFromTensor(true);
            infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
603 604
          }
        } else {
605
          // do nothing, skip current attr
606
        }
607 608 609 610 611 612
        break;
      case phi::AttributeType::SCALARS:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS: {
R
Ruibiao Chen 已提交
613
              const auto& vec = PADDLE_GET_CONST(std::vector<int32_t>, attr);
614 615 616 617 618 619 620 621
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::LONGS: {
R
Ruibiao Chen 已提交
622
              const auto& vec = PADDLE_GET_CONST(std::vector<int64_t>, attr);
623 624 625 626 627 628 629 630
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOATS: {
R
Ruibiao Chen 已提交
631
              const auto& vec = PADDLE_GET_CONST(std::vector<float>, attr);
632 633 634 635 636 637 638 639
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOAT64S: {
R
Ruibiao Chen 已提交
640
              const auto& vec = PADDLE_GET_CONST(std::vector<double>, attr);
641 642 643 644 645 646 647 648 649 650 651 652
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to vector<Scalar> when "
                  "construct KernelContext.",
                  attr_names[i]));
653 654
          }
        } else {
655
          // do nothing, skip current attr
656
        }
657 658 659 660 661 662
        break;
      default:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (attr_defs[i].type_index) {
            case phi::AttributeType::FLOAT32:
R
Ruibiao Chen 已提交
663
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
664
              break;
665 666 667 668
            case phi::AttributeType::FLOAT64:
              infer_meta_context.EmplaceBackAttr(
                  PADDLE_GET_CONST(double, attr));
              break;
669
            case phi::AttributeType::INT32:
R
Ruibiao Chen 已提交
670
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
671 672
              break;
            case phi::AttributeType::BOOL:
R
Ruibiao Chen 已提交
673
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(bool, attr));
674 675 676
              break;
            case phi::AttributeType::INT64:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
677
                  PADDLE_GET_CONST(int64_t, attr));
678 679 680
              break;
            case phi::AttributeType::INT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
681
                  PADDLE_GET_CONST(std::vector<int>, attr));
682 683 684 685
              break;
            case phi::AttributeType::DATA_TYPE: {
              auto data_type = paddle::framework::TransToPhiDataType(
                  static_cast<framework::proto::VarType::Type>(
R
Ruibiao Chen 已提交
686
                      PADDLE_GET_CONST(int, attr)));
687 688 689 690
              infer_meta_context.EmplaceBackAttr(data_type);
            } break;
            case phi::AttributeType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
691
                  PADDLE_GET_CONST(std::string, attr));
692 693 694 695 696
              break;
            case phi::AttributeType::INT64S:
              switch (AttrTypeID(attr)) {
                case framework::proto::AttrType::LONGS:
                  infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
697
                      PADDLE_GET_CONST(std::vector<int64_t>, attr));
698 699 700
                  break;
                case framework::proto::AttrType::INTS: {
                  const auto& vector_int_attr =
R
Ruibiao Chen 已提交
701
                      PADDLE_GET_CONST(std::vector<int>, attr);
702 703 704 705 706 707 708 709 710 711 712 713 714 715
                  const std::vector<int64_t> vector_int64_attr(
                      vector_int_attr.begin(), vector_int_attr.end());
                  infer_meta_context.EmplaceBackAttr(vector_int64_attr);
                } break;
                default:
                  PADDLE_THROW(platform::errors::Unimplemented(
                      "Unsupported cast op attribute `%s` to vector<int64_t> "
                      "when "
                      "construct KernelContext.",
                      attr_names[i]));
              }
              break;
            case phi::AttributeType::FLOAT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
716
                  PADDLE_GET_CONST(std::vector<float>, attr));
717 718 719
              break;
            case phi::AttributeType::STRINGS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
720
                  PADDLE_GET_CONST(std::vector<std::string>, attr));
721 722 723
              break;
            case phi::AttributeType::BOOLS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
724
                  PADDLE_GET_CONST(std::vector<bool>, attr));
725 726 727
              break;
            case phi::AttributeType::FLOAT64S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
728
                  PADDLE_GET_CONST(std::vector<double>, attr));
729 730 731 732 733 734 735
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` when construct "
                  "KernelContext in dygraph.",
                  attr_names[i]));
          }
H
hong 已提交
736
        } else {
737
          // do nothing, skip currnet attr
H
hong 已提交
738
        }
739 740 741
    }
  }

742 743
  VLOG(6) << "BuildInferMetaContext: Done attrs";

744
  for (auto& out_name : output_names) {
745
    if (ctx->HasOutputs(out_name, true)) {
746
      auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
747
      if (output_var.size() == 1) {
748 749
        infer_meta_context.EmplaceBackOutput(
            std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
750
      } else {
C
Chen Weihang 已提交
751
        paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
752
            outputs;
753
        for (const auto& out : output_var) {
754
          if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
755
            if (PADDLE_GET_CONST(Variable*, out)) {
756
              outputs.emplace_back(
757
                  std::move(CompatMetaTensor(out, ctx->IsRuntime())));
758 759
              continue;
            }
R
Ruibiao Chen 已提交
760
          } else if (PADDLE_GET_CONST(VarDesc*, out)) {
761
            outputs.emplace_back(
762
                std::move(CompatMetaTensor(out, ctx->IsRuntime())));
763 764
            continue;
          }
765
          outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
766 767 768 769
        }
        infer_meta_context.EmplaceBackOutputs(std::move(outputs));
      }
    } else {
770 771
      infer_meta_context.EmplaceBackOutput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
772
    }
773 774
  }

775 776
  VLOG(6) << "BuildInferMetaContext: Done outputs";

777 778 779
  return infer_meta_context;
}

C
Chen Weihang 已提交
780 781
}  // namespace framework
}  // namespace paddle