abstract_operation.h 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
15 16
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
17

18 19
#include <memory>

20
#include "absl/types/span.h"
21
#include "tensorflow/c/eager/abstract_tensor_handle.h"
22 23
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/types.pb.h"
24
#include "tensorflow/core/platform/status.h"
25

G
Gaurav Jain 已提交
26 27
namespace tensorflow {

28
// Abstract interface to an operation.
29 30 31 32
// This interface allows building and executing an operation in either
// tracing or immediate execution mode.
class AbstractOperation {
 protected:
33 34 35 36 37 38 39 40
  enum AbstractOperationKind {
    kGraph,
    kMlir,
    kEager,
    kTfrt,
    kTape,
    kOpHandler
  };
41 42 43
  explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
  virtual ~AbstractOperation() {}

44
 public:
45 46
  AbstractOperationKind getKind() const { return kind_; }

47 48 49 50 51 52 53
  // Release any underlying resources, including the interface object.
  //
  // WARNING: The destructor of this class is marked as protected to disallow
  // clients from directly destroying this object since it may manage it's own
  // lifetime through ref counting. Thus this must be allocated on the heap and
  // clients MUST call Release() in order to destroy an instance of this class.
  virtual void Release() = 0;
54

G
Gaurav Jain 已提交
55
  virtual Status Reset(const char* op, const char* raw_device_name) = 0;
56

G
Gaurav Jain 已提交
57
  virtual const string& Name() const = 0;
58 59 60 61 62 63 64 65 66 67 68 69

  // Returns the operation's device name.
  //
  // The value returned may be different from the one set by SetDeviceName, but
  // it will be compatible with it: the name will be updated by device placement
  // logic to refer to the specific device chosen.
  //
  // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
  // returned by DeviceName should be "/device:GPU:*" until a particular GPU is
  // chosen for the operation by the device placement logic in the
  // executor. After that, the value returned by DeviceName will be a full
  // device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
G
Gaurav Jain 已提交
70
  virtual const string& DeviceName() const = 0;
71 72 73 74 75 76 77 78 79

  // Sets the operation device name.
  //
  // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
  // the result will be used as a constraint for device placement. See the
  // documentation for DeviceName for more details.
  //
  // The value will override the previous value - that is, no "merging" of
  // existing and given constraints will be performed.
G
Gaurav Jain 已提交
80
  virtual Status SetDeviceName(const char* name) = 0;
81

82
  virtual Status AddInput(AbstractTensorHandle* input) = 0;
83 84
  virtual Status AddInputList(
      absl::Span<AbstractTensorHandle* const> inputs) = 0;
85
  virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
86
                         int* num_retvals) = 0;
87

G
Gaurav Jain 已提交
88 89 90 91 92
  virtual Status SetAttrString(const char* attr_name, const char* data,
                               size_t length) = 0;
  virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
  virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
  virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
93
  virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
G
Gaurav Jain 已提交
94 95
  virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
                              const int num_dims) = 0;
96
  virtual Status SetAttrFunction(const char* attr_name,
97
                                 const AbstractOperation* value) = 0;
G
Gaurav Jain 已提交
98 99
  virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
                                     size_t length) = 0;
100 101
  virtual Status SetAttrTensor(const char* attr_name,
                               AbstractTensorInterface* tensor) = 0;
G
Gaurav Jain 已提交
102 103 104 105 106 107 108
  virtual Status SetAttrStringList(const char* attr_name,
                                   const void* const* values,
                                   const size_t* lengths, int num_values) = 0;
  virtual Status SetAttrFloatList(const char* attr_name, const float* values,
                                  int num_values) = 0;
  virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
                                int num_values) = 0;
109 110
  virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
                                 int num_values) = 0;
G
Gaurav Jain 已提交
111 112 113 114 115
  virtual Status SetAttrBoolList(const char* attr_name,
                                 const unsigned char* values,
                                 int num_values) = 0;
  virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
                                  const int* num_dims, int num_values) = 0;
116
  virtual Status SetAttrFunctionList(
117
      const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
G
Gaurav Jain 已提交
118

119 120
 private:
  const AbstractOperationKind kind_;
121 122
};

123 124 125 126 127 128 129 130 131 132
namespace internal {
struct AbstractOperationDeleter {
  void operator()(AbstractOperation* p) const {
    if (p != nullptr) {
      p->Release();
    }
  }
};
}  // namespace internal

133
using AbstractOperationPtr =
134 135
    std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;

136 137
}  // namespace tensorflow

138
#endif  // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_