提交 fc94f765 编写于 作者: W wangjiawei04

replace feed_alias_name with feed_name

上级 eeded12a
......@@ -17,12 +17,13 @@
#include <iostream>
#include <memory>
#include <sstream>
#include <unordered_map>
#include <utility>
#include "core/cube/cube-api/include/cube_api.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
#include <utility>
namespace baidu {
namespace paddle_serving {
......@@ -56,7 +57,7 @@ int GeneralDistKVInferOp::inference() {
std::vector<rec::mcube::CubeValue> values;
int sparse_count = 0;
int dense_count = 0;
std::vector<std::pair<int64_t*, size_t>> dataptr_size_pairs;
std::vector<std::pair<int64_t *, size_t>> dataptr_size_pairs;
size_t key_len = 0;
for (size_t i = 0; i < in->size(); ++i) {
if (in->at(i).dtype != paddle::PaddleDType::INT64) {
......@@ -75,15 +76,16 @@ int GeneralDistKVInferOp::inference() {
keys.resize(key_len);
int key_idx = 0;
for (size_t i = 0; i < dataptr_size_pairs.size(); ++i) {
std::copy(dataptr_size_pairs[i].first, dataptr_size_pairs[i].first + dataptr_size_pairs[i].second, keys.begin() + key_idx);
key_idx += dataptr_size_pairs[i].second;
std::copy(dataptr_size_pairs[i].first,
dataptr_size_pairs[i].first + dataptr_size_pairs[i].second,
keys.begin() + key_idx);
key_idx += dataptr_size_pairs[i].second;
}
rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance();
// TODO: temp hard code "test_dict" here, fix this with next commit
// related to cube conf
std::vector<std::string> table_names = cube->get_table_names();
if (table_names.size() == 0) {
LOG(ERROR) << "cube init error or cube config not given.";
return -1;
}
int ret = cube->seek(table_names[0], keys, &values);
......@@ -121,7 +123,7 @@ int GeneralDistKVInferOp::inference() {
sparse_out[sparse_idx].shape.push_back(
sparse_out[sparse_idx].lod[0].back());
sparse_out[sparse_idx].shape.push_back(EMBEDDING_SIZE);
sparse_out[sparse_idx].name = model_config->_feed_alias_name[i];
sparse_out[sparse_idx].name = model_config->_feed_name[i];
sparse_out[sparse_idx].data.Resize(sparse_out[sparse_idx].lod[0].back() *
EMBEDDING_SIZE * sizeof(float));
float *dst_ptr = static_cast<float *>(sparse_out[sparse_idx].data.data());
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册