infershape_utils.cc 32.3 KB
Newer Older
C
Chen Weihang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"

17
#include <algorithm>
18 19
#include <string>

20
#include "paddle/fluid/framework/convert_utils.h"
C
Chen Weihang 已提交
21
#include "paddle/fluid/framework/framework.pb.h"
22
#include "paddle/fluid/framework/phi_utils.h"
C
Chen Weihang 已提交
23
#include "paddle/fluid/platform/enforce.h"
24
#include "paddle/phi/common/int_array.h"
25
#include "paddle/phi/common/scalar.h"
26 27 28 29 30
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
31
#include "paddle/phi/core/kernel_factory.h"
32
#include "paddle/phi/core/tensor_utils.h"
C
Chen Weihang 已提交
33

Y
YuanRisheng 已提交
34 35
#include "glog/logging.h"

C
Chen Weihang 已提交
36 37 38
namespace paddle {
namespace framework {

39
class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
C
Chen Weihang 已提交
40 41 42 43 44 45 46 47 48 49 50 51
 public:
  explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx)
      : ctx_(ctx) {}

  bool HasInput(const std::string& name) const override {
    return ctx_.HasInput(name);
  }

  bool HasOutput(const std::string& name) const override {
    return ctx_.HasOutput(name);
  }

52 53 54 55
  bool HasAttr(const std::string& name) const override {
    return ctx_.HasAttr(name);
  }

C
Chen Weihang 已提交
56
  paddle::any Attr(const std::string& name) const override {
57 58
    auto* attr = ctx_.Attrs().GetAttr(name);
    PADDLE_ENFORCE_NOT_NULL(
59 60 61
        attr,
        platform::errors::NotFound("Attribute (%s) should be in AttributeMap.",
                                   name));
62
    return GetAttrValue(*attr);
C
Chen Weihang 已提交
63 64 65
  }

  size_t InputSize(const std::string& name) const override {
66 67 68 69 70 71
    if (ctx_.HasInputs(name)) {
      return ctx_.Inputs(name).size();
    } else if (ctx_.HasInput(name)) {
      return 1;
    }
    return 0;
C
Chen Weihang 已提交
72 73 74 75 76 77 78
  }

  size_t OutputSize(const std::string& name) const override {
    return ctx_.Outputs(name).size();
  }

  bool IsDenseTensorInput(const std::string& name) const override {
79 80 81 82 83
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::LOD_TENSOR;
  }

  bool IsDenseTensorInputs(const std::string& name) const override {
C
Chen Weihang 已提交
84
    auto var_types = ctx_.GetInputsVarType(name);
85 86
    return std::all_of(var_types.begin(),
                       var_types.end(),
87 88 89
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
C
Chen Weihang 已提交
90 91
  }

Y
YuanRisheng 已提交
92 93 94 95 96 97 98 99 100
  bool IsSelectedRowsInputs(const std::string& name) const override {
    auto var_types = ctx_.GetInputsVarType(name);
    return std::all_of(var_types.begin(),
                       var_types.end(),
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::SELECTED_ROWS;
                       });
  }

C
Chen Weihang 已提交
101
  bool IsSelectedRowsInput(const std::string& name) const override {
102 103
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::SELECTED_ROWS;
C
Chen Weihang 已提交
104 105
  }

106 107
  bool IsDenseTensorVectorInput(const std::string& name) const override {
    auto var_types = ctx_.GetInputsVarType(name);
108 109
    return std::all_of(var_types.begin(),
                       var_types.end(),
110 111 112
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR_ARRAY;
                       });
113 114
  }

115 116 117 118 119
  bool IsSparseCooTensorInput(const std::string& name) const override {
    auto var_type = ctx_.GetInputVarType(name);
    return var_type == proto::VarType::SPARSE_COO;
  }

120 121
  bool IsDenseTensorOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
122 123
    return std::all_of(var_types.begin(),
                       var_types.end(),
124 125 126
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::LOD_TENSOR;
                       });
127 128 129 130
  }

  bool IsSelectedRowsOutput(const std::string& name) const override {
    auto var_types = ctx_.GetOutputsVarType(name);
131 132
    return std::all_of(var_types.begin(),
                       var_types.end(),
133 134 135
                       [](const proto::VarType::Type& type) {
                         return type == proto::VarType::SELECTED_ROWS;
                       });
136 137
  }

138 139
  bool IsForInferShape() const override { return true; }

140 141
  bool IsRuntime() const override { return ctx_.IsRuntime(); }

C
Chen Weihang 已提交
142 143 144 145
 private:
  const InferShapeContext& ctx_;
};

146 147 148 149 150 151 152
static inline void ValidCheck(const phi::MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current CompatMetaTensor is not initialized."));
}

153
int64_t CompatMetaTensor::numel() const {
154
  ValidCheck(*this);
155
  if (is_runtime_) {
R
Ruibiao Chen 已提交
156
    auto* var = PADDLE_GET_CONST(Variable*, var_);
157
    return var->Get<phi::DenseTensor>().numel();
158
  } else {
R
Ruibiao Chen 已提交
159
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
160
    return var->ElementSize();
C
Chen Weihang 已提交
161
  }
162
}
C
Chen Weihang 已提交
163

Y
YuanRisheng 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
bool CompatMetaTensor::is_selected_rows() const {
  if (is_runtime_) {
    auto* var = PADDLE_GET_CONST(Variable*, var_);
    return var->IsType<phi::SelectedRows>();
  } else {
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
    return var->GetType() == proto::VarType::SELECTED_ROWS;
  }
}

bool CompatMetaTensor::is_dense() const {
  if (is_runtime_) {
    auto* var = PADDLE_GET_CONST(Variable*, var_);
    return var->IsType<phi::DenseTensor>();
  } else {
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
    return var->GetType() == proto::VarType::LOD_TENSOR;
  }
}

bool CompatMetaTensor::is_tensor_array() const {
  if (is_runtime_) {
    auto* var = PADDLE_GET_CONST(Variable*, var_);
    return var->IsType<framework::LoDTensorArray>();
  } else {
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
    return var->GetType() == proto::VarType::LOD_TENSOR_ARRAY;
  }
}

194
DDim CompatMetaTensor::dims() const {
195
  ValidCheck(*this);
196
  if (is_runtime_) {
R
Ruibiao Chen 已提交
197
    auto* var = PADDLE_GET_CONST(Variable*, var_);
198 199 200
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dims();
    } else if (var->IsType<phi::SelectedRows>()) {
Y
YuanRisheng 已提交
201
      return var->Get<phi::SelectedRows>().GetCompleteDims();
202 203
    } else if (var->IsType<phi::SparseCooTensor>()) {
      return var->Get<phi::SparseCooTensor>().dims();
204 205 206 207
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // use tensor array size as dims
      auto& tensor_array = var->Get<framework::LoDTensorArray>();
      return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
C
Chen Weihang 已提交
208
    } else {
209 210 211
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dims from DenseTensor or SelectedRows or "
          "DenseTensorArray."));
C
Chen Weihang 已提交
212
    }
213
  } else {
R
Ruibiao Chen 已提交
214
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
215

216 217 218
    return phi::make_ddim(var->GetShape());
    // return var->GetShape().empty() ? phi::make_ddim({0UL}) :
    // phi::make_ddim(var->GetShape());
C
Chen Weihang 已提交
219
  }
220
}
C
Chen Weihang 已提交
221

222
phi::DataType CompatMetaTensor::dtype() const {
223
  ValidCheck(*this);
224
  if (is_runtime_) {
R
Ruibiao Chen 已提交
225
    auto* var = PADDLE_GET_CONST(Variable*, var_);
226 227 228 229
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().dtype();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().dtype();
230 231
    } else if (var->IsType<phi::SparseCooTensor>()) {
      return var->Get<phi::SparseCooTensor>().dtype();
232 233 234 235
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported get dtype from LoDTensorArray now
      return phi::DataType::UNDEFINED;
C
Chen Weihang 已提交
236
    } else {
237 238
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
239
    }
240
  } else {
R
Ruibiao Chen 已提交
241
    auto* var = PADDLE_GET_CONST(VarDesc*, var_);
242
    return paddle::framework::TransToPhiDataType(var->GetDataType());
C
Chen Weihang 已提交
243
  }
244
}
C
Chen Weihang 已提交
245

246
DataLayout CompatMetaTensor::layout() const {
247
  ValidCheck(*this);
248
  if (is_runtime_) {
R
Ruibiao Chen 已提交
249
    auto* var = PADDLE_GET_CONST(Variable*, var_);
250 251 252 253
    if (var->IsType<phi::DenseTensor>()) {
      return var->Get<phi::DenseTensor>().layout();
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().layout();
254 255
    } else if (var->IsType<phi::SparseCooTensor>()) {
      return var->Get<phi::SparseCooTensor>().layout();
256
    } else if (var->IsType<framework::LoDTensorArray>()) {
257
      // NOTE(chenweihang): do nothing
258 259 260 261 262 263
      // Unsupported get layout from LoDTensorArray now
      return phi::DataLayout::UNDEFINED;
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can get layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
264
    }
265 266 267 268
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported get layout for VarDesc now
    return DataLayout::UNDEFINED;
C
Chen Weihang 已提交
269
  }
270 271 272
}

void CompatMetaTensor::set_dims(const DDim& dims) {
273
  ValidCheck(*this);
274
  if (is_runtime_) {
R
Ruibiao Chen 已提交
275
    auto* var = PADDLE_GET(Variable*, var_);
Y
YuanRisheng 已提交
276
    if (var == nullptr) return;
277 278 279 280
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
    } else if (var->IsType<phi::SelectedRows>()) {
Y
YuanRisheng 已提交
281
      var->GetMutable<phi::SelectedRows>()->set_height(dims[0]);
282 283 284
    } else if (var->IsType<phi::SparseCooTensor>()) {
      auto* tensor = var->GetMutable<phi::SparseCooTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
285 286 287 288 289
    } else if (var->IsType<framework::LoDTensorArray>()) {
      auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
      // Note: Here I want enforce `tensor_array->size() == 0UL`, because
      // inplace using on LoDTensorArray is dangerous, but the unittest
      // `test_list` contains this behavior
290 291
      PADDLE_ENFORCE_EQ(dims.size(),
                        1UL,
292 293 294 295
                        platform::errors::InvalidArgument(
                            "LoDTensorArray can only have one dimension."));
      // only set the array size for LoDTensorArray input
      tensor_array->resize(dims[0]);
C
Chen Weihang 已提交
296
    } else {
297 298
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dims from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
299
    }
300
  } else {
R
Ruibiao Chen 已提交
301
    auto* var = PADDLE_GET(VarDesc*, var_);
Y
YuanRisheng 已提交
302 303 304
    if (var) {
      var->SetShape(vectorize(dims));
    }
C
Chen Weihang 已提交
305
  }
306 307 308
}

void CompatMetaTensor::set_dtype(phi::DataType dtype) {
309
  ValidCheck(*this);
310
  if (is_runtime_) {
R
Ruibiao Chen 已提交
311
    auto* var = PADDLE_GET(Variable*, var_);
Y
YuanRisheng 已提交
312
    if (var == nullptr) return;
313 314 315 316 317 318
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
319 320 321
    } else if (var->IsType<phi::SparseCooTensor>()) {
      auto* tensor = var->GetMutable<phi::SparseCooTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
322 323 324
    } else if (var->IsType<framework::LoDTensorArray>()) {
      // NOTE(chenweihang): do nothing
      // Unsupported set dtype for LoDTensorArray now
C
Chen Weihang 已提交
325
    } else {
326 327
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set dtype from DenseTensor or SelectedRows."));
C
Chen Weihang 已提交
328
    }
329
  } else {
R
Ruibiao Chen 已提交
330
    auto* var = PADDLE_GET(VarDesc*, var_);
Y
YuanRisheng 已提交
331 332 333
    if (var) {
      var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
    }
C
Chen Weihang 已提交
334
  }
335 336 337
}

void CompatMetaTensor::set_layout(DataLayout layout) {
338
  ValidCheck(*this);
339
  if (is_runtime_) {
R
Ruibiao Chen 已提交
340
    auto* var = PADDLE_GET(Variable*, var_);
Y
YuanRisheng 已提交
341
    if (var == nullptr) return;
342 343 344 345 346 347
    if (var->IsType<phi::DenseTensor>()) {
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
    } else if (var->IsType<phi::SelectedRows>()) {
      auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
348 349 350
    } else if (var->IsType<phi::SparseCooTensor>()) {
      auto* tensor = var->GetMutable<phi::SparseCooTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
351
    } else if (var->IsType<framework::LoDTensorArray>()) {
352
      // NOTE(chenweihang): do nothing
353 354 355 356 357
      // Unsupported set dtype for LoDTensorArray now
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, only can set layout from DenseTensor or "
          "SelectedRows."));
C
Chen Weihang 已提交
358
    }
359 360 361
  } else {
    // NOTE(chenweihang): do nothing
    // Unsupported set layout for VarDesc now
C
Chen Weihang 已提交
362
  }
363 364 365
}

void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
366 367
  ValidCheck(*this);
  ValidCheck(meta_tensor);
368
  if (is_runtime_) {
R
Ruibiao Chen 已提交
369
    auto* var = PADDLE_GET(Variable*, var_);
Y
YuanRisheng 已提交
370
    if (var == nullptr) return;
Y
YuanRisheng 已提交
371
    if (var->IsType<phi::DenseTensor>() && meta_tensor.is_dense()) {
372 373 374
      auto* tensor = var->GetMutable<phi::DenseTensor>();
      phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
C
Chen Weihang 已提交
375
    } else {
376 377
      // NOTE(chenweihang): do nothing
      // only LoDTensor need to share lod
C
Chen Weihang 已提交
378
    }
379
  } else {
R
Ruibiao Chen 已提交
380
    auto* var = PADDLE_GET(VarDesc*, var_);
Y
YuanRisheng 已提交
381 382 383 384
    if (!meta_tensor.is_dense() && !meta_tensor.is_tensor_array()) {
      VLOG(3) << "input metatensor is not LoDTensor or LoDTensorArray.";
      return;
    }
Y
YuanRisheng 已提交
385 386 387 388
    if (var) {
      var->SetLoDLevel(static_cast<const CompatMetaTensor&>(meta_tensor)
                           .GetCompileTimeLoD());
    }
C
Chen Weihang 已提交
389
  }
390 391 392
}

void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
393 394
  ValidCheck(*this);
  ValidCheck(meta_tensor);
395 396
  set_dims(meta_tensor.dims());
  if (is_runtime_) {
R
Ruibiao Chen 已提交
397
    auto* var = PADDLE_GET(Variable*, var_);
Y
YuanRisheng 已提交
398
    if (var == nullptr) return;
399 400 401 402 403 404
    if (var->IsType<phi::SelectedRows>()) {
      auto* selected_rows = var->GetMutable<phi::SelectedRows>();
      auto& input_selected_rows =
          static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
      selected_rows->set_rows(input_selected_rows.rows());
      selected_rows->set_height(input_selected_rows.height());
405
    }
406
  }
407 408 409 410 411 412 413 414 415
}

void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
  share_dims(meta_tensor);
  set_dtype(meta_tensor.dtype());
  set_layout(meta_tensor.layout());
  // special case: share lod of LoDTensor
  share_lod(meta_tensor);
}
C
Chen Weihang 已提交
416

417 418 419 420 421 422 423 424 425 426 427 428
void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
  int index = compat_inputs_.size();
  compat_inputs_.emplace_back(std::move(input));
  input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
  int index = compat_outputs_.size();
  compat_outputs_.emplace_back(std::move(output));
  output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}

void CompatInferMetaContext::EmplaceBackInputs(
C
Chen Weihang 已提交
429
    paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
430 431 432 433 434 435 436 437
  int index = compat_inputs_.size();
  input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
  compat_inputs_.insert(compat_inputs_.end(),
                        std::make_move_iterator(inputs.begin()),
                        std::make_move_iterator(inputs.end()));
}

void CompatInferMetaContext::EmplaceBackOutputs(
C
Chen Weihang 已提交
438
    paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
        outputs) {
  int index = compat_outputs_.size();
  output_range_.emplace_back(
      std::pair<int, int>(index, index + outputs.size()));
  compat_outputs_.insert(compat_outputs_.end(),
                         std::make_move_iterator(outputs.begin()),
                         std::make_move_iterator(outputs.end()));
}

const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
  return compat_inputs_.at(idx);
}

std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
    size_t start, size_t end) const {
  std::vector<const phi::MetaTensor*> result;
  result.reserve(end - start);

  for (size_t i = start; i < end; ++i) {
    auto& in = compat_inputs_.at(i);
    result.emplace_back(in.initialized() ? &in : nullptr);
  }

  return result;
}

465
paddle::optional<std::vector<const phi::MetaTensor*>>
466 467 468 469 470 471 472 473 474 475 476 477
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
  const auto& first = compat_inputs_.at(start);

  if (first.initialized()) {
    std::vector<const phi::MetaTensor*> result;
    result.reserve(end - start);

    for (size_t i = start; i < end; ++i) {
      auto& in = compat_inputs_.at(i);
      result.emplace_back(in.initialized() ? &in : nullptr);
    }

478 479
    return paddle::optional<std::vector<const phi::MetaTensor*>>(
        std::move(result));
480
  }
481
  return paddle::none;
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
}

phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
  auto& out = compat_outputs_.at(idx);
  return out.initialized() ? &out : nullptr;
}

std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
    size_t start, size_t end) {
  std::vector<phi::MetaTensor*> result;
  result.reserve(end - start);
  for (size_t i = start; i < end; ++i) {
    auto& out = compat_outputs_.at(i);
    result.emplace_back(out.initialized() ? &out : nullptr);
  }
  return result;
}

CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
                                             const std::string& op_type) {
502
  // 1. get kernel args
503
  auto* arg_map_fn = ctx->GetPhiArgumentMappingFn();
504
  InferShapeArgumentMappingContext arg_map_context(*ctx);
505 506 507
  phi::KernelSignature signature = arg_map_fn
                                       ? (*arg_map_fn)(arg_map_context)
                                       : *ctx->GetPhiDefaultKernelSignature();
508 509 510
  VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;

  // 2. build infermeta context
511
  CompatInferMetaContext infer_meta_context(
F
From00 已提交
512
      {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
513

514 515 516
  const auto& input_names = signature.input_names;
  const auto& attr_names = signature.attr_names;
  const auto& output_names = signature.output_names;
517

518 519 520
  const auto& args_def =
      phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
  const auto& attr_defs = args_def.attribute_defs();
521

522
  for (auto& in_name : input_names) {
523
    if (ctx->HasInputs(in_name)) {
524
      auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
525 526
      if (input_var.size() == 1) {
        infer_meta_context.EmplaceBackInput(
527
            std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
528
      } else {
C
Chen Weihang 已提交
529
        paddle::small_vector<CompatMetaTensor, phi::kInputSmallVectorSize>
530
            inputs;
531
        for (const auto& in : input_var) {
532 533
          inputs.emplace_back(
              std::move(CompatMetaTensor(in, ctx->IsRuntime())));
534 535 536
        }
        infer_meta_context.EmplaceBackInputs(std::move(inputs));
      }
537
    } else {
538 539 540
      // Note: Because the input of InferMetaFn is const MetaTensor&,
      // so when we prepare input MetaTensor by InferMetaContext->InputAt(),
      // we need to return a const reference of empty MetaTensor
541 542
      infer_meta_context.EmplaceBackInput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
543
    }
544
  }
545

546 547
  VLOG(6) << "BuildInferMetaContext: Done inputs";

548
  auto attr_reader = ctx->Attrs();
549
  for (size_t i = 0; i < attr_names.size(); ++i) {
550
    auto& attr_name = attr_names[i];
551
    auto* attr_ptr = attr_reader.GetAttr(attr_name);
552 553 554
    bool is_attr_var = attr_ptr != nullptr && HasAttrVar(*attr_ptr);
    VLOG(6) << "BuildInferMetaContext: " << attr_name << ": "
            << attr_defs[i].type_index << ", is_attr_var: " << is_attr_var;
555 556
    switch (attr_defs[i].type_index) {
      case phi::AttributeType::SCALAR:
557
        if (attr_ptr && !is_attr_var) {
558 559 560 561
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::FLOAT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
562
                  phi::Scalar(PADDLE_GET_CONST(float, attr)));
563
              break;
564 565 566 567
            case framework::proto::AttrType::FLOAT64:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(double, attr)));
              break;
568 569
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
570
                  phi::Scalar(PADDLE_GET_CONST(int, attr)));
571
              break;
572 573 574 575
            case framework::proto::AttrType::LONG:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(int64_t, attr)));
              break;
576 577
            case framework::proto::AttrType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
578
                  phi::Scalar(PADDLE_GET_CONST(std::string, attr)));
579
              break;
580 581 582 583
            case framework::proto::AttrType::BOOLEAN:
              infer_meta_context.EmplaceBackAttr(
                  phi::Scalar(PADDLE_GET_CONST(bool, attr)));
              break;
584 585 586 587 588
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to Scalar when construct "
                  "InferMetaContext.",
                  attr_name));
589
          }
590 591 592 593
        } else if (ctx->HasInput(attr_name)) {
          auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
          if (infershape_input.size() == 1) {
            if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
594
              Variable* var = PADDLE_GET_CONST(Variable*, infershape_input[0]);
595 596 597 598 599 600 601
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiScalarFromVar(*var)));
            } else {
              phi::Scalar tensor_scalar(-1);
              tensor_scalar.SetFromTensor(true);
              infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar));
            }
602
          } else {
603 604 605
            PADDLE_THROW(platform::errors::InvalidArgument(
                "Invalid input.size() when cast op attribute `%s` to Scalar, "
                "expected 1, but actually is %d .",
606 607
                attr_name,
                infershape_input.size()));
608 609
          }
        } else {
610 611 612 613 614
          // do nothing, skip current attr
        }
        break;
      case phi::AttributeType::INT_ARRAY:
        // When attr is a vector_tensor or tensor, transform it to IntArray
615
        if (attr_ptr && !is_attr_var) {
616 617 618 619
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
620
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int32_t>, attr))));
621 622 623
              break;
            case framework::proto::AttrType::LONGS:
              infer_meta_context.EmplaceBackAttr(std::move(
R
Ruibiao Chen 已提交
624
                  phi::IntArray(PADDLE_GET_CONST(std::vector<int64_t>, attr))));
625 626 627
              break;
            case framework::proto::AttrType::INT:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
628
                  phi::IntArray({PADDLE_GET_CONST(int, attr)}));
629 630 631 632 633 634
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to IntArray when "
                  "construct InferMetaContext.",
                  attr_name));
635
          }
636 637 638 639 640 641 642 643
        } else if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
          auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
          if (ctx->IsRuntime()) {
            // If is in runtime, we will get tensor's value for IntArray
            // and push it into attrs
            std::vector<Variable*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); i++) {
R
Ruibiao Chen 已提交
644
              vars.push_back(PADDLE_GET_CONST(Variable*, infershape_inputs[i]));
645
            }
646 647 648 649 650 651
            if (infershape_inputs.size() != 1) {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVarList(vars)));
            } else {
              infer_meta_context.EmplaceBackAttr(
                  std::move(experimental::MakePhiIntArrayFromVar(*vars[0])));
652
            }
653
          } else {
654 655 656 657
            // If is not in runtime, we will set default value(-1) for IntArray
            std::vector<VarDesc*> vars;
            vars.reserve(infershape_inputs.size());
            for (size_t i = 0; i < infershape_inputs.size(); ++i) {
R
Ruibiao Chen 已提交
658
              vars.push_back(PADDLE_GET_CONST(VarDesc*, infershape_inputs[i]));
659 660 661 662 663 664 665 666 667 668 669
            }

            int64_t num_ele = 0;
            if (vars.size() == 1) {
              num_ele = 1;
              const auto& tensor_dims = vars[0]->GetShape();
              for (size_t i = 0; i < tensor_dims.size(); ++i) {
                num_ele *= tensor_dims[i];
              }

              if (num_ele <= 0) {
670
                num_ele = tensor_dims.size();
671 672 673 674 675 676 677 678
              }

            } else {
              num_ele = vars.size();
            }
            phi::IntArray tensor_attr(std::vector<int32_t>(num_ele, -1));
            tensor_attr.SetFromTensor(true);
            infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
679 680
          }
        } else {
681
          // do nothing, skip current attr
682
        }
683 684 685 686 687 688
        break;
      case phi::AttributeType::SCALARS:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS: {
R
Ruibiao Chen 已提交
689
              const auto& vec = PADDLE_GET_CONST(std::vector<int32_t>, attr);
690 691 692 693 694 695 696 697
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::LONGS: {
R
Ruibiao Chen 已提交
698
              const auto& vec = PADDLE_GET_CONST(std::vector<int64_t>, attr);
699 700 701 702 703 704 705 706
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOATS: {
R
Ruibiao Chen 已提交
707
              const auto& vec = PADDLE_GET_CONST(std::vector<float>, attr);
708 709 710 711 712 713 714 715
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            case framework::proto::AttrType::FLOAT64S: {
R
Ruibiao Chen 已提交
716
              const auto& vec = PADDLE_GET_CONST(std::vector<double>, attr);
717 718 719 720 721 722 723 724 725 726 727 728
              std::vector<phi::Scalar> scalar_list;
              scalar_list.reserve(vec.size());
              for (const auto& val : vec) {
                scalar_list.emplace_back(val);
              }
              infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
            } break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to vector<Scalar> when "
                  "construct KernelContext.",
                  attr_names[i]));
729 730
          }
        } else {
731
          // do nothing, skip current attr
732
        }
733 734 735 736 737 738
        break;
      default:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (attr_defs[i].type_index) {
            case phi::AttributeType::FLOAT32:
R
Ruibiao Chen 已提交
739
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
740
              break;
741 742 743 744
            case phi::AttributeType::FLOAT64:
              infer_meta_context.EmplaceBackAttr(
                  PADDLE_GET_CONST(double, attr));
              break;
745
            case phi::AttributeType::INT32:
R
Ruibiao Chen 已提交
746
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
747 748
              break;
            case phi::AttributeType::BOOL:
R
Ruibiao Chen 已提交
749
              infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(bool, attr));
750 751 752
              break;
            case phi::AttributeType::INT64:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
753
                  PADDLE_GET_CONST(int64_t, attr));
754 755 756
              break;
            case phi::AttributeType::INT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
757
                  PADDLE_GET_CONST(std::vector<int>, attr));
758 759 760 761
              break;
            case phi::AttributeType::DATA_TYPE: {
              auto data_type = paddle::framework::TransToPhiDataType(
                  static_cast<framework::proto::VarType::Type>(
R
Ruibiao Chen 已提交
762
                      PADDLE_GET_CONST(int, attr)));
763 764 765 766
              infer_meta_context.EmplaceBackAttr(data_type);
            } break;
            case phi::AttributeType::STRING:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
767
                  PADDLE_GET_CONST(std::string, attr));
768 769 770 771 772
              break;
            case phi::AttributeType::INT64S:
              switch (AttrTypeID(attr)) {
                case framework::proto::AttrType::LONGS:
                  infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
773
                      PADDLE_GET_CONST(std::vector<int64_t>, attr));
774 775 776
                  break;
                case framework::proto::AttrType::INTS: {
                  const auto& vector_int_attr =
R
Ruibiao Chen 已提交
777
                      PADDLE_GET_CONST(std::vector<int>, attr);
778 779 780 781 782 783 784 785 786 787 788 789 790 791
                  const std::vector<int64_t> vector_int64_attr(
                      vector_int_attr.begin(), vector_int_attr.end());
                  infer_meta_context.EmplaceBackAttr(vector_int64_attr);
                } break;
                default:
                  PADDLE_THROW(platform::errors::Unimplemented(
                      "Unsupported cast op attribute `%s` to vector<int64_t> "
                      "when "
                      "construct KernelContext.",
                      attr_names[i]));
              }
              break;
            case phi::AttributeType::FLOAT32S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
792
                  PADDLE_GET_CONST(std::vector<float>, attr));
793 794 795
              break;
            case phi::AttributeType::STRINGS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
796
                  PADDLE_GET_CONST(std::vector<std::string>, attr));
797 798 799
              break;
            case phi::AttributeType::BOOLS:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
800
                  PADDLE_GET_CONST(std::vector<bool>, attr));
801 802 803
              break;
            case phi::AttributeType::FLOAT64S:
              infer_meta_context.EmplaceBackAttr(
R
Ruibiao Chen 已提交
804
                  PADDLE_GET_CONST(std::vector<double>, attr));
805 806 807 808 809 810 811
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` when construct "
                  "KernelContext in dygraph.",
                  attr_names[i]));
          }
H
hong 已提交
812
        } else {
813
          // do nothing, skip currnet attr
H
hong 已提交
814
        }
815 816 817
    }
  }

818 819
  VLOG(6) << "BuildInferMetaContext: Done attrs";

820
  for (auto& out_name : output_names) {
821
    if (ctx->HasOutputs(out_name, true)) {
822
      auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
823
      if (output_var.size() == 1) {
824 825
        infer_meta_context.EmplaceBackOutput(
            std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
826
      } else {
C
Chen Weihang 已提交
827
        paddle::small_vector<CompatMetaTensor, phi::kOutputSmallVectorSize>
828
            outputs;
829
        for (const auto& out : output_var) {
830
          if (ctx->IsRuntime()) {
R
Ruibiao Chen 已提交
831
            if (PADDLE_GET_CONST(Variable*, out)) {
832
              outputs.emplace_back(
833
                  std::move(CompatMetaTensor(out, ctx->IsRuntime())));
834 835
              continue;
            }
R
Ruibiao Chen 已提交
836
          } else if (PADDLE_GET_CONST(VarDesc*, out)) {
837
            outputs.emplace_back(
838
                std::move(CompatMetaTensor(out, ctx->IsRuntime())));
839 840
            continue;
          }
841
          outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
842 843 844 845
        }
        infer_meta_context.EmplaceBackOutputs(std::move(outputs));
      }
    } else {
846 847
      infer_meta_context.EmplaceBackOutput(
          std::move(CompatMetaTensor(ctx->IsRuntime())));
848
    }
849 850
  }

851 852
  VLOG(6) << "BuildInferMetaContext: Done outputs";

853 854 855
  return infer_meta_context;
}

C
Chen Weihang 已提交
856 857
}  // namespace framework
}  // namespace paddle