prepared_operator.h 27.1 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
W
wanghuancoder 已提交
20

J
Jiabin Yang 已提交
21
#include "paddle/fluid/eager/eager_tensor.h"
22
#include "paddle/fluid/framework/convert_utils.h"
23 24
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h"
J
Jiabin Yang 已提交
25
#include "paddle/fluid/framework/operator.h"
26
#include "paddle/fluid/framework/phi_utils.h"
27
#include "paddle/fluid/framework/type_defs.h"
28
#include "paddle/fluid/imperative/execution_context.h"
J
Jiabin Yang 已提交
29 30
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
J
Jiabin Yang 已提交
31
#include "paddle/fluid/imperative/var_helper.h"
32
#include "paddle/phi/core/dense_tensor.h"
33
#include "paddle/phi/core/kernel_context.h"
34
#include "paddle/phi/core/selected_rows.h"
35

36 37
DECLARE_bool(use_mkldnn);

J
Jiabin Yang 已提交
38 39 40 41 42
namespace paddle {
namespace imperative {

const framework::Tensor* GetTensorFromVar(const framework::Variable& var);

43 44 45 46 47 48 49 50
template <typename VarType>
static void SetForwardDataTypeOfGradVar(const std::shared_ptr<VarType>& var);

template <>
void SetForwardDataTypeOfGradVar<VariableWrapper>(
    const std::shared_ptr<VariableWrapper>& var) {
  if (var->HasGradVar()) {
    auto grad_var = var->GetGradVar();
51
    VLOG(6) << "Set grad var (" << grad_var->Name() << ")'s forward dtype to ("
52 53 54 55 56 57 58 59 60 61 62 63 64
            << framework::DataTypeToString(var->DataType()) << ").";
    grad_var->SetForwardDataType(var->DataType());
  }
}

template <>
void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
  if (var->HasGradVar()) {
    auto& shared_var = var->SharedVar();
    SetForwardDataTypeOfGradVar<VariableWrapper>(shared_var);
  }
}

J
Jiabin Yang 已提交
65
template <>
66 67
void SetForwardDataTypeOfGradVar<egr::EagerVariable>(
    const std::shared_ptr<egr::EagerVariable>& var) {
J
Jiabin Yang 已提交
68 69 70 71 72
  VLOG(10) << "Var in Eager dose not support SetForwardDataTypeOfGradVar: "
           << var->name();
  // TODO(jiabin): SetForwardDataType of Grad var is not supported yet in
  // EagerMode.
}
73

74
template <typename VarType>
75
std::shared_ptr<NameVarMap<VarType>> PrepareData(
76 77
    const framework::OperatorWithKernel& op,
    const NameVarMap<VarType>& ins,
78
    const framework::OpKernelType& expected_kernel_key) {
79 80 81
  std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr;
  for (const auto& name_pair : ins) {
    for (size_t i = 0; i < name_pair.second.size(); ++i) {
J
Jiabin Yang 已提交
82 83 84
      auto& template_var = name_pair.second[i];
      SetForwardDataTypeOfGradVar(template_var);
      const auto* tensor = GetTensorFromVar(template_var->Var());
85 86 87 88 89 90
      if (tensor && tensor->IsInitialized()) {
        auto kernel_type_for_var = op.GetKernelTypeForVar(
            name_pair.first, *tensor, expected_kernel_key);
        if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
          continue;
        } else {
J
Jiabin Yang 已提交
91 92 93
          VLOG(3) << "Transform Variable " << GetNameFromVar(template_var)
                  << " from " << kernel_type_for_var << " to "
                  << expected_kernel_key;
94

J
Jiabin Yang 已提交
95
          if (CheckCachedKey(template_var, expected_kernel_key)) {
96 97 98
            VLOG(3) << "Hit variable_wrapper cache: key="
                    << expected_kernel_key;
            std::shared_ptr<VariableWrapper> cache_var =
J
Jiabin Yang 已提交
99
                GetCachedValue(template_var, expected_kernel_key);
100 101 102
            if (tmp_ins_ptr == nullptr) {
              tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
            }
103 104

            const auto* tensor = GetTensorFromVar(cache_var->Var());
J
Jiabin Yang 已提交
105 106 107
            auto tmp_var =
                std::make_shared<VarType>(GetNameFromVar(template_var));
            SetType(tmp_var, GetType(template_var));
108 109
            SetTensorToVariable(
                cache_var->Var(), *tensor, tmp_var->MutableVar());
110 111
            (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
          } else {
112
            framework::Tensor out;
113 114
            TransformData(
                expected_kernel_key, kernel_type_for_var, *tensor, &out);
115 116 117 118 119 120 121 122
            if (NeedTransformDataType(kernel_type_for_var,
                                      expected_kernel_key)) {
              // To avoid NameVarMap copy construction overhead in general
              // scenarios, if inplace transformed, return original input
              // directly
              if (tmp_ins_ptr == nullptr) {
                tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
              }
J
Jiabin Yang 已提交
123 124 125
              auto tmp_var =
                  std::make_shared<VarType>(GetNameFromVar(template_var));
              SetType(tmp_var, GetType(template_var));
126 127
              SetTensorToVariable(
                  template_var->Var(), out, tmp_var->MutableVar());
128
              (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
J
Jiabin Yang 已提交
129
              SetCachedValue(template_var, expected_kernel_key, tmp_var);
130 131 132 133 134 135
              VLOG(3) << "Set cache to variable_wrapper: key="
                      << expected_kernel_key;
            } else {
              // if dtype is same, transform inplace will not change the
              // original
              // value, transform inplace to avoid multiple copy
136 137
              SetTensorToVariable(
                  template_var->Var(), out, template_var->MutableVar());
138
            }
139
          }
140 141 142 143
        }
      }
    }
  }
144
  return tmp_ins_ptr;
145 146
}

J
Jiabin Yang 已提交
147 148
class PreparedOp {
 public:
149 150
  PreparedOp(const framework::OperatorBase& op,
             const framework::RuntimeContext& ctx,
151
             const framework::OpKernelType& kernel_type,
152
             const framework::OperatorWithKernel::OpKernelFunc& func,
153 154
             const phi::ArgumentMappingFn* arg_map_fn,
             const phi::KernelSignature* default_kernel_signature,
155
             platform::DeviceContext* dev_ctx);
156

157 158 159
  PreparedOp(const framework::OperatorBase& op,
             const framework::RuntimeContext& ctx,
             const framework::OpKernelType& kernel_type,
160 161 162
             const phi::ArgumentMappingFn* arg_map_fn,
             const phi::KernelSignature* default_kernel_signature,
             phi::KernelSignature&& kernel_signature,
163 164
             const phi::Kernel& phi_kernel,
             platform::DeviceContext* dev_ctx);
165

166 167 168 169
  static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
                            const NameVarMap<VarBase>& outs,
                            const framework::OperatorWithKernel& op,
                            const platform::Place& place,
170
                            const framework::AttributeMap& attrs,
171
                            const framework::AttributeMap& default_attrs);
172 173 174 175 176

  static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
                            const NameVarMap<VariableWrapper>& outs,
                            const framework::OperatorWithKernel& op,
                            const platform::Place& place,
177
                            const framework::AttributeMap& attrs,
178
                            const framework::AttributeMap& default_attrs);
J
Jiabin Yang 已提交
179

180 181
  static PreparedOp Prepare(const NameVarMap<egr::EagerVariable>& ins,
                            const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
182 183 184 185 186
                            const framework::OperatorWithKernel& op,
                            const platform::Place& place,
                            const framework::AttributeMap& attrs,
                            const framework::AttributeMap& default_attrs);

187 188
  void Run(const NameVarMap<VarBase>& in,
           const NameVarMap<VarBase>& out,
189 190
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs);
191 192 193

  void Run(const NameVarMap<VariableWrapper>& ins,
           const NameVarMap<VariableWrapper>& outs,
194 195
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs);
J
Jiabin Yang 已提交
196

197 198
  void Run(const NameVarMap<egr::EagerVariable>& ins,
           const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
199 200 201
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs);

202 203
  const framework::OpKernelType& kernel_type() const { return kernel_type_; }

J
Jiabin Yang 已提交
204 205 206
 private:
  const framework::OperatorBase& op_;
  const framework::RuntimeContext& ctx_;
207
  framework::OpKernelType kernel_type_;
J
Jiabin Yang 已提交
208 209
  framework::OperatorWithKernel::OpKernelFunc func_;
  platform::DeviceContext* dev_ctx_;
210
  // NOTE(chenweihang): Similar op members are used to adapt to
211
  // new phi kernel, if there is a better design in the future,
212
  // we may polish the implementation here
213
  bool run_phi_kernel_{false};
L
Liu-xiandong 已提交
214
  bool run_kp_kernel_{false};
215 216 217 218
  const phi::ArgumentMappingFn* arg_map_fn_;
  const phi::KernelSignature* default_kernel_signature_;
  phi::KernelSignature kernel_signature_;
  const phi::Kernel& phi_kernel_;
219 220 221 222

  static const phi::KernelFactory& phi_kernel_factory;
  static const phi::OpUtilsMap& phi_op_utils_map;
  static const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map;
J
Jiabin Yang 已提交
223 224
};

225
const inline framework::Attribute* GetAttr(
226
    const framework::AttributeMap& attrs,
227 228
    const framework::AttributeMap& default_attrs,
    const std::string& name) {
229 230 231 232 233 234
  auto it = attrs.find(name);
  bool found = it != attrs.end();
  if (!found) {
    it = default_attrs.find(name);
    found = it != default_attrs.end();
  }
235 236 237 238
  if (found) {
    return &it->second;
  }
  return nullptr;
239 240 241
}

template <typename VarType>
242 243 244 245 246 247 248 249
void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
                                  const phi::Kernel& phi_kernel,
                                  const NameVarMap<VarType>& ins,
                                  const NameVarMap<VarType>& outs,
                                  const framework::AttributeMap& attrs,
                                  const framework::AttributeMap& default_attrs,
                                  platform::DeviceContext* dev_ctx,
                                  phi::KernelContext* kernel_ctx) {
250 251
  kernel_ctx->SetDeviceContext(dev_ctx);

252 253 254
  const auto& input_names = kernel_signature.input_names;
  const auto& attr_names = kernel_signature.attr_names;
  const auto& output_names = kernel_signature.output_names;
255

256 257 258
  auto& input_defs = phi_kernel.args_def().input_defs();
  auto& output_defs = phi_kernel.args_def().output_defs();
  auto& attr_defs = phi_kernel.args_def().attribute_defs();
259

260 261
  PADDLE_ENFORCE_EQ(input_names.size(),
                    input_defs.size(),
262 263 264
                    platform::errors::InvalidArgument(
                        "the size of inputs_args names (%d) must be equal to "
                        "the size of kernel input_defs (%d).",
265 266
                        input_names.size(),
                        input_defs.size()));
267

268 269
  PADDLE_ENFORCE_EQ(output_names.size(),
                    output_defs.size(),
270 271 272
                    platform::errors::InvalidArgument(
                        "the size of outputs_args names (%d) must be equal to "
                        "the size of kernel output_defs (%d).",
273 274
                        output_names.size(),
                        output_defs.size()));
275

276 277
  PADDLE_ENFORCE_EQ(attr_names.size(),
                    attr_defs.size(),
278 279 280
                    platform::errors::InvalidArgument(
                        "the size of attribute_args names (%d) must be equal "
                        "to the size of kernel attribute_defs (%d).",
281 282
                        attr_names.size(),
                        attr_defs.size()));
283 284

  for (size_t i = 0; i < input_names.size(); ++i) {
H
hong 已提交
285
    auto it = ins.find(input_names[i]);
286 287 288

    size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second);

F
From00 已提交
289 290
    if (it == ins.end()) {
      if (LIKELY(input_defs[i].type_index ==
291
                 std::type_index(typeid(paddle::optional<phi::DenseTensor>)))) {
F
From00 已提交
292 293 294 295
        kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
        auto end_idx = start_idx + 1;
        kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
        continue;
296
      } else if (input_defs[i].type_index ==
297 298
                 std::type_index(typeid(
                     paddle::optional<std::vector<const phi::DenseTensor*>>))) {
299 300 301 302
        kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
        auto end_idx = start_idx + 1;
        kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
        continue;
F
From00 已提交
303 304 305 306 307
      } else {
        PADDLE_THROW(phi::errors::NotFound(
            "Can not find input variable '%s' for %s OP, please check whether "
            "the name setting in OpArgumentMapping is consistent with that in "
            "OpMaker.",
308 309
            input_names[i],
            kernel_signature.name));
F
From00 已提交
310
      }
311
    }
F
From00 已提交
312

313
    auto& ins_vector = it->second;
314 315 316
    size_t end_idx = start_idx + ins_vector.size();

    for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
317
      const phi::TensorBase* tensor_in = nullptr;
318
      auto& var = ins_vector[offset]->Var();
319 320
      if (var.template IsType<phi::DenseTensor>()) {
        tensor_in = &(var.template Get<phi::DenseTensor>());
321
        kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
322 323
      } else if (var.template IsType<phi::SelectedRows>()) {
        tensor_in = &(var.template Get<phi::SelectedRows>());
324 325
        kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
      } else if (var.template IsType<framework::LoDTensorArray>()) {
C
Chen Weihang 已提交
326
        paddle::small_vector<const phi::TensorBase*> tensor_vector;
327 328 329 330 331 332
        auto& tensor_array = var.template Get<framework::LoDTensorArray>();
        for (auto& t : tensor_array) {
          tensor_vector.emplace_back(&t);
        }
        kernel_ctx->EmplaceBackInputsWithoutSetRange(tensor_vector);
        end_idx += tensor_array.size() - 1;
333 334 335 336
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported input `%s` type when call pt kernel.",
            framework::ToTypeName(var.Type())));
337
      }
338
    }
339
    kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
340
  }
341
  VLOG(6) << "BuildDygraphPhiKernelContext: Inputs parsing completed.";
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361

  for (size_t i = 0; i < output_names.size(); ++i) {
    size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);

    auto iter = outs.find(output_names[i]);
    if (iter == outs.end()) {
      kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
      kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
                                    i);
      continue;
    }

    auto& outs_vector = iter->second;
    size_t end_idx = start_idx + outs_vector.size();

    for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
      if (outs_vector[offset] == nullptr) {
        kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
        continue;
      }
362

363
      phi::TensorBase* tensor_out = nullptr;
364
      auto* var = outs_vector[offset]->MutableVar();
365 366 367
      if (var) {
        if (var->template IsType<phi::DenseTensor>()) {
          tensor_out = var->template GetMutable<phi::DenseTensor>();
368
          kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
369 370
        } else if (var->template IsType<phi::SelectedRows>()) {
          tensor_out = var->template GetMutable<phi::SelectedRows>();
371 372
          kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
        } else if (var->template IsType<framework::LoDTensorArray>()) {
C
Chen Weihang 已提交
373
          paddle::small_vector<phi::TensorBase*> tensor_vector;
374 375 376 377 378 379 380
          auto* tensor_array =
              var->template GetMutable<framework::LoDTensorArray>();
          for (auto& t : *tensor_array) {
            tensor_vector.emplace_back(&t);
          }
          kernel_ctx->EmplaceBackOutputsWithoutSetRange(tensor_vector);
          end_idx += tensor_array->size() - 1;
381 382 383 384 385
        } else {
          PADDLE_THROW(platform::errors::Unimplemented(
              "Unsupported output `%s` type when call pt kernel.",
              framework::ToTypeName(var->Type())));
        }
386 387
      } else {
        kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
388
      }
389 390 391
    }
    kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
  }
392
  VLOG(6) << "BuildDygraphPhiKernelContext: Outputs parsing completed.";
393 394

  for (size_t i = 0; i < attr_names.size(); ++i) {
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
    VLOG(6) << "BuildDygraphPhiKernelContext: " << attr_names[i] << ": "
            << attr_defs[i].type_index;
    auto* attr_ptr = GetAttr(attrs, default_attrs, attr_names[i]);
    switch (attr_defs[i].type_index) {
      case phi::AttributeType::SCALAR:
        if (attr_ptr) {
          // scalar is in the attribute
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::FLOAT:
              kernel_ctx->EmplaceBackAttr(
                  std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
              break;
            case framework::proto::AttrType::INT:
              kernel_ctx->EmplaceBackAttr(
                  std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
              break;
            case framework::proto::AttrType::STRING:
              kernel_ctx->EmplaceBackAttr(
                  std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to Scalar when construct "
                  "KernelContext in dygraph.",
                  attr_names[i]));
          }
        } else {  // scalar is in the input
          auto& ins_vector = ins.at(attr_names[i]);
424
          kernel_ctx->EmplaceBackAttr(std::move(
425
              experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
426
        }
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
        break;
      case phi::AttributeType::INT_ARRAY:
        if (attr_ptr) {
          auto& attr = *attr_ptr;
          switch (AttrTypeID(attr)) {
            case framework::proto::AttrType::INTS:
              kernel_ctx->EmplaceBackAttr(std::move(
                  phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
              break;
            case framework::proto::AttrType::LONGS:
              kernel_ctx->EmplaceBackAttr(std::move(
                  phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
              break;
            case framework::proto::AttrType::INT:
              kernel_ctx->EmplaceBackAttr(
                  std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
              break;
            case framework::proto::AttrType::LONG:
              kernel_ctx->EmplaceBackAttr(
                  std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1)));
              break;
            default:
              PADDLE_THROW(platform::errors::Unimplemented(
                  "Unsupported cast op attribute `%s` to IntArray when "
                  "construct KernelContext.",
                  attr_names[i]));
          }
        } else {  // shape is in the input
          auto& ins_vector = ins.at(attr_names[i]);
          if (ins_vector.size() == 1) {  // ShapeTensor
            kernel_ctx->EmplaceBackAttr(std::move(
                experimental::MakePhiIntArrayFromVar(ins_vector[0]->Var())));
          } else {  // ShapeTensorList
            std::vector<framework::Variable*> variables;
            variables.reserve(ins_vector.size());
            for (const auto& var_base : ins_vector) {
              variables.push_back(var_base->MutableVar());
            }
            kernel_ctx->EmplaceBackAttr(
                std::move(experimental::MakePhiIntArrayFromVarList(variables)));
467
          }
468
        }
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
        break;
      case phi::AttributeType::SCALARS: {
        PADDLE_ENFORCE_NOT_NULL(
            attr_ptr,
            platform::errors::NotFound("(%s) is not found in AttributeMap when "
                                       "buildind dygraph KernelContext.",
                                       attr_names[i]));
        auto& attr = *attr_ptr;
        switch (AttrTypeID(attr)) {
          case framework::proto::AttrType::INTS: {
            const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
            std::vector<phi::Scalar> scalar_list;
            scalar_list.reserve(vec.size());
            for (const auto& val : vec) {
              scalar_list.emplace_back(val);
            }
            kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
          } break;
          case framework::proto::AttrType::LONGS: {
            const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
            std::vector<phi::Scalar> scalar_list;
            scalar_list.reserve(vec.size());
            for (const auto& val : vec) {
              scalar_list.emplace_back(val);
            }
            kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
          } break;
          case framework::proto::AttrType::FLOATS: {
            const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
            std::vector<phi::Scalar> scalar_list;
            scalar_list.reserve(vec.size());
            for (const auto& val : vec) {
              scalar_list.emplace_back(val);
            }
            kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
          } break;
          case framework::proto::AttrType::FLOAT64S: {
            const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
            std::vector<phi::Scalar> scalar_list;
            scalar_list.reserve(vec.size());
            for (const auto& val : vec) {
              scalar_list.emplace_back(val);
            }
            kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
          } break;
          case framework::proto::AttrType::BOOLEANS: {
            const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr);
            std::vector<phi::Scalar> scalar_list;
            scalar_list.reserve(vec.size());
            for (const auto& val : vec) {
              scalar_list.emplace_back(val);
            }
            kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
          } break;
          default:
            PADDLE_THROW(platform::errors::Unimplemented(
                "Unsupported cast op attribute `%s` to vector<Scalar> when "
                "construct KernelContext.",
                attr_names[i]));
528
        }
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
      } break;
      default: {
        PADDLE_ENFORCE_NOT_NULL(
            attr_ptr,
            platform::errors::NotFound("(%s) is not found in AttributeMap when "
                                       "buildind dygraph KernelContext.",
                                       attr_names[i]));
        auto& attr = *attr_ptr;
        switch (attr_defs[i].type_index) {
          case phi::AttributeType::FLOAT32:
            kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
            break;
          case phi::AttributeType::INT32:
            kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
            break;
          case phi::AttributeType::BOOL:
            kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
            break;
          case phi::AttributeType::INT64:
            kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
            break;
          case phi::AttributeType::INT32S:
            kernel_ctx->EmplaceBackAttr(
                BOOST_GET_CONST(std::vector<int>, attr));
            break;
          case phi::AttributeType::DATA_TYPE: {
            auto data_type = framework::TransToPhiDataType(
                static_cast<framework::proto::VarType::Type>(
                    BOOST_GET_CONST(int, attr)));
            kernel_ctx->EmplaceBackAttr(data_type);
          } break;
          case phi::AttributeType::STRING:
            kernel_ctx->EmplaceBackAttr(
                std::move(BOOST_GET_CONST(std::string, attr)));
            break;
          case phi::AttributeType::INT64S: {
            switch (AttrTypeID(attr)) {
              case framework::proto::AttrType::LONGS:
                kernel_ctx->EmplaceBackAttr(
                    BOOST_GET_CONST(std::vector<int64_t>, attr));
                break;
              case framework::proto::AttrType::INTS: {
                const auto& vector_int_attr =
                    BOOST_GET_CONST(std::vector<int>, attr);
                const std::vector<int64_t> vector_int64_attr(
                    vector_int_attr.begin(), vector_int_attr.end());
                kernel_ctx->EmplaceBackAttr(vector_int64_attr);
              } break;
              default:
                PADDLE_THROW(platform::errors::Unimplemented(
                    "Unsupported cast op attribute `%s` to vector<int64_t> "
                    "when "
                    "construct KernelContext.",
                    attr_names[i]));
            }
          } break;
          case phi::AttributeType::FLOAT32S:
            kernel_ctx->EmplaceBackAttr(
                BOOST_GET_CONST(std::vector<float>, attr));
            break;
          case phi::AttributeType::STRINGS:
            kernel_ctx->EmplaceBackAttr(
                BOOST_GET_CONST(std::vector<std::string>, attr));
            break;
          default:
            PADDLE_THROW(platform::errors::Unimplemented(
                "Unsupported cast op attribute `%s` when construct "
                "KernelContext in dygraph.",
                attr_names[i]));
598 599 600 601
        }
      }
    }
  }
602
  VLOG(6) << "BuildDygraphPhiKernelContext: Attributes parsing completed.";
603 604 605
}

template <typename VarType>
606 607
void PreparePhiData(const phi::Kernel& phi_kernel,
                    const phi::KernelSignature& kernel_signature,
608
                    const NameVarMap<VarType>& ins) {
609 610
  const auto& input_names = kernel_signature.input_names;
  auto& input_defs = phi_kernel.args_def().input_defs();
611

612 613
  PADDLE_ENFORCE_EQ(input_names.size(),
                    input_defs.size(),
614 615 616
                    platform::errors::InvalidArgument(
                        "the size of inputs_args names (%d) must be equal to "
                        "the size of kernel input_defs (%d).",
617 618
                        input_names.size(),
                        input_defs.size()));
619 620 621

  for (size_t i = 0; i < input_names.size(); ++i) {
    auto& in_def = input_defs.at(i);
622 623
    auto iter = ins.find(input_names[i]);
    if (iter == ins.end()) {
H
hong 已提交
624 625
      continue;
    }
626
    auto& ins_vector = iter->second;
627 628

    for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
629
      auto& var = ins_vector[offset];
J
Jiabin Yang 已提交
630
      const auto* tensor_in = GetTensorFromVar(var->Var());
631
      if (tensor_in && tensor_in->IsInitialized()) {
632 633 634
        if (in_def.backend == phi::Backend::ALL_BACKEND) {
          continue;
        }
635 636 637 638
        auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
        if (in_def.backend == tensor_backend ||
            (in_def.backend == phi::Backend::GPUDNN &&
             tensor_backend == phi::Backend::GPU)) {
639 640 641
          continue;
        }

642 643
        auto expected_place = phi::TransToPhiPlace(in_def.backend);

644
        VLOG(3) << "Phi Transform Variable " << input_names[i] << " from "
645 646 647 648 649
                << tensor_in->place() << " to " << expected_place;

        framework::Tensor tmp_tensor;
        framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor);

J
Jiabin Yang 已提交
650
        SetTensorToVariable(var->Var(), tmp_tensor, var->MutableVar());
651 652 653 654 655
      }
    }
  }
}

J
Jiabin Yang 已提交
656 657
}  // namespace imperative
}  // namespace paddle