prepared_operator.cc 8.0 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/prepared_operator.h"
16

J
Jiabin Yang 已提交
17
#include <sstream>
18

19
#include "paddle/fluid/framework/data_type_transform.h"
20 21
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/infer_var_type_context.h"
J
Jiabin Yang 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35

namespace paddle {
namespace imperative {

const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
  if (var.IsType<framework::LoDTensor>()) {
    return &(var.Get<framework::LoDTensor>());
  } else if (var.IsType<framework::SelectedRows>()) {
    return &(var.Get<framework::SelectedRows>().value());
  } else {
    return nullptr;
  }
}

36
template <typename VarType>
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
static void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
  for (auto& pair : outs) {
    for (auto& var : pair.second) {
      if (var == nullptr) {
        continue;
      }
      if (var->ForwardDataType() ==
          static_cast<framework::proto::VarType::Type>(-1)) {
        VLOG(6) << "Var (" << var->Name()
                << ")'s forward data type is not set.";
        continue;
      }
      if (!framework::IsComplexType(var->DataType()) ||
          framework::IsComplexType(var->ForwardDataType())) {
        continue;
      }
      const auto* tensor = GetTensorFromVar(var->Var());
J
Jiabin Yang 已提交
54
      if (tensor && tensor->IsInitialized()) {
55 56 57 58 59 60 61 62
        VLOG(6) << "Transform " << framework::DataTypeToString(var->DataType())
                << " var `" << var->Name() << "` to "
                << framework::DataTypeToString(var->ForwardDataType())
                << " real var in dynamic graph.";
        framework::Tensor out;
        framework::TransComplexToReal(var->ForwardDataType(), var->DataType(),
                                      *tensor, &out);
        SetTensorToVariable(var->Var(), out, var->MutableVar());
J
Jiabin Yang 已提交
63 64 65 66 67 68 69
      }
    }
  }
}

PreparedOp::PreparedOp(const framework::OperatorBase& op,
                       const framework::RuntimeContext& ctx,
70
                       const framework::OpKernelType& kernel_type,
71
                       const framework::OperatorWithKernel::OpKernelFunc& func,
72
                       platform::DeviceContext* dev_ctx)
73 74 75 76 77 78
    : op_(op),
      ctx_(ctx),
      kernel_type_(kernel_type),
      func_(func),
      dev_ctx_(dev_ctx) {}

79 80 81 82 83 84
template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
                       const NameVarMap<VarType>& outs,
                       const framework::OperatorWithKernel& op,
                       const platform::Place& place,
                       const framework::AttributeMap& attrs) {
85
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
86 87 88 89 90 91 92 93 94 95 96 97
  auto* dev_ctx = pool.Get(place);
  framework::RuntimeContext ctx({}, {});

#ifdef PADDLE_WITH_MKLDNN
  // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
  // GetKernelType functions, so we need to copy the attributes there.
  // Const qualifier of Attrs had to be discarded to overwrite it.
  if (FLAGS_use_mkldnn) {
    auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
    mutable_op_attrs = attrs;
  }
#endif
J
Jiabin Yang 已提交
98

99 100 101 102 103 104 105
  // 1. get expected kernel key
  auto expected_kernel_key =
      op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
          op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
  VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

  // 2. check if op[type] has kernel registered.
J
Jiabin Yang 已提交
106 107
  auto& all_op_kernels = op.AllOpKernels();
  auto kernels_iter = all_op_kernels.find(op.Type());
108 109 110 111 112
  PADDLE_ENFORCE_NE(
      kernels_iter, all_op_kernels.end(),
      platform::errors::NotFound(
          "There are no kernels which are registered in the %s operator.",
          op.Type()));
J
Jiabin Yang 已提交
113 114 115

  auto& kernels = kernels_iter->second;
  auto kernel_iter = kernels.find(expected_kernel_key);
116 117 118 119 120 121 122
#ifdef PADDLE_WITH_XPU
  if (kernel_iter == kernels.end() &&
      is_xpu_place(expected_kernel_key.place_)) {
    expected_kernel_key.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(expected_kernel_key);
  }
#endif
J
Jiabin Yang 已提交
123
  // TODO(jiabin): Add operator.cc's line 1000 part back when we need that case
124 125 126 127
  PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
                    platform::errors::NotFound(
                        "Operator %s does not have kernel for %s.", op.Type(),
                        KernelTypeToString(expected_kernel_key)));
128

129 130 131 132
  if (!(expected_kernel_key.place_ == place)) {
    dev_ctx = pool.Get(expected_kernel_key.place_);
  }

133
  return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx);
134 135
}

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
                               const NameVarMap<VarBase>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
                               const framework::AttributeMap& attrs) {
  return PrepareImpl<VarBase>(ins, outs, op, place, attrs);
}

PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
                               const NameVarMap<VariableWrapper>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
                               const framework::AttributeMap& attrs) {
  return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs);
}

152 153 154
template <typename VarType>
static void PreparedOpRunImpl(
    const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
155
    const framework::OpKernelType& kernel_type,
156
    const framework::OperatorWithKernel::OpKernelFunc& func,
157 158
    platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) {
J
Jiabin Yang 已提交
159 160
  // TODO(zjl): remove scope in dygraph
  framework::Scope scope;
H
hong 已提交
161

162 163
  DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
                                                    op.Type());
164 165
  static_cast<const framework::OperatorWithKernel&>(op).InferShape(
      &infer_shape_ctx);
H
hong 已提交
166

167 168
  func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
                                        attrs));
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184

  /**
   * [ Why need handle complex gradient to real gradient? ]
   *
   * After the introduction of complex number calculations, Ops that support
   * complex number calculations generally support type promotion, such as
   * x(float32) + y(complex64) = out(complex64), then the type of the grad
   * tensor should be dout(complex64), dx(float32), dy (complex64).
   *
   * But because the dout is complex64, the dx is also complex64 after
   * grad op kernel executed, we need to recognize this situation and
   * convert dx to float32 type. HandleComplexGradToRealGrad does this thing.
   */
  if (framework::IsComplexType(kernel_type.data_type_)) {
    HandleComplexGradToRealGrad<VarType>(outs);
  }
185
}
H
hong 已提交
186

187 188 189
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
                     const NameVarMap<VarBase>& outs,
                     const framework::AttributeMap& attrs) {
190 191
  PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
                             outs, attrs);
192
}
H
hong 已提交
193

194 195 196
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
                     const NameVarMap<VariableWrapper>& outs,
                     const framework::AttributeMap& attrs) {
197 198
  PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
                                     ins, outs, attrs);
J
Jiabin Yang 已提交
199 200 201 202
}

}  // namespace imperative
}  // namespace paddle