channel.h 5.5 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <utility>
W
wangguibao 已提交
18 19 20 21 22 23 24 25 26
#include "common/inner_common.h"

namespace baidu {
namespace paddle_serving {
namespace predictor {

class Channel;

class Bus {
W
wangguibao 已提交
27 28
 public:
  Bus() { clear(); }
W
wangguibao 已提交
29

W
wangguibao 已提交
30 31 32 33 34 35
  int regist(const std::string& op, Channel* channel) {
    std::pair<boost::unordered_map<std::string, Channel*>::iterator, bool> r =
        _op_channels.insert(std::make_pair(op, channel));
    if (!r.second) {
      LOG(ERROR) << "Failed insert op&channel into bus:" << op;
      return -1;
W
wangguibao 已提交
36
    }
W
wangguibao 已提交
37 38
    return 0;
  }
W
wangguibao 已提交
39

W
wangguibao 已提交
40 41 42 43 44 45
  Channel* channel_by_name(const std::string& op_name) {
    typename boost::unordered_map<std::string, Channel*>::iterator it =
        _op_channels.find(op_name);
    if (it == _op_channels.end()) {
      LOG(WARNING) << "Not found channel in bus, op_name:" << op_name << ".";
      return NULL;
W
wangguibao 已提交
46 47
    }

W
wangguibao 已提交
48 49
    return it->second;
  }
W
wangguibao 已提交
50

W
wangguibao 已提交
51 52 53
  void clear() { _op_channels.clear(); }

  size_t size() const { return _op_channels.size(); }
W
wangguibao 已提交
54

W
wangguibao 已提交
55 56
 private:
  boost::unordered_map<std::string, Channel*> _op_channels;
W
wangguibao 已提交
57 58 59
};

class Channel {
W
wangguibao 已提交
60 61
 public:
  Channel() {}
W
wangguibao 已提交
62

W
wangguibao 已提交
63 64 65 66 67
  void init(uint32_t id, const char* op) {
    _id = id;
    _op = std::string(op);
    clear_data();
  }
W
wangguibao 已提交
68

W
wangguibao 已提交
69
  void deinit() { clear_data(); }
W
wangguibao 已提交
70

W
wangguibao 已提交
71
  uint32_t id() const { return _id; }
W
wangguibao 已提交
72

W
wangguibao 已提交
73
  const std::string& op() { return _op; }
W
wangguibao 已提交
74

W
wangguibao 已提交
75 76 77 78
  int share_to_bus(Bus* bus) {
    if (bus->regist(_op, this) != 0) {
      LOG(ERROR) << "Failed regist channel[" << _op << "] to bus!";
      return -1;
W
wangguibao 已提交
79 80
    }

W
wangguibao 已提交
81 82 83 84
    return 0;
  }

  virtual void clear_data() = 0;
W
wangguibao 已提交
85

W
wangguibao 已提交
86 87
  virtual void* param() = 0;
  virtual const void* param() const = 0;
W
wangguibao 已提交
88

W
wangguibao 已提交
89 90
  virtual google::protobuf::Message* message() = 0;
  virtual const google::protobuf::Message* message() const = 0;
W
wangguibao 已提交
91

W
wangguibao 已提交
92
  virtual Channel& operator=(const Channel& channel) = 0;
W
wangguibao 已提交
93

W
wangguibao 已提交
94
  virtual std::string debug_string() const = 0;
W
wangguibao 已提交
95

W
wangguibao 已提交
96 97 98
 private:
  uint32_t _id;
  std::string _op;
W
wangguibao 已提交
99 100
};

W
wangguibao 已提交
101
template <typename T>
W
wangguibao 已提交
102
class OpChannel : public Channel {
W
wangguibao 已提交
103 104
 public:
  OpChannel() {}
W
wangguibao 已提交
105

W
wangguibao 已提交
106
  void clear_data() { _data.Clear(); }
W
wangguibao 已提交
107

W
wangguibao 已提交
108
  void* param() { return &_data; }
W
wangguibao 已提交
109

W
wangguibao 已提交
110
  const void* param() const { return &_data; }
W
wangguibao 已提交
111

W
wangguibao 已提交
112 113 114 115 116
  google::protobuf::Message* message() {
    return message_impl(
        derived_from_message<
            TIsDerivedFromB<T, google::protobuf::Message>::RESULT>());
  }
W
wangguibao 已提交
117

W
wangguibao 已提交
118 119 120
  google::protobuf::Message* message_impl(derived_from_message<true>) {
    return dynamic_cast<const google::protobuf::Message*>(&_data);
  }
W
wangguibao 已提交
121

W
wangguibao 已提交
122 123 124 125 126
  google::protobuf::Message* message_impl(derived_from_message<false>) {
    LOG(ERROR) << "Current type: " << typeid(T).name()
               << " is not derived from protobuf.";
    return NULL;
  }
W
wangguibao 已提交
127

W
wangguibao 已提交
128 129 130 131 132
  const google::protobuf::Message* message() const {
    return message_impl(
        derived_from_message<
            TIsDerivedFromB<T, google::protobuf::Message>::RESULT>());
  }
W
wangguibao 已提交
133

W
wangguibao 已提交
134 135 136 137
  const google::protobuf::Message* message_impl(
      derived_from_message<true>) const {
    return dynamic_cast<const google::protobuf::Message*>(&_data);
  }
W
wangguibao 已提交
138

W
wangguibao 已提交
139 140 141 142 143 144
  const google::protobuf::Message* message_impl(
      derived_from_message<false>) const {
    LOG(ERROR) << "Current type: " << typeid(T).name()
               << " is not derived from protobuf.";
    return NULL;
  }
W
wangguibao 已提交
145

W
wangguibao 已提交
146 147 148 149
  Channel& operator=(const Channel& channel) {
    _data = *(dynamic_cast<const OpChannel<T>&>(channel)).data();
    return *this;
  }
W
wangguibao 已提交
150

W
wangguibao 已提交
151
  std::string debug_string() const { return _data.ShortDebugString(); }
W
wangguibao 已提交
152

W
wangguibao 已提交
153
  // functions of derived class
W
wangguibao 已提交
154

W
wangguibao 已提交
155
  T* data() { return &_data; }
W
wangguibao 已提交
156

W
wangguibao 已提交
157 158 159 160 161 162 163 164 165
  const T* data() const { return &_data; }

  Channel& operator=(const T& obj) {
    _data = obj;
    return *this;
  }

 private:
  T _data;
W
wangguibao 已提交
166 167
};

W
wangguibao 已提交
168
template <>
W
wangguibao 已提交
169
class OpChannel<google::protobuf::Message> : public Channel {
W
wangguibao 已提交
170 171
 public:
  OpChannel<google::protobuf::Message>() : _data(NULL) {}
W
wangguibao 已提交
172

W
wangguibao 已提交
173
  virtual ~OpChannel<google::protobuf::Message>() { _data = NULL; }
W
wangguibao 已提交
174

W
wangguibao 已提交
175
  void clear_data() { _data = NULL; }
W
wangguibao 已提交
176

W
wangguibao 已提交
177
  void* param() { return const_cast<void*>((const void*)_data); }
W
wangguibao 已提交
178

W
wangguibao 已提交
179
  const void* param() const { return _data; }
W
wangguibao 已提交
180

W
wangguibao 已提交
181 182 183
  google::protobuf::Message* message() {
    return const_cast<google::protobuf::Message*>(_data);
  }
W
wangguibao 已提交
184

W
wangguibao 已提交
185
  const google::protobuf::Message* message() const { return _data; }
W
wangguibao 已提交
186

W
wangguibao 已提交
187 188 189 190
  Channel& operator=(const Channel& channel) {
    _data = channel.message();
    return *this;
  }
W
wangguibao 已提交
191

W
wangguibao 已提交
192 193 194 195 196
  std::string debug_string() const {
    if (_data) {
      return _data->ShortDebugString();
    } else {
      return "{\"Error\": \"Null Message Ptr\"}";
W
wangguibao 已提交
197
    }
W
wangguibao 已提交
198
  }
W
wangguibao 已提交
199

W
wangguibao 已提交
200 201 202 203
  // derived function imiplements
  google::protobuf::Message* data() {
    return const_cast<google::protobuf::Message*>(_data);
  }
W
wangguibao 已提交
204

W
wangguibao 已提交
205
  const google::protobuf::Message* data() const { return _data; }
W
wangguibao 已提交
206

W
wangguibao 已提交
207 208 209 210 211
  OpChannel<google::protobuf::Message>& operator=(
      google::protobuf::Message* message) {
    _data = message;
    return *this;
  }
W
wangguibao 已提交
212

W
wangguibao 已提交
213 214 215 216 217
  OpChannel<google::protobuf::Message>& operator=(
      const google::protobuf::Message* message) {
    _data = message;
    return *this;
  }
W
wangguibao 已提交
218

W
wangguibao 已提交
219 220
 private:
  const google::protobuf::Message* _data;
W
wangguibao 已提交
221 222
};

W
wangguibao 已提交
223 224 225
}  // namespace predictor
}  // namespace paddle_serving
}  // namespace baidu