backend_set.h 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <ostream>

#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/common/backend.h"
namespace paddle {
namespace experimental {

/**
 * We use the backend to form a bit set to assist the runtime kernel selection,
 * and the higher backend bit has a higher priority.
 *
 * A Tensor may belong to multiple backends at the same time, such CPU and
 * MKLDNN. Only one backend value cannot
 */
class BackendSet final {
 public:
  constexpr BackendSet() : bitset_(0) {}
  explicit constexpr BackendSet(Backend b)
      : bitset_(b == Backend::UNDEFINED ? 0 : 1ULL << (static_cast<uint8_t>(b) -
                                                       1)) {}

  uint64_t bitset() const { return bitset_; }

  bool inline Has(Backend b) const {
    PADDLE_ENFORCE_NE(b,
                      Backend::UNDEFINED,
                      platform::errors::InvalidArgument(
                          "Backend argument can't be UNDEFINED."));
    return static_cast<bool>(bitset_ & BackendSet(b).bitset());
  }
  bool IsEmpty() const { return bitset_ == 0; }

  BackendSet operator|(const BackendSet& other) const {
    return BackendSet(bitset_ | other.bitset());
  }
  BackendSet operator&(const BackendSet& other) const {
    return BackendSet(bitset_ & other.bitset());
  }
  BackendSet operator-(const BackendSet& other) const {
    return BackendSet(bitset_ & ~other.bitset());
  }
  BackendSet operator^(const BackendSet& other) const {
    return BackendSet(bitset_ ^ other.bitset());
  }

  bool operator==(const BackendSet& other) const {
    return bitset_ == other.bitset();
  }

 private:
  constexpr BackendSet(uint64_t bitset) : bitset_(bitset) {}
  uint64_t bitset_;
};

}  // namespace experimental
}  // namespace paddle