infer_shape_context.h 16.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>
19 20

#include "paddle/fluid/framework/operator.h"
21 22 23
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/var_helper.h"
25
#include "paddle/fluid/imperative/variable_wrapper.h"
26
#include "paddle/phi/core/ddim.h"
27 28 29 30 31 32 33 34 35

namespace paddle {
namespace imperative {

template <typename VarType>
class DygraphInferShapeContext : public framework::InferShapeContext {
  using DDim = framework::DDim;

 public:
36
  DygraphInferShapeContext(
37 38
      const NameVarMap<VarType>* in,
      const NameVarMap<VarType>* out,
39
      const framework::AttributeMap* attr,
40 41
      const framework::AttributeMap* default_attr,
      const std::string op_type,
42 43 44
      const framework::OpKernelType* op_kernel_type = nullptr,
      const phi::ArgumentMappingFn* arg_map_fn = nullptr,
      const phi::KernelSignature* default_kernel_signature = nullptr)
J
Jiabin Yang 已提交
45 46
      : var_map_in_(in),
        var_map_out_(out),
47
        attrs_(attr),
48
        default_attrs_(default_attr),
49
        op_type_(op_type),
50 51 52
        op_kernel_type_(op_kernel_type),
        arg_map_fn_(arg_map_fn),
        default_kernel_signature_(default_kernel_signature) {}
53 54 55

  bool HasInput(const std::string& name) const override {
    // has only one input
J
Jiabin Yang 已提交
56
    auto it = var_map_in_->find(name);
57

J
Jiabin Yang 已提交
58
    if (it == var_map_in_->end()) {
59 60 61 62 63
      return false;
    }
    const auto& in = it->second;
    if (in.size() == 0) return false;
    PADDLE_ENFORCE_EQ(
64 65
        in.size(),
        1UL,
66 67 68 69 70 71 72
        platform::errors::PreconditionNotMet(
            "Input %s should not have more than one inputs", name));
    return in[0] != nullptr;
  }

  bool HasOutput(const std::string& name) const override {
    // has only one output
J
Jiabin Yang 已提交
73 74
    auto it = var_map_out_->find(name);
    if (it == var_map_out_->end()) {
75 76 77 78 79 80 81
      return false;
    }
    const auto& out = it->second;
    if (out.size() == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(
82 83
        out.size(),
        1UL,
84 85 86 87 88
        platform::errors::PreconditionNotMet(
            "Output %s should not have more than one outputs", name));
    return out[0] != nullptr;
  }

89 90 91 92
  bool HasAttr(const std::string& name) const override {
    return attrs_->count(name) > 0 || default_attrs_->count(name) > 0;
  }

93
  bool HasInputs(const std::string& name) const override {
J
Jiabin Yang 已提交
94 95
    auto it = var_map_in_->find(name);
    if (it == var_map_in_->end() || it->second.empty()) {
96 97 98 99 100 101 102 103 104 105
      return false;
    }
    for (auto& input : it->second) {
      if (input == nullptr) {
        return false;
      }
    }
    return true;
  }

106 107
  bool HasOutputs(const std::string& name,
                  bool allow_null = false) const override {
J
Jiabin Yang 已提交
108 109
    auto it = var_map_out_->find(name);
    if (it == var_map_out_->end() || it->second.empty()) {
110 111
      return false;
    }
Y
YuanRisheng 已提交
112
    if (!allow_null) {
113 114 115 116
      for (auto& output : it->second) {
        if (output == nullptr) {
          return false;
        }
117 118
      }
    }
Y
YuanRisheng 已提交
119
    return true;
120 121 122
  }

  framework::AttrReader Attrs() const override {
123
    return framework::AttrReader(*attrs_, *default_attrs_);
124 125 126 127
  }

  std::vector<std::string> Inputs(const std::string& name) const override {
    std::vector<std::string> vec_res;
J
Jiabin Yang 已提交
128
    auto it = var_map_in_->find(name);
129
    PADDLE_ENFORCE_NE(
130 131
        it,
        var_map_in_->end(),
132 133 134 135 136
        platform::errors::NotFound("can not find [%s] in input", name));

    vec_res.reserve(it->second.size());
    for (auto& var : it->second) {
      if (var) {
J
Jiabin Yang 已提交
137
        vec_res.push_back(GetNameFromVar(var));
138 139 140 141 142 143 144 145 146 147
      } else {
        vec_res.push_back(framework::kEmptyVarName);
      }
    }

    return vec_res;
  }

  std::vector<std::string> Outputs(const std::string& name) const override {
    std::vector<std::string> vec_res;
J
Jiabin Yang 已提交
148
    auto it = var_map_out_->find(name);
149
    PADDLE_ENFORCE_NE(
150 151
        it,
        var_map_out_->end(),
152 153 154 155 156
        platform::errors::NotFound("can not find [%s] in output", name));

    vec_res.reserve(it->second.size());
    for (auto& var : it->second) {
      if (var) {
J
Jiabin Yang 已提交
157
        vec_res.push_back(GetNameFromVar(var));
158 159 160 161 162 163 164
      } else {
        vec_res.push_back(framework::kEmptyVarName);
      }
    }

    return vec_res;
  }
165

166 167 168
  std::string GetInputNameByIdx(size_t idx) const override {
    auto& op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_;
169 170
    PADDLE_ENFORCE_LT(idx,
                      op_proto->inputs().size(),
171 172 173
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
174 175 176
                          op_type_,
                          idx,
                          op_proto->inputs().size()));
177 178 179 180 181 182 183
    return op_proto->inputs()[idx].name();
  }

  std::string GetOutputNameByIdx(size_t idx) const override {
    auto& op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_;
    PADDLE_ENFORCE_LT(
184 185
        idx,
        op_proto->outputs().size(),
186 187 188
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
189 190 191
            op_type_,
            idx,
            op_proto->outputs().size()));
192 193
    return op_proto->outputs()[idx].name();
  }
194

195 196 197
  void ShareDim(const std::string& in,
                const std::string& out,
                size_t i = 0,
198
                size_t j = 0) override {
J
Jiabin Yang 已提交
199 200
    auto in_it = var_map_in_->find(in);
    auto out_it = var_map_out_->find(out);
201
    PADDLE_ENFORCE_NE(
202 203
        in_it,
        var_map_in_->end(),
204
        platform::errors::NotFound("can not found [%s] in input", in));
205 206
    PADDLE_ENFORCE_GT(in_it->second.size(),
                      i,
207 208 209
                      platform::errors::PreconditionNotMet(
                          "Inputs %s should have %llu argument", in, i));
    PADDLE_ENFORCE_NE(
210 211
        out_it,
        var_map_out_->end(),
212
        platform::errors::NotFound("can not found [%s] in input", in));
213 214
    PADDLE_ENFORCE_GT(out_it->second.size(),
                      j,
215 216 217 218 219 220
                      platform::errors::PreconditionNotMet(
                          "Outputs %s should have %llu argument", out, j));

    framework::Variable* in_var = in_it->second[i]->MutableVar();
    framework::Variable* out_var = out_it->second[j]->MutableVar();

221 222
    PADDLE_ENFORCE_EQ(in_var->Type(),
                      out_var->Type(),
223 224 225 226 227 228 229 230
                      platform::errors::PreconditionNotMet(
                          "The type of %s and %s is not the same.", in, out));

    if (in_var->IsType<framework::LoDTensor>()) {
      auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
      auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
      out_lod_tensor->Resize(in_lod_tensor.dims());
    } else {
231 232
      auto& in_sele_rows = in_var->Get<phi::SelectedRows>();
      auto out_sele_rows = out_var->GetMutable<phi::SelectedRows>();
233 234 235 236 237 238 239 240 241 242
      out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
      out_sele_rows->set_rows(in_sele_rows.rows());
      out_sele_rows->set_height(in_sele_rows.height());
    }
  }

  void ShareAllLoD(const std::string& in,
                   const std::string& out) const override {
    // do nothing
  }
243 244 245
  void ShareLoD(const std::string& in,
                const std::string& out,
                size_t i = 0,
246 247 248 249 250 251
                size_t j = 0) const override {
    // do nothing
  }

  bool IsRuntime() const override { return true; }

252 253 254 255 256
  bool IsRunMKLDNNKernel() const override {
    return (op_kernel_type_ &&
            (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
  }

C
Chen Weihang 已提交
257
  paddle::small_vector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
258
  GetInputVarPtrs(const std::string& name) const override {
C
Chen Weihang 已提交
259 260
    paddle::small_vector<framework::InferShapeVarPtr,
                         phi::kInputSmallVectorSize>
261
        res;
J
Jiabin Yang 已提交
262
    auto it = var_map_in_->find(name);
263
    PADDLE_ENFORCE_NE(
264 265
        it,
        var_map_in_->end(),
266 267 268 269 270
        platform::errors::NotFound("Can not find [%s] in inputs.", name));
    for (auto& var : it->second) {
      res.emplace_back(var->MutableVar());
    }
    return res;
271 272
  }

C
Chen Weihang 已提交
273
  paddle::small_vector<framework::InferShapeVarPtr, phi::kOutputSmallVectorSize>
274
  GetOutputVarPtrs(const std::string& name) const override {
C
Chen Weihang 已提交
275 276
    paddle::small_vector<framework::InferShapeVarPtr,
                         phi::kOutputSmallVectorSize>
277
        res;
J
Jiabin Yang 已提交
278
    auto it = var_map_out_->find(name);
279
    PADDLE_ENFORCE_NE(
280 281
        it,
        var_map_out_->end(),
282 283
        platform::errors::NotFound("Can not find [%s] in outputs.", name));
    for (auto& var : it->second) {
Y
YuanRisheng 已提交
284 285 286 287 288
      if (var) {
        res.emplace_back(var->MutableVar());
      } else {
        res.emplace_back(framework::InferShapeVarPtr());
      }
289 290
    }
    return res;
291 292 293
  }

  DDim GetInputDim(const std::string& name) const override {
J
Jiabin Yang 已提交
294
    auto it = var_map_in_->find(name);
295
    PADDLE_ENFORCE_NE(
296 297
        it,
        var_map_in_->end(),
298 299
        platform::errors::NotFound("can not find [%s] in input", name));
    PADDLE_ENFORCE_EQ(
300 301
        it->second.size(),
        1UL,
302
        platform::errors::PreconditionNotMet(
303 304
            "Input(%s) should hold one element, but now it holds %d",
            name,
305 306 307 308 309 310 311
            it->second.size()));
    return this->GetDim(it->second[0]->MutableVar());
  }

  std::vector<DDim> GetInputsDim(const std::string& name) const override {
    // const std::vector<Variable*>& vars = InputVars(name);
    std::vector<DDim> vec_res;
J
Jiabin Yang 已提交
312
    auto it = var_map_in_->find(name);
313
    PADDLE_ENFORCE_NE(
314 315
        it,
        var_map_in_->end(),
316 317 318 319 320 321 322 323 324 325 326 327 328
        platform::errors::NotFound("can not find [%s] in output", name));
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      if (it->second[i]) {
        vec_res.emplace_back(GetDim(it->second[i]->MutableVar()));
      } else {
        vec_res.emplace_back();
      }
    }

    return vec_res;
  }

329 330 331 332
  framework::proto::VarType::Type GetInputVarType(
      const std::string& name) const override {
    auto it = var_map_in_->find(name);
    PADDLE_ENFORCE_NE(
333 334
        it,
        var_map_in_->end(),
335 336 337 338
        platform::errors::NotFound("can not find [%s] in input", name));
    return framework::ToVarType(it->second[0]->Var().Type());
  }

339 340 341
  std::vector<framework::proto::VarType::Type> GetInputsVarType(
      const std::string& name) const override {
    std::vector<framework::proto::VarType::Type> vec_res;
J
Jiabin Yang 已提交
342
    auto it = var_map_in_->find(name);
343
    PADDLE_ENFORCE_NE(
344 345
        it,
        var_map_in_->end(),
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
        platform::errors::NotFound("can not find [%s] in input", name));
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      if (it->second[i]) {
        vec_res.emplace_back(
            framework::ToVarType(it->second[i]->MutableVar()->Type()));
      } else {
        vec_res.emplace_back();
      }
    }
    return vec_res;
  }

  std::vector<framework::proto::VarType::Type> GetOutputsVarType(
      const std::string& name) const override {
    std::vector<framework::proto::VarType::Type> vec_res;
J
Jiabin Yang 已提交
362
    auto it = var_map_out_->find(name);
363
    PADDLE_ENFORCE_NE(
364 365
        it,
        var_map_out_->end(),
366 367 368 369 370 371 372 373 374 375 376 377 378 379
        platform::errors::NotFound("can not find [%s] in output", name));
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      if (it->second[i]) {
        vec_res.emplace_back(
            framework::ToVarType(it->second[i]->MutableVar()->Type()));
      } else {
        vec_res.emplace_back(static_cast<framework::proto::VarType::Type>(-1));
      }
    }
    return vec_res;
  }

  void SetOutputDim(const std::string& name, const DDim& dim) override {
J
Jiabin Yang 已提交
380
    auto it = var_map_out_->find(name);
381
    PADDLE_ENFORCE_NE(
382 383
        it,
        var_map_out_->end(),
384 385 386 387 388 389 390 391 392
        platform::errors::NotFound("can not find [%s] in output", name));

    if (it->second[0]) {
      SetDim(it->second[0]->MutableVar(), dim);
    }
  }

  void SetOutputsDim(const std::string& name,
                     const std::vector<DDim>& dims) override {
J
Jiabin Yang 已提交
393
    auto it = var_map_out_->find(name);
394
    PADDLE_ENFORCE_NE(
395 396
        it,
        var_map_out_->end(),
397 398
        platform::errors::NotFound("can not find [%s] in output", name));

399 400
    PADDLE_ENFORCE_EQ(dims.size(),
                      it->second.size(),
401 402 403 404
                      platform::errors::InvalidArgument(
                          "The number of dims is expected to be equal to the "
                          "number of Outputs(%s). But receieved: the number of "
                          "dims = %d, the number of Outputs(%s) = %d.",
405 406 407 408
                          name,
                          dims.size(),
                          name,
                          it->second.size()));
409 410 411 412 413 414 415 416 417 418 419 420 421

    for (size_t i = 0; i < dims.size(); ++i) {
      if (it->second[i]) {
        SetDim(it->second[i]->MutableVar(), dims[i]);
      }
    }
  }

  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "GetLoDLevel function not support in dygraph mode"));
  }

422 423
  void SetLoDLevel(const std::string& out,
                   int32_t lod_level,
424 425 426 427 428
                   size_t j = 0) const override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "SetLoDLevel function not support in dygraph mode"));
  }

429 430 431 432 433 434 435 436
  const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
    return arg_map_fn_;
  }

  const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
    return default_kernel_signature_;
  }

437 438
 protected:
  DDim GetDim(framework::Variable* var) const {
439 440 441
    PADDLE_ENFORCE_NOT_NULL(var,
                            platform::errors::PreconditionNotMet(
                                "Input variable should not be null"));
442 443
    if (var->IsType<framework::LoDTensor>()) {
      return var->Get<framework::LoDTensor>().dims();
444 445
    } else if (var->IsType<phi::SelectedRows>()) {
      return var->Get<phi::SelectedRows>().GetCompleteDims();
446 447 448
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
449 450
          "type_id is: %s.",
          framework::ToTypeName(var->Type())));
451 452 453 454 455 456 457 458 459 460 461
    }
  }

  std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "GetRepeatedDims not support in dygraph runtime"));
  }

  void SetDim(framework::Variable* var, const DDim& dim) {
    if (var->IsType<framework::LoDTensor>()) {
      var->GetMutable<framework::LoDTensor>()->Resize(dim);
462 463
    } else if (var->IsType<phi::SelectedRows>()) {
      var->GetMutable<phi::SelectedRows>()->set_height(dim[0]);
464 465 466 467 468 469 470 471 472 473
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Variable type_id %s, expect LoDTensor/SelectedRows."));
    }
  }

  void SetDims(const std::vector<framework::Variable*>& vars,
               const std::vector<DDim>& dims) {
    size_t length = vars.size();
    PADDLE_ENFORCE_EQ(
474 475
        length,
        dims.size(),
476
        platform::errors::PreconditionNotMet(
477 478
            "Vars number [%d] should be equal with dims number [%d]",
            length,
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
            dims.size()));
    for (size_t i = 0; i < length; ++i) {
      if (vars[i] == nullptr) {
        continue;
      }
      SetDim(vars[i], dims[i]);
    }
  }

  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "SetRepeatedDims not support in dygraph runtime"));
  }

 private:
J
Jiabin Yang 已提交
495 496
  const NameVarMap<VarType>* var_map_in_;
  const NameVarMap<VarType>* var_map_out_;
497
  const framework::AttributeMap* attrs_;
498
  const framework::AttributeMap* default_attrs_;
499
  const std::string op_type_;
500
  const framework::OpKernelType* op_kernel_type_;
501 502 503
  // arg_map_fn_ and default_kernel_signature_ may be nullptr
  const phi::ArgumentMappingFn* arg_map_fn_;
  const phi::KernelSignature* default_kernel_signature_;
504 505 506 507
};

}  // namespace imperative
}  // namespace paddle