提交 fc94f765 编写于 作者: W wangjiawei04

replace feed_alias_name with feed_name

上级 eeded12a
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <unordered_map>
#include <utility>
#include "core/cube/cube-api/include/cube_api.h" #include "core/cube/cube-api/include/cube_api.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h" #include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h" #include "core/util/include/timer.h"
#include <utility>
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -56,7 +57,7 @@ int GeneralDistKVInferOp::inference() { ...@@ -56,7 +57,7 @@ int GeneralDistKVInferOp::inference() {
std::vector<rec::mcube::CubeValue> values; std::vector<rec::mcube::CubeValue> values;
int sparse_count = 0; int sparse_count = 0;
int dense_count = 0; int dense_count = 0;
std::vector<std::pair<int64_t*, size_t>> dataptr_size_pairs; std::vector<std::pair<int64_t *, size_t>> dataptr_size_pairs;
size_t key_len = 0; size_t key_len = 0;
for (size_t i = 0; i < in->size(); ++i) { for (size_t i = 0; i < in->size(); ++i) {
if (in->at(i).dtype != paddle::PaddleDType::INT64) { if (in->at(i).dtype != paddle::PaddleDType::INT64) {
...@@ -75,15 +76,16 @@ int GeneralDistKVInferOp::inference() { ...@@ -75,15 +76,16 @@ int GeneralDistKVInferOp::inference() {
keys.resize(key_len); keys.resize(key_len);
int key_idx = 0; int key_idx = 0;
for (size_t i = 0; i < dataptr_size_pairs.size(); ++i) { for (size_t i = 0; i < dataptr_size_pairs.size(); ++i) {
std::copy(dataptr_size_pairs[i].first, dataptr_size_pairs[i].first + dataptr_size_pairs[i].second, keys.begin() + key_idx); std::copy(dataptr_size_pairs[i].first,
key_idx += dataptr_size_pairs[i].second; dataptr_size_pairs[i].first + dataptr_size_pairs[i].second,
keys.begin() + key_idx);
key_idx += dataptr_size_pairs[i].second;
} }
rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance(); rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance();
// TODO: temp hard code "test_dict" here, fix this with next commit
// related to cube conf
std::vector<std::string> table_names = cube->get_table_names(); std::vector<std::string> table_names = cube->get_table_names();
if (table_names.size() == 0) { if (table_names.size() == 0) {
LOG(ERROR) << "cube init error or cube config not given."; LOG(ERROR) << "cube init error or cube config not given.";
return -1;
} }
int ret = cube->seek(table_names[0], keys, &values); int ret = cube->seek(table_names[0], keys, &values);
...@@ -121,7 +123,7 @@ int GeneralDistKVInferOp::inference() { ...@@ -121,7 +123,7 @@ int GeneralDistKVInferOp::inference() {
sparse_out[sparse_idx].shape.push_back( sparse_out[sparse_idx].shape.push_back(
sparse_out[sparse_idx].lod[0].back()); sparse_out[sparse_idx].lod[0].back());
sparse_out[sparse_idx].shape.push_back(EMBEDDING_SIZE); sparse_out[sparse_idx].shape.push_back(EMBEDDING_SIZE);
sparse_out[sparse_idx].name = model_config->_feed_alias_name[i]; sparse_out[sparse_idx].name = model_config->_feed_name[i];
sparse_out[sparse_idx].data.Resize(sparse_out[sparse_idx].lod[0].back() * sparse_out[sparse_idx].data.Resize(sparse_out[sparse_idx].lod[0].back() *
EMBEDDING_SIZE * sizeof(float)); EMBEDDING_SIZE * sizeof(float));
float *dst_ptr = static_cast<float *>(sparse_out[sparse_idx].data.data()); float *dst_ptr = static_cast<float *>(sparse_out[sparse_idx].data.data());
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册