storage_properties.h 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <memory>
18
#include "paddle/phi/core/ddim.h"
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
#include "paddle/phi/core/utils/type_registry.h"

#ifdef PADDLE_WITH_MKLDNN
#include "dnnl.hpp"  // NOLINT
#endif

namespace phi {

struct StorageProperties {
 public:
  virtual ~StorageProperties() = default;
  TypeInfo<StorageProperties> type_info() const { return type_info_; }

 private:
  template <typename T, typename U>
  friend class TypeInfoTraits;
  TypeInfo<StorageProperties> type_info_{
      TypeInfo<StorageProperties>::kUnknownType};
};

struct NPUStorageProperties
    : public StorageProperties,
      public TypeInfoTraits<StorageProperties, NPUStorageProperties> {
  virtual ~NPUStorageProperties() = default;
  static const char* name() { return "NPUStorageProperties"; }

45 46 47 48
  int64_t origin_format{-1};
  int64_t storage_format{-1};
  DDim origin_dims;
  DDim storage_dims;
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
};

// Add OneDNNStorageProperties firstly for unittest covergae
#ifdef PADDLE_WITH_MKLDNN
struct OneDNNStorageProperties
    : public StorageProperties,
      public TypeInfoTraits<StorageProperties, OneDNNStorageProperties> {
  virtual ~OneDNNStorageProperties() = default;
  static const char* name() { return "OneDNNStorageProperties"; }

  /**
   * @brief the detail format of memory block which have layout as ONEDNN
   *
   * @note ONEDNN lib support various memory format like nchw, nhwc, nChw8C,
   *       nChw16c, etc. For a ONEDNN memory block, layout will be set as
   *       DataLayout::ONEDNN meanwhile detail memory format will be kept in
   *       this field.
   */
  dnnl::memory::format_tag format = dnnl::memory::format_tag::undef;

  /// \brief memory descriptor of tensor which have layout set as ONEDNN
  dnnl::memory::desc mem_desc;
};
#endif

static std::unique_ptr<StorageProperties> CopyStorageProperties(
    const std::unique_ptr<StorageProperties>& sp) {
  if (sp) {
    if (NPUStorageProperties::classof(sp.get())) {
      auto result = std::make_unique<NPUStorageProperties>();
79 80
      result->origin_format =
          static_cast<NPUStorageProperties*>(sp.get())->origin_format;
81 82
      result->storage_format =
          static_cast<NPUStorageProperties*>(sp.get())->storage_format;
83 84 85 86
      result->origin_dims =
          static_cast<NPUStorageProperties*>(sp.get())->origin_dims;
      result->storage_dims =
          static_cast<NPUStorageProperties*>(sp.get())->storage_dims;
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
      return result;
#ifdef PADDLE_WITH_MKLDNN
    } else if (OneDNNStorageProperties::classof(sp.get())) {
      auto result = std::make_unique<OneDNNStorageProperties>();
      result->format = static_cast<OneDNNStorageProperties*>(sp.get())->format;
      result->mem_desc =
          static_cast<OneDNNStorageProperties*>(sp.get())->mem_desc;
      return result;
#endif
    } else {
      return nullptr;
    }
  }
  return nullptr;
}

}  // namespace phi