未验证 提交 a009272e 编写于 作者: Y Yan Chunwei 提交者: GitHub

inference/unify output buffer management (#11569)

上级 5f0c780a
...@@ -40,10 +40,9 @@ void Main(bool use_gpu) { ...@@ -40,10 +40,9 @@ void Main(bool use_gpu) {
//# 2. Prepare input. //# 2. Prepare input.
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "", PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}), .shape = std::vector<int>({4, 1}),
.data = buf, .data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::INT64}; .dtype = PaddleDType::INT64};
// For simplicity, we set all the slots with the same data. // For simplicity, we set all the slots with the same data.
...@@ -55,14 +54,12 @@ void Main(bool use_gpu) { ...@@ -55,14 +54,12 @@ void Main(bool use_gpu) {
//# 4. Get output. //# 4. Get output.
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
LOG(INFO) << "output buffer size: " << outputs.front().data.length; LOG(INFO) << "output buffer size: " << outputs.front().data.length();
const size_t num_elements = outputs.front().data.length / sizeof(float); const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data)[i]; LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i];
} }
// TODO(Superjomn): this is should be free automatically
free(outputs[0].data.data);
} }
} }
...@@ -86,10 +83,9 @@ void MainThreads(int num_threads, bool use_gpu) { ...@@ -86,10 +83,9 @@ void MainThreads(int num_threads, bool use_gpu) {
for (int batch_id = 0; batch_id < num_batches; ++batch_id) { for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
// 2. Dummy Input Data // 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "", PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}), .shape = std::vector<int>({4, 1}),
.data = buf, .data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::INT64}; .dtype = PaddleDType::INT64};
std::vector<PaddleTensor> inputs(4, tensor); std::vector<PaddleTensor> inputs(4, tensor);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
...@@ -99,13 +95,13 @@ void MainThreads(int num_threads, bool use_gpu) { ...@@ -99,13 +95,13 @@ void MainThreads(int num_threads, bool use_gpu) {
// 4. Get output. // 4. Get output.
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
LOG(INFO) << "TID: " << tid << ", " LOG(INFO) << "TID: " << tid << ", "
<< "output buffer size: " << outputs.front().data.length; << "output buffer size: " << outputs.front().data.length();
const size_t num_elements = outputs.front().data.length / sizeof(float); const size_t num_elements =
outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data)[i]; LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i];
} }
free(outputs[0].data.data);
} }
}); });
} }
......
...@@ -13,3 +13,53 @@ See the License for the specific language governing permissions and ...@@ -13,3 +13,53 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/contrib/inference/paddle_inference_api.h" #include "paddle/contrib/inference/paddle_inference_api.h"
namespace paddle {
PaddleBuf::PaddleBuf(PaddleBuf&& other)
: data_(other.data_),
length_(other.length_),
memory_owned_(other.memory_owned_) {
other.memory_owned_ = false;
other.data_ = nullptr;
other.length_ = 0;
}
PaddleBuf::PaddleBuf(const PaddleBuf& other) { *this = other; }
PaddleBuf& PaddleBuf::operator=(const PaddleBuf& other) {
// only the buffer with external memory can be copied
assert(!other.memory_owned_);
data_ = other.data_;
length_ = other.length_;
memory_owned_ = other.memory_owned_;
return *this;
}
void PaddleBuf::Resize(size_t length) {
// Only the owned memory can be reset, the external memory can't be changed.
if (length_ == length) return;
assert(memory_owned_);
Free();
data_ = new char[length];
length_ = length;
memory_owned_ = true;
}
void PaddleBuf::Reset(void* data, size_t length) {
Free();
memory_owned_ = false;
data_ = data;
length_ = length;
}
void PaddleBuf::Free() {
if (memory_owned_ && data_) {
assert(length_ > 0);
delete static_cast<char*>(data_);
data_ = nullptr;
length_ = 0;
}
}
} // namespace paddle
\ No newline at end of file
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#pragma once #pragma once
#include <cassert>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -32,12 +33,38 @@ enum PaddleDType { ...@@ -32,12 +33,38 @@ enum PaddleDType {
INT64, INT64,
}; };
struct PaddleBuf { class PaddleBuf {
void* data; // pointer to the data memory. public:
size_t length; // number of memory bytes. PaddleBuf() = default;
PaddleBuf(PaddleBuf&& other);
// Copy only available when memory is managed externally.
explicit PaddleBuf(const PaddleBuf&);
PaddleBuf& operator=(const PaddleBuf&);
// Do not own the memory.
PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {}
// Own memory.
PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes.
void Resize(size_t length);
// Reset to external memory.
void Reset(void* data, size_t length);
bool empty() const { return length_ == 0; }
void* data() const { return data_; }
size_t length() const { return length_; }
~PaddleBuf() { Free(); }
private:
void Free();
void* data_{nullptr}; // pointer to the data memory.
size_t length_{0}; // number of memory bytes.
bool memory_owned_{true};
}; };
struct PaddleTensor { struct PaddleTensor {
PaddleTensor() = default;
std::string name; // variable name. std::string name; // variable name.
std::vector<int> shape; std::vector<int> shape;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed. // TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
...@@ -67,8 +94,9 @@ class PaddlePredictor { ...@@ -67,8 +94,9 @@ class PaddlePredictor {
// Predict an record. // Predict an record.
// The caller should be responsible for allocating and releasing the memory of // The caller should be responsible for allocating and releasing the memory of
// `inputs`. `inputs` should be alive until Run returns. caller should be // `inputs`. `inputs` should be available until Run returns. Caller should be
// responsible for releasing the memory of `output_data`. // responsible for the output tensor's buffer, either allocated or passed from
// outside.
virtual bool Run(const std::vector<PaddleTensor>& inputs, virtual bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data) = 0; std::vector<PaddleTensor>* output_data) = 0;
......
...@@ -48,7 +48,7 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -48,7 +48,7 @@ bool PaddleInferenceAnakinPredictor::Run(
auto d_tensor_in_p = executor_.get_in(input.name); auto d_tensor_in_p = executor_.get_in(input.name);
float *d_data_p = d_tensor_in_p->mutable_data(); float *d_data_p = d_tensor_in_p->mutable_data();
if (cudaMemcpy(d_data_p, if (cudaMemcpy(d_data_p,
static_cast<float *>(input.data.data), static_cast<float *>(input.data.data()),
d_tensor_in_p->valid_size() * sizeof(float), d_tensor_in_p->valid_size() * sizeof(float),
cudaMemcpyHostToDevice) != 0) { cudaMemcpyHostToDevice) != 0) {
LOG(ERROR) << "copy data from CPU to GPU error"; LOG(ERROR) << "copy data from CPU to GPU error";
...@@ -65,8 +65,11 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -65,8 +65,11 @@ bool PaddleInferenceAnakinPredictor::Run(
for (auto &output : *output_data) { for (auto &output : *output_data) {
auto *tensor = executor_.get_out(output.name); auto *tensor = executor_.get_out(output.name);
output.shape = tensor->shape(); output.shape = tensor->shape();
if (output.data.length() < tensor->valid_size() * sizeof(float)) {
output.data.Resize(tensor->valid_size() * sizeof(float));
}
// Copy data from GPU -> CPU // Copy data from GPU -> CPU
if (cudaMemcpy(output.data.data, if (cudaMemcpy(output.data.data(),
tensor->mutable_data(), tensor->mutable_data(),
tensor->valid_size() * sizeof(float), tensor->valid_size() * sizeof(float),
cudaMemcpyDeviceToHost) != 0) { cudaMemcpyDeviceToHost) != 0) {
......
...@@ -37,28 +37,26 @@ TEST(inference, anakin) { ...@@ -37,28 +37,26 @@ TEST(inference, anakin) {
float data[1 * 3 * 224 * 224] = {1.0f}; float data[1 * 3 * 224 * 224] = {1.0f};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "input_0", PaddleTensor tensor{.name = "input_0",
.shape = std::vector<int>({1, 3, 224, 224}), .shape = std::vector<int>({1, 3, 224, 224}),
.data = buf, .data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::FLOAT32}; .dtype = PaddleDType::FLOAT32};
// For simplicity, we set all the slots with the same data. // For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor); std::vector<PaddleTensor> paddle_tensor_feeds;
paddle_tensor_feeds.emplace_back(std::move(tensor));
float data_out[1000];
PaddleBuf buf_out{.data = data_out, .length = sizeof(data)};
PaddleTensor tensor_out{.name = "prob_out", PaddleTensor tensor_out{.name = "prob_out",
.shape = std::vector<int>({1000, 1}), .shape = std::vector<int>({1000, 1}),
.data = buf_out, .data = PaddleBuf(),
.dtype = PaddleDType::FLOAT32}; .dtype = PaddleDType::FLOAT32};
std::vector<PaddleTensor> outputs(1, tensor_out); std::vector<PaddleTensor> outputs;
outputs.emplace_back(std::move(tensor_out));
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
float* data_o = static_cast<float*>(outputs[0].data.data); float* data_o = static_cast<float*>(outputs[0].data.data());
for (size_t j = 0; j < 1000; ++j) { for (size_t j = 0; j < 1000; ++j) {
LOG(INFO) << "output[" << j << "]: " << data_o[j]; LOG(INFO) << "output[" << j << "]: " << data_o[j];
} }
......
...@@ -178,8 +178,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -178,8 +178,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(static_cast<void *>(input_ptr), std::memcpy(static_cast<void *>(input_ptr),
inputs[i].data.data, inputs[i].data.data(),
inputs[i].data.length); inputs[i].data.length());
feeds->push_back(input); feeds->push_back(input);
} }
return true; return true;
...@@ -241,10 +241,11 @@ bool NativePaddlePredictor::GetFetch( ...@@ -241,10 +241,11 @@ bool NativePaddlePredictor::GetFetch(
} }
outputs->at(i).shape = shape; outputs->at(i).shape = shape;
outputs->at(i).data.length = sizeof(float) * data.size(); auto &buffer = outputs->at(i).data;
outputs->at(i).data.data = malloc(outputs->at(i).data.length); if (buffer.empty() || buffer.length() < sizeof(float) * data.size()) {
std::memcpy( buffer.Resize(sizeof(float) * data.size());
outputs->at(i).data.data, data.data(), outputs->at(i).data.length); }
std::memcpy(buffer.data(), data.data(), buffer.length());
outputs->at(i).dtype = PaddleDType::FLOAT32; outputs->at(i).dtype = PaddleDType::FLOAT32;
// TODO(panyx0718): support other types? fill tensor name? avoid a copy. // TODO(panyx0718): support other types? fill tensor name? avoid a copy.
} }
......
...@@ -27,13 +27,12 @@ namespace paddle { ...@@ -27,13 +27,12 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt; PaddleTensor pt;
pt.data.data = t->data<void>();
if (t->type() == typeid(int64_t)) { if (t->type() == typeid(int64_t)) {
pt.data.length = t->numel() * sizeof(int64_t); pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) { } else if (t->type() == typeid(float)) {
pt.data.length = t->numel() * sizeof(float); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } else {
LOG(FATAL) << "unsupported type."; LOG(FATAL) << "unsupported type.";
...@@ -79,8 +78,8 @@ void MainWord2Vec(bool use_gpu) { ...@@ -79,8 +78,8 @@ void MainWord2Vec(bool use_gpu) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
size_t len = outputs[0].data.length; size_t len = outputs[0].data.length();
float* data = static_cast<float*>(outputs[0].data.data); float* data = static_cast<float*>(outputs[0].data.data());
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0); ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0); ASSERT_GT(data[j], -1.0);
...@@ -103,8 +102,6 @@ void MainWord2Vec(bool use_gpu) { ...@@ -103,8 +102,6 @@ void MainWord2Vec(bool use_gpu) {
EXPECT_LT(lod_data[i] - data[i], 1e-3); EXPECT_LT(lod_data[i] - data[i], 1e-3);
EXPECT_GT(lod_data[i] - data[i], -1e-3); EXPECT_GT(lod_data[i] - data[i], -1e-3);
} }
free(outputs[0].data.data);
} }
void MainImageClassification(bool use_gpu) { void MainImageClassification(bool use_gpu) {
...@@ -143,13 +140,12 @@ void MainImageClassification(bool use_gpu) { ...@@ -143,13 +140,12 @@ void MainImageClassification(bool use_gpu) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
size_t len = outputs[0].data.length; size_t len = outputs[0].data.length();
float* data = static_cast<float*>(outputs[0].data.data); float* data = static_cast<float*>(outputs[0].data.data());
float* lod_data = output1.data<float>(); float* lod_data = output1.data<float>();
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
EXPECT_NEAR(lod_data[j], data[j], 1e-3); EXPECT_NEAR(lod_data[j], data[j], 1e-3);
} }
free(data);
} }
void MainThreadsWord2Vec(bool use_gpu) { void MainThreadsWord2Vec(bool use_gpu) {
...@@ -192,8 +188,8 @@ void MainThreadsWord2Vec(bool use_gpu) { ...@@ -192,8 +188,8 @@ void MainThreadsWord2Vec(bool use_gpu) {
// check outputs range // check outputs range
ASSERT_EQ(local_outputs.size(), 1UL); ASSERT_EQ(local_outputs.size(), 1UL);
const size_t len = local_outputs[0].data.length; const size_t len = local_outputs[0].data.length();
float* data = static_cast<float*>(local_outputs[0].data.data); float* data = static_cast<float*>(local_outputs[0].data.data());
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0); ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0); ASSERT_GT(data[j], -1.0);
...@@ -205,7 +201,6 @@ void MainThreadsWord2Vec(bool use_gpu) { ...@@ -205,7 +201,6 @@ void MainThreadsWord2Vec(bool use_gpu) {
for (int i = 0; i < refs[tid].numel(); ++i) { for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3); EXPECT_NEAR(ref_data[i], data[i], 1e-3);
} }
free(data);
}); });
} }
for (int i = 0; i < num_jobs; ++i) { for (int i = 0; i < num_jobs; ++i) {
...@@ -251,14 +246,13 @@ void MainThreadsImageClassification(bool use_gpu) { ...@@ -251,14 +246,13 @@ void MainThreadsImageClassification(bool use_gpu) {
// check outputs correctness // check outputs correctness
ASSERT_EQ(local_outputs.size(), 1UL); ASSERT_EQ(local_outputs.size(), 1UL);
const size_t len = local_outputs[0].data.length; const size_t len = local_outputs[0].data.length();
float* data = static_cast<float*>(local_outputs[0].data.data); float* data = static_cast<float*>(local_outputs[0].data.data());
float* ref_data = refs[tid].data<float>(); float* ref_data = refs[tid].data<float>();
EXPECT_EQ(refs[tid].numel(), len / sizeof(float)); EXPECT_EQ(refs[tid].numel(), len / sizeof(float));
for (int i = 0; i < refs[tid].numel(); ++i) { for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3); EXPECT_NEAR(ref_data[i], data[i], 1e-3);
} }
free(data);
}); });
} }
for (int i = 0; i < num_jobs; ++i) { for (int i = 0; i < num_jobs; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册