grpc_bytebuffer_stream.cc 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

// NOTE: This file was originally created by tensorflow
//       (https://github.com/tensorflow/tensorflow/) we borrow this
//       file and did some modifications so that we can send gRPC
//       requests without too much copying of the tensor data.

W
Wu Yi 已提交
20
#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h"
21

W
wanghuancoder 已提交
22 23 24 25
namespace grpc {
class ByteBuffer;
}  // namespace grpc

26 27
namespace paddle {
namespace operators {
28
namespace distributed {
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

GrpcByteBufferSource::GrpcByteBufferSource() {}

bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) {
  cur_ = -1;
  left_ = 0;
  ptr_ = nullptr;
  byte_count_ = 0;
  bool ok = src.Dump(&slices_).ok();
  if (!ok) {
    slices_.clear();
  }
  return ok;
}

bool GrpcByteBufferSource::Next(const void** data, int* size) {
  // Use loop instead of if in case buffer contained empty slices.
  while (left_ == 0) {
    // Advance to next slice.
    cur_++;
    if (cur_ >= slices_.size()) {
      return false;
    }
    const ::grpc::Slice& s = slices_[cur_];
    left_ = s.size();
    ptr_ = reinterpret_cast<const char*>(s.begin());
  }

  *data = ptr_;
  *size = left_;
  byte_count_ += left_;
  ptr_ += left_;
  left_ = 0;
  return true;
}

void GrpcByteBufferSource::BackUp(int count) {
  ptr_ -= count;
  left_ += count;
  byte_count_ -= count;
}

bool GrpcByteBufferSource::Skip(int count) {
  const void* data;
  int size;
  while (Next(&data, &size)) {
    if (size >= count) {
      BackUp(size - count);
      return true;
    }
    // size < count;
    count -= size;
  }
  // error or we have too large count;
  return false;
}

google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
  return byte_count_;
}

90
}  // namespace distributed
91
}  // namespace operators
Y
Yancey 已提交
92
}  // namespace paddle