prepared_operator.h 26.7 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
W
wanghuancoder 已提交
20

J
Jiabin Yang 已提交
21
#include "paddle/fluid/eager/eager_tensor.h"
22 23
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/framework/operator.h"
25
#include "paddle/fluid/framework/phi_utils.h"
26
#include "paddle/fluid/framework/type_defs.h"
27
#include "paddle/fluid/imperative/execution_context.h"
J
Jiabin Yang 已提交
28 29
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/imperative/var_helper.h"
J
Jiabin Yang 已提交
31

32
#include "paddle/fluid/framework/convert_utils.h"
33
#include "paddle/phi/core/dense_tensor.h"
34
#include "paddle/phi/core/kernel_context.h"
35
#include "paddle/phi/core/selected_rows.h"
36

37 38
DECLARE_bool(use_mkldnn);

J
Jiabin Yang 已提交
39 40 41 42 43
namespace paddle {
namespace imperative {

const framework::Tensor* GetTensorFromVar(const framework::Variable& var);

44 45 46 47 48 49 50 51
template <typename VarType>
static void SetForwardDataTypeOfGradVar(const std::shared_ptr<VarType>& var);

template <>
void SetForwardDataTypeOfGradVar<VariableWrapper>(
    const std::shared_ptr<VariableWrapper>& var) {
  if (var->HasGradVar()) {
    auto grad_var = var->GetGradVar();
52
    VLOG(6) << "Set grad var (" << grad_var->Name() << ")'s forward dtype to ("
53 54 55 56 57 58 59 60 61 62 63 64 65
            << framework::DataTypeToString(var->DataType()) << ").";
    grad_var->SetForwardDataType(var->DataType());
  }
}

template <>
void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
  if (var->HasGradVar()) {
    auto& shared_var = var->SharedVar();
    SetForwardDataTypeOfGradVar<VariableWrapper>(shared_var);
  }
}

J
Jiabin Yang 已提交
66
template <>
67 68
void SetForwardDataTypeOfGradVar<egr::EagerVariable>(
    const std::shared_ptr<egr::EagerVariable>& var) {
J
Jiabin Yang 已提交
69 70 71 72 73
  VLOG(10) << "Var in Eager dose not support SetForwardDataTypeOfGradVar: "
           << var->name();
  // TODO(jiabin): SetForwardDataType of Grad var is not supported yet in
  // EagerMode.
}
74

75
template <typename VarType>
76
std::shared_ptr<NameVarMap<VarType>> PrepareData(
77 78
    const framework::OperatorWithKernel& op, const NameVarMap<VarType>& ins,
    const framework::OpKernelType& expected_kernel_key) {
79 80 81
  std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr;
  for (const auto& name_pair : ins) {
    for (size_t i = 0; i < name_pair.second.size(); ++i) {
J
Jiabin Yang 已提交
82 83 84
      auto& template_var = name_pair.second[i];
      SetForwardDataTypeOfGradVar(template_var);
      const auto* tensor = GetTensorFromVar(template_var->Var());
85 86 87 88 89 90
      if (tensor && tensor->IsInitialized()) {
        auto kernel_type_for_var = op.GetKernelTypeForVar(
            name_pair.first, *tensor, expected_kernel_key);
        if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
          continue;
        } else {
J
Jiabin Yang 已提交
91 92 93
          VLOG(3) << "Transform Variable " << GetNameFromVar(template_var)
                  << " from " << kernel_type_for_var << " to "
                  << expected_kernel_key;
94

J
Jiabin Yang 已提交
95
          if (CheckCachedKey(template_var, expected_kernel_key)) {
96 97 98
            VLOG(3) << "Hit variable_wrapper cache: key="
                    << expected_kernel_key;
            std::shared_ptr<VariableWrapper> cache_var =
J
Jiabin Yang 已提交
99
                GetCachedValue(template_var, expected_kernel_key);
100 101 102
            if (tmp_ins_ptr == nullptr) {
              tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
            }
103 104

            const auto* tensor = GetTensorFromVar(cache_var->Var());
J
Jiabin Yang 已提交
105 106 107
            auto tmp_var =
                std::make_shared<VarType>(GetNameFromVar(template_var));
            SetType(tmp_var, GetType(template_var));
108 109
            SetTensorToVariable(cache_var->Var(), *tensor,
                                tmp_var->MutableVar());
110 111
            (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
          } else {
112 113 114 115 116 117 118 119 120 121 122
            framework::Tensor out;
            TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
                          &out);
            if (NeedTransformDataType(kernel_type_for_var,
                                      expected_kernel_key)) {
              // To avoid NameVarMap copy construction overhead in general
              // scenarios, if inplace transformed, return original input
              // directly
              if (tmp_ins_ptr == nullptr) {
                tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
              }
J
Jiabin Yang 已提交
123 124 125 126 127
              auto tmp_var =
                  std::make_shared<VarType>(GetNameFromVar(template_var));
              SetType(tmp_var, GetType(template_var));
              SetTensorToVariable(template_var->Var(), out,
                                  tmp_var->MutableVar());
128
              (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
J
Jiabin Yang 已提交
129
              SetCachedValue(template_var, expected_kernel_key, tmp_var);
130 131 132 133 134 135
              VLOG(3) << "Set cache to variable_wrapper: key="
                      << expected_kernel_key;
            } else {
              // if dtype is same, transform inplace will not change the
              // original
              // value, transform inplace to avoid multiple copy
J
Jiabin Yang 已提交
136 137
              SetTensorToVariable(template_var->Var(), out,
                                  template_var->MutableVar());
138
            }
139
          }
140 141 142 143
        }
      }
    }
  }
144
  return tmp_ins_ptr;
145 146
}

J
Jiabin Yang 已提交
147 148
class PreparedOp {
 public:
149 150
  PreparedOp(const framework::OperatorBase& op,
             const framework::RuntimeContext& ctx,
151
             const framework::OpKernelType& kernel_type,
152
             const framework::OperatorWithKernel::OpKernelFunc& func,
153
             platform::DeviceContext* dev_ctx);
154

155 156 157
  PreparedOp(const framework::OperatorBase& op,
             const framework::RuntimeContext& ctx,
             const framework::OpKernelType& kernel_type,
158
             framework::KernelSignature&& kernel_signature,
159
             const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
160

161 162 163 164
  static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
                            const NameVarMap<VarBase>& outs,
                            const framework::OperatorWithKernel& op,
                            const platform::Place& place,
165
                            const framework::AttributeMap& attrs,
166
                            const framework::AttributeMap& default_attrs);
167 168 169 170 171

  static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
                            const NameVarMap<VariableWrapper>& outs,
                            const framework::OperatorWithKernel& op,
                            const platform::Place& place,
172
                            const framework::AttributeMap& attrs,
173
                            const framework::AttributeMap& default_attrs);
J
Jiabin Yang 已提交
174

175 176
  static PreparedOp Prepare(const NameVarMap<egr::EagerVariable>& ins,
                            const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
177 178 179 180 181
                            const framework::OperatorWithKernel& op,
                            const platform::Place& place,
                            const framework::AttributeMap& attrs,
                            const framework::AttributeMap& default_attrs);

182
  void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
183 184
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs);
185 186 187

  void Run(const NameVarMap<VariableWrapper>& ins,
           const NameVarMap<VariableWrapper>& outs,
188 189
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs);
J
Jiabin Yang 已提交
190

191 192
  void Run(const NameVarMap<egr::EagerVariable>& ins,
           const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
193 194 195
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs);

196 197
  const framework::OpKernelType& kernel_type() const { return kernel_type_; }

J
Jiabin Yang 已提交
198 199 200
 private:
  const framework::OperatorBase& op_;
  const framework::RuntimeContext& ctx_;
201
  framework::OpKernelType kernel_type_;
J
Jiabin Yang 已提交
202 203
  framework::OperatorWithKernel::OpKernelFunc func_;
  platform::DeviceContext* dev_ctx_;
204
  // NOTE(chenweihang): Similar op members are used to adapt to
205
  // new phi kernel, if there is a better design in the future,
206
  // we may polish the implementation here
207
  bool run_phi_kernel_{false};
L
Liu-xiandong 已提交
208
  bool run_kp_kernel_{false};
209
  framework::KernelSignature pt_kernel_signature_;
210
  const phi::Kernel& pt_kernel_;
J
Jiabin Yang 已提交
211 212
};

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
const inline framework::Attribute& GetAttr(
    const framework::AttributeMap& attrs,
    const framework::AttributeMap& default_attrs, const std::string& name) {
  auto it = attrs.find(name);
  bool found = it != attrs.end();
  if (!found) {
    it = default_attrs.find(name);
    found = it != default_attrs.end();
  }
  PADDLE_ENFORCE_EQ(
      found, true,
      platform::errors::NotFound("(%s) is not found in AttributeMap.", name));
  return it->second;
}

template <typename VarType>
229
void BuildDygraphPhiKernelContext(
230
    const framework::KernelSignature& pt_kernel_signature,
231
    const phi::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
232 233
    const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
    const framework::AttributeMap& default_attrs,
234
    platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) {
235 236
  kernel_ctx->SetDeviceContext(dev_ctx);

237 238 239
  const auto& input_names = pt_kernel_signature.input_names;
  const auto& attr_names = pt_kernel_signature.attr_names;
  const auto& output_names = pt_kernel_signature.output_names;
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263

  auto& input_defs = pt_kernel.args_def().input_defs();
  auto& output_defs = pt_kernel.args_def().output_defs();
  auto& attr_defs = pt_kernel.args_def().attribute_defs();

  PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
                    platform::errors::InvalidArgument(
                        "the size of inputs_args names (%d) must be equal to "
                        "the size of kernel input_defs (%d).",
                        input_names.size(), input_defs.size()));

  PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(),
                    platform::errors::InvalidArgument(
                        "the size of outputs_args names (%d) must be equal to "
                        "the size of kernel output_defs (%d).",
                        output_names.size(), output_defs.size()));

  PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(),
                    platform::errors::InvalidArgument(
                        "the size of attribute_args names (%d) must be equal "
                        "to the size of kernel attribute_defs (%d).",
                        attr_names.size(), attr_defs.size()));

  for (size_t i = 0; i < input_names.size(); ++i) {
H
hong 已提交
264
    auto it = ins.find(input_names[i]);
265 266 267

    size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second);

F
From00 已提交
268 269 270 271 272 273 274 275
    if (it == ins.end()) {
      if (LIKELY(input_defs[i].type_index ==
                 std::type_index(
                     typeid(paddle::optional<const phi::DenseTensor&>)))) {
        kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
        auto end_idx = start_idx + 1;
        kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
        continue;
276 277 278 279 280 281 282 283
      } else if (input_defs[i].type_index ==
                 std::type_index(
                     typeid(paddle::optional<
                            const std::vector<const phi::DenseTensor*>>))) {
        kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
        auto end_idx = start_idx + 1;
        kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
        continue;
F
From00 已提交
284 285 286 287 288 289 290
      } else {
        PADDLE_THROW(phi::errors::NotFound(
            "Can not find input variable '%s' for %s OP, please check whether "
            "the name setting in OpArgumentMapping is consistent with that in "
            "OpMaker.",
            input_names[i], pt_kernel_signature.name));
      }
291
    }
F
From00 已提交
292

293
    auto& ins_vector = it->second;
294 295 296
    size_t end_idx = start_idx + ins_vector.size();

    for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
297
      const phi::TensorBase* tensor_in = nullptr;
298
      auto& var = ins_vector[offset]->Var();
299 300
      if (var.template IsType<phi::DenseTensor>()) {
        tensor_in = &(var.template Get<phi::DenseTensor>());
301
        kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
302 303
      } else if (var.template IsType<phi::SelectedRows>()) {
        tensor_in = &(var.template Get<phi::SelectedRows>());
304 305 306 307 308 309 310 311 312
        kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
      } else if (var.template IsType<framework::LoDTensorArray>()) {
        paddle::SmallVector<const phi::TensorBase*> tensor_vector;
        auto& tensor_array = var.template Get<framework::LoDTensorArray>();
        for (auto& t : tensor_array) {
          tensor_vector.emplace_back(&t);
        }
        kernel_ctx->EmplaceBackInputsWithoutSetRange(tensor_vector);
        end_idx += tensor_array.size() - 1;
313 314 315 316
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported input `%s` type when call pt kernel.",
            framework::ToTypeName(var.Type())));
317
      }
318
    }
319
    kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
  }

  for (size_t i = 0; i < output_names.size(); ++i) {
    size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);

    auto iter = outs.find(output_names[i]);
    if (iter == outs.end()) {
      kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
      kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
                                    i);
      continue;
    }

    auto& outs_vector = iter->second;
    size_t end_idx = start_idx + outs_vector.size();

    for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
      if (outs_vector[offset] == nullptr) {
        kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
        continue;
      }
341

342
      phi::TensorBase* tensor_out = nullptr;
343
      auto* var = outs_vector[offset]->MutableVar();
344 345 346
      if (var) {
        if (var->template IsType<phi::DenseTensor>()) {
          tensor_out = var->template GetMutable<phi::DenseTensor>();
347
          kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
348 349
        } else if (var->template IsType<phi::SelectedRows>()) {
          tensor_out = var->template GetMutable<phi::SelectedRows>();
350 351 352 353 354 355 356 357 358 359
          kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
        } else if (var->template IsType<framework::LoDTensorArray>()) {
          paddle::SmallVector<phi::TensorBase*> tensor_vector;
          auto* tensor_array =
              var->template GetMutable<framework::LoDTensorArray>();
          for (auto& t : *tensor_array) {
            tensor_vector.emplace_back(&t);
          }
          kernel_ctx->EmplaceBackOutputsWithoutSetRange(tensor_vector);
          end_idx += tensor_array->size() - 1;
360 361 362 363 364
        } else {
          PADDLE_THROW(platform::errors::Unimplemented(
              "Unsupported output `%s` type when call pt kernel.",
              framework::ToTypeName(var->Type())));
        }
365 366
      } else {
        kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
367
      }
368 369 370 371 372
    }
    kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
  }

  for (size_t i = 0; i < attr_names.size(); ++i) {
373
    if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) {
374 375 376 377 378 379
      if (attrs.find(attr_names[i]) !=
          attrs.end()) {  // shape is in the attribute
        auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
        if (std::type_index(attr.type()) ==
            std::type_index(typeid(std::vector<int64_t>))) {
          kernel_ctx->EmplaceBackAttr(std::move(
380
              phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
381 382 383
        } else if (std::type_index(attr.type()) ==
                   std::type_index(typeid(std::vector<int32_t>))) {
          kernel_ctx->EmplaceBackAttr(std::move(
384
              phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
C
chentianyu03 已提交
385 386 387
        } else if (std::type_index(attr.type()) ==
                   std::type_index(typeid(int64_t))) {
          kernel_ctx->EmplaceBackAttr(
388
              std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1)));
C
chentianyu03 已提交
389 390 391
        } else if (std::type_index(attr.type()) ==
                   std::type_index(typeid(int32_t))) {
          kernel_ctx->EmplaceBackAttr(
392
              std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
H
hong 已提交
393 394 395 396
        } else if (attr_defs[i].type_index ==
                   std::type_index(typeid(std::vector<int32_t>))) {
          const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
          kernel_ctx->EmplaceBackAttr(vector_int_attr);
397 398 399 400 401 402 403 404 405 406
        } else {
          PADDLE_THROW(platform::errors::Unimplemented(
              "Unsupported cast op attribute `%s` to VectorTensor when "
              "construct KernelContext.",
              attr_names[i]));
        }
      } else {  // shape is in the input
        auto& ins_vector = ins.at(attr_names[i]);
        if (ins_vector.size() == 1) {  // ShapeTensor
          kernel_ctx->EmplaceBackAttr(std::move(
407
              experimental::MakePhiIntArrayFromVar(ins_vector[0]->Var())));
408 409 410 411 412 413
        } else {  // ShapeTensorList
          std::vector<framework::Variable*> variables;
          variables.reserve(ins_vector.size());
          for (const auto& var_base : ins_vector) {
            variables.push_back(var_base->MutableVar());
          }
414 415
          kernel_ctx->EmplaceBackAttr(
              std::move(experimental::MakePhiIntArrayFromVarList(variables)));
416 417 418
        }
      }
    } else if (attr_defs[i].type_index ==
419
               std::type_index(typeid(phi::Scalar))) {
420 421 422 423 424 425 426 427 428
      // TODO(chenweihang): support other attrs later
      // TODO(zhangyunfei): Scalar should hold scaler type, and we should check
      // attribtue type by attr_defs
      if (attrs.find(attr_names[i]) != attrs.end() ||
          default_attrs.find(attr_names[i]) !=
              default_attrs.end()) {  // scalar is in the attribute
        auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
        if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
          kernel_ctx->EmplaceBackAttr(
429
              std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
430 431 432
        } else if (std::type_index(attr.type()) ==
                   std::type_index(typeid(std::string))) {
          kernel_ctx->EmplaceBackAttr(
433
              std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
434 435 436
        } else if (std::type_index(attr.type()) ==
                   std::type_index(typeid(int))) {
          kernel_ctx->EmplaceBackAttr(
437
              std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
438 439 440 441 442 443 444 445 446
        } else {
          PADDLE_THROW(platform::errors::Unimplemented(
              "Unsupported cast op attribute `%s` to Scalar when construct "
              "KernelContext in dygraph.",
              attr_names[i]));
        }
      } else {  // scalar is in the input
        auto& ins_vector = ins.at(attr_names[i]);
        kernel_ctx->EmplaceBackAttr(std::move(
447
            experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
448 449
      }

H
hong 已提交
450 451 452 453 454 455 456 457 458 459 460
    } else if (ins.find(attr_names[i]) != ins.end()) {
      // deal tensor attr here
      auto& ins_vector = ins.at(attr_names[i]);
      auto tensor_attr =
          experimental::MakePhiScalarFromVar(ins_vector[0]->Var());
      if (attr_defs[i].type_index == std::type_index(typeid(int))) {
        int val = tensor_attr.template to<int>();
        kernel_ctx->EmplaceBackAttr(val);
      } else {
        PADDLE_THROW(platform::errors::Unimplemented("only support int here"));
      }
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
    } else if (attr_defs[i].type_index ==
               std::type_index(typeid(std::vector<phi::Scalar>))) {
      auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
      if (std::type_index(attr.type()) ==
          std::type_index(typeid(std::vector<int32_t>))) {
        const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
        std::vector<phi::Scalar> scalar_list;
        scalar_list.reserve(vec.size());
        for (const auto& val : vec) {
          scalar_list.emplace_back(val);
        }
        kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
      } else if (std::type_index(attr.type()) ==
                 std::type_index(typeid(std::vector<int64_t>))) {
        const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
        std::vector<phi::Scalar> scalar_list;
        scalar_list.reserve(vec.size());
        for (const auto& val : vec) {
          scalar_list.emplace_back(val);
        }
        kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
      } else if (std::type_index(attr.type()) ==
                 std::type_index(typeid(std::vector<float>))) {
        const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
        std::vector<phi::Scalar> scalar_list;
        scalar_list.reserve(vec.size());
        for (const auto& val : vec) {
          scalar_list.emplace_back(val);
        }
        kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
      } else if (std::type_index(attr.type()) ==
                 std::type_index(typeid(std::vector<double>))) {
        const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
        std::vector<phi::Scalar> scalar_list;
        scalar_list.reserve(vec.size());
        for (const auto& val : vec) {
          scalar_list.emplace_back(val);
        }
        kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
      } else if (std::type_index(attr.type()) ==
                 std::type_index(typeid(std::vector<bool>))) {
        const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr);
        std::vector<phi::Scalar> scalar_list;
        scalar_list.reserve(vec.size());
        for (const auto& val : vec) {
          scalar_list.emplace_back(val);
        }
        kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported cast op attribute `%s` to vector<Scalar> when "
            "construct KernelContext.",
            attr_names[i]));
      }
515 516
    } else {
      // TODO(chenweihang): support other attrs later
H
hong 已提交
517

518 519 520 521 522 523 524
      auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
      if (attr_defs[i].type_index == std::type_index(typeid(int))) {
        kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
      } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
        kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
      } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
        kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
H
hong 已提交
525 526
      } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
        kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
H
hong 已提交
527 528 529
      } else if (attr_defs[i].type_index ==
                 std::type_index(typeid(std::string))) {
        kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
530
      } else if (attr_defs[i].type_index ==
531
                 std::type_index(typeid(phi::DataType))) {
532
        auto data_type = framework::TransToPhiDataType(
533 534 535 536 537 538
            static_cast<framework::proto::VarType::Type>(
                BOOST_GET_CONST(int, attr)));
        kernel_ctx->EmplaceBackAttr(data_type);
      } else if (attr_defs[i].type_index ==
                 std::type_index(typeid(std::vector<int64_t>))) {
        if (std::type_index(attr.type()) ==
539 540 541 542 543
            std::type_index(typeid(std::vector<int64_t>))) {
          kernel_ctx->EmplaceBackAttr(
              BOOST_GET_CONST(std::vector<int64_t>, attr));
        } else if (std::type_index(attr.type()) ==
                   std::type_index(typeid(std::vector<int>))) {
544
          // Emplace Back Attr according to the type of Phi_Kernel args.
545 546 547 548 549
          const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
          const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
                                                       vector_int_attr.end());
          kernel_ctx->EmplaceBackAttr(vector_int64_attr);
        }
550 551 552
      } else if (attr_defs[i].type_index ==
                 std::type_index(typeid(std::vector<int>))) {
        kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
553 554 555 556
      } else if (attr_defs[i].type_index ==
                 std::type_index(typeid(std::vector<std::string>))) {
        kernel_ctx->EmplaceBackAttr(
            BOOST_GET_CONST(std::vector<std::string>, attr));
557 558 559
      } else if (attr_defs[i].type_index ==
                 std::type_index(typeid(std::vector<float>))) {
        kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<float>, attr));
560 561 562 563 564 565 566 567 568 569 570
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported cast op attribute `%s` when construct "
            "KernelContext in dygraph.",
            attr_names[i]));
      }
    }
  }
}

template <typename VarType>
571 572 573
void PreparePhiData(const phi::Kernel& pt_kernel,
                    const framework::KernelSignature& pt_kernel_signature,
                    const NameVarMap<VarType>& ins) {
574
  const auto& input_names = pt_kernel_signature.input_names;
575 576 577 578 579 580 581 582 583 584
  auto& input_defs = pt_kernel.args_def().input_defs();

  PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
                    platform::errors::InvalidArgument(
                        "the size of inputs_args names (%d) must be equal to "
                        "the size of kernel input_defs (%d).",
                        input_names.size(), input_defs.size()));

  for (size_t i = 0; i < input_names.size(); ++i) {
    auto& in_def = input_defs.at(i);
585 586
    auto iter = ins.find(input_names[i]);
    if (iter == ins.end()) {
H
hong 已提交
587 588
      continue;
    }
589
    auto& ins_vector = iter->second;
590 591

    for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
592
      auto& var = ins_vector[offset];
J
Jiabin Yang 已提交
593
      const auto* tensor_in = GetTensorFromVar(var->Var());
594
      if (tensor_in && tensor_in->IsInitialized()) {
595 596 597
        if (in_def.backend == phi::Backend::ALL_BACKEND) {
          continue;
        }
598 599 600 601
        auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
        if (in_def.backend == tensor_backend ||
            (in_def.backend == phi::Backend::GPUDNN &&
             tensor_backend == phi::Backend::GPU)) {
602 603 604
          continue;
        }

605 606
        auto expected_place = phi::TransToPhiPlace(in_def.backend);

607
        VLOG(3) << "Phi Transform Variable " << input_names[i] << " from "
608 609 610 611 612
                << tensor_in->place() << " to " << expected_place;

        framework::Tensor tmp_tensor;
        framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor);

J
Jiabin Yang 已提交
613
        SetTensorToVariable(var->Var(), tmp_tensor, var->MutableVar());
614 615 616 617 618
      }
    }
  }
}

J
Jiabin Yang 已提交
619 620
}  // namespace imperative
}  // namespace paddle