提交 1e1202b6 编写于 作者: W wanghaox

merge detection.py

...@@ -60,7 +60,7 @@ option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) ...@@ -60,7 +60,7 @@ option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF) option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" ON) option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
......
...@@ -65,6 +65,7 @@ ...@@ -65,6 +65,7 @@
output_file = "output.paddle.model" output_file = "output.paddle.model"
merge_v2_model(net, param_file, output_file) merge_v2_model(net, param_file, output_file)
``` ```
对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)这个示例,可直接运行 `python` [merge_v2_model.py](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense/merge_v2_model.py)。序列化结果会写入当前运行目录下的`output.paddle.model`文件中。使用这种方式,运行时C-API可以通过指定`output.paddle.model`文件的路径来加载预测模型。 对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)这个示例,可直接运行 `python` [merge_v2_model.py](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense/merge_v2_model.py)。序列化结果会写入当前运行目录下的`output.paddle.model`文件中。使用这种方式,运行时C-API可以通过指定`output.paddle.model`文件的路径来加载预测模型。
#### 注意事项 #### 注意事项
......
...@@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { ...@@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) { } else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>(); var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::NCCL_COM) { } else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in ncclInit // GetMutable will be called in operator
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type %d is not in " "Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]", "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type); var_type);
} }
} }
......
...@@ -113,7 +113,10 @@ message VarType { ...@@ -113,7 +113,10 @@ message VarType {
PLACE_LIST = 14; PLACE_LIST = 14;
READER = 15; READER = 15;
CHANNEL = 16; CHANNEL = 16;
NCCL_COM = 17; // Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW = 17;
} }
required Type type = 1; required Type type = 1;
......
...@@ -31,8 +31,14 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { ...@@ -31,8 +31,14 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
os << "{"; os << "{";
for (auto &v : lod) { for (auto &v : lod) {
os << "{"; os << "{";
bool is_first = true;
for (auto &i : v) { for (auto &i : v) {
os << i << ","; if (is_first) {
os << i;
is_first = false;
} else {
os << ", " << i;
}
} }
os << "}"; os << "}";
} }
......
...@@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) { ...@@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) {
inputfs.close(); inputfs.close();
} }
bool IsParameter(const framework::VarDesc* var, bool IsPersistable(const framework::VarDesc* var) {
const framework::ProgramDesc& main_program) { if (var->Persistable() &&
if (var->Persistable()) { var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
// There are many unreachable variables in the program var->GetType() != framework::proto::VarType::FETCH_LIST) {
for (size_t i = 0; i < main_program.Size(); ++i) { return true;
const framework::BlockDesc& block = main_program.Block(i);
for (auto* op : block.AllOps()) {
if (op->Type() == framework::kFeedOpType) {
continue;
}
for (auto input_argument_name : op->InputArgumentNames()) {
if (input_argument_name == var->Name()) {
return true;
}
}
}
}
} }
return false; return false;
} }
...@@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor, ...@@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor,
std::vector<std::string> paramlist; std::vector<std::string> paramlist;
for (auto* var : global_block.AllVars()) { for (auto* var : global_block.AllVars()) {
if (IsParameter(var, main_program)) { if (IsPersistable(var)) {
VLOG(3) << "parameter's name: " << var->Name(); VLOG(3) << "persistable variable's name: " << var->Name();
framework::VarDesc* new_var = load_block->Var(var->Name()); framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->GetShape()); new_var->SetShape(var->GetShape());
...@@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor, ...@@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor,
executor.Run(*load_program, &scope, 0, true, true); executor.Run(*load_program, &scope, 0, true, true);
VLOG(3) << "Ran loading successfully";
delete load_program; delete load_program;
} }
......
...@@ -30,5 +30,5 @@ inference_test(label_semantic_roles) ...@@ -30,5 +30,5 @@ inference_test(label_semantic_roles)
inference_test(recognize_digits ARGS mlp conv) inference_test(recognize_digits ARGS mlp conv)
inference_test(recommender_system) inference_test(recommender_system)
#inference_test(rnn_encoder_decoder) #inference_test(rnn_encoder_decoder)
inference_test(understand_sentiment) inference_test(understand_sentiment ARGS conv)
inference_test(word2vec) inference_test(word2vec)
...@@ -32,16 +32,42 @@ TEST(inference, label_semantic_roles) { ...@@ -32,16 +32,42 @@ TEST(inference, label_semantic_roles) {
paddle::framework::LoDTensor word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, paddle::framework::LoDTensor word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1,
ctx_p2, mark; ctx_p2, mark;
paddle::framework::LoD lod{{0, 4, 10}}; paddle::framework::LoD lod{{0, 4, 10}};
int64_t word_dict_len = 44068;
SetupLoDTensor(word, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); int64_t predicate_dict_len = 3162;
SetupLoDTensor( int64_t mark_dict_len = 2;
predicate, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_n2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); SetupLoDTensor(word,
SetupLoDTensor(ctx_n1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); lod,
SetupLoDTensor(ctx_0, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); static_cast<int64_t>(0),
SetupLoDTensor(ctx_p1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_p2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); SetupLoDTensor(predicate,
SetupLoDTensor(mark, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); lod,
static_cast<int64_t>(0),
static_cast<int64_t>(predicate_dict_len - 1));
SetupLoDTensor(ctx_n2,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_n1,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_0,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_p1,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_p2,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(mark,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(mark_dict_len - 1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds; std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&word); cpu_feeds.push_back(&word);
......
...@@ -31,7 +31,12 @@ TEST(inference, understand_sentiment) { ...@@ -31,7 +31,12 @@ TEST(inference, understand_sentiment) {
paddle::framework::LoDTensor words; paddle::framework::LoDTensor words;
paddle::framework::LoD lod{{0, 4, 10}}; paddle::framework::LoD lod{{0, 4, 10}};
SetupLoDTensor(words, lod, static_cast<int64_t>(0), static_cast<int64_t>(10)); int64_t word_dict_len = 5147;
SetupLoDTensor(words,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds; std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&words); cpu_feeds.push_back(&words);
......
...@@ -31,12 +31,12 @@ TEST(inference, word2vec) { ...@@ -31,12 +31,12 @@ TEST(inference, word2vec) {
paddle::framework::LoDTensor first_word, second_word, third_word, fourth_word; paddle::framework::LoDTensor first_word, second_word, third_word, fourth_word;
paddle::framework::LoD lod{{0, 1}}; paddle::framework::LoD lod{{0, 1}};
int64_t dict_size = 2072; // Hard-coding the size of dictionary int64_t dict_size = 2073; // The size of dictionary
SetupLoDTensor(first_word, lod, static_cast<int64_t>(0), dict_size); SetupLoDTensor(first_word, lod, static_cast<int64_t>(0), dict_size - 1);
SetupLoDTensor(second_word, lod, static_cast<int64_t>(0), dict_size); SetupLoDTensor(second_word, lod, static_cast<int64_t>(0), dict_size - 1);
SetupLoDTensor(third_word, lod, static_cast<int64_t>(0), dict_size); SetupLoDTensor(third_word, lod, static_cast<int64_t>(0), dict_size - 1);
SetupLoDTensor(fourth_word, lod, static_cast<int64_t>(0), dict_size); SetupLoDTensor(fourth_word, lod, static_cast<int64_t>(0), dict_size - 1);
std::vector<paddle::framework::LoDTensor*> cpu_feeds; std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&first_word); cpu_feeds.push_back(&first_word);
......
...@@ -101,8 +101,8 @@ void TestInference(const std::string& dirname, ...@@ -101,8 +101,8 @@ void TestInference(const std::string& dirname,
if (IsCombined) { if (IsCombined) {
// All parameters are saved in a single file. // All parameters are saved in a single file.
// Hard-coding the file names of program and parameters in unittest. // Hard-coding the file names of program and parameters in unittest.
// Users are free to specify different filename // The file names should be consistent with that used in Python API
// (provided: the filenames are changed in the python api as well: io.py) // `fluid.io.save_inference_model`.
std::string prog_filename = "__model_combined__"; std::string prog_filename = "__model_combined__";
std::string param_filename = "__params_combined__"; std::string param_filename = "__params_combined__";
inference_program = paddle::inference::Load(executor, inference_program = paddle::inference::Load(executor,
......
...@@ -11,6 +11,8 @@ function(op_library TARGET) ...@@ -11,6 +11,8 @@ function(op_library TARGET)
set(cc_srcs) set(cc_srcs)
set(cu_srcs) set(cu_srcs)
set(cu_cc_srcs) set(cu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(CUDNN_FILE)
set(op_common_deps operator op_registry math_function) set(op_common_deps operator op_registry math_function)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
...@@ -30,10 +32,16 @@ function(op_library TARGET) ...@@ -30,10 +32,16 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu)
endif() endif()
string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
endif()
else() else()
foreach(src ${op_library_SRCS}) foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.cu$") if (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src}) list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
list(APPEND cudnn_cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$") elseif(${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src}) list(APPEND cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$") elseif(${src} MATCHES ".*\\.cc$")
...@@ -54,7 +62,7 @@ function(op_library TARGET) ...@@ -54,7 +62,7 @@ function(op_library TARGET)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
endif() endif()
if (WITH_GPU) if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps}) ${op_common_deps})
else() else()
cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
...@@ -98,6 +106,12 @@ function(op_library TARGET) ...@@ -98,6 +106,12 @@ function(op_library TARGET)
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
# pybind USE_OP # pybind USE_OP
if (${pybind_flag} EQUAL 0) if (${pybind_flag} EQUAL 0)
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
...@@ -152,43 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) ...@@ -152,43 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(lstmp_op DEPS sequence2batch lstm_compute) op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor) op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(create_reader_op DEPS reader) op_library(create_reader_op DEPS reader)
# Regist multiple Kernel to pybind # Regist multiple Kernel to pybind
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv)
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
vol2col depthwise_conv)
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n")
else() else()
op_library(conv_op SRCS conv_op.cc DEPS vol2col) op_library(conv_op DEPS vol2col)
op_library(pool_op SRCS pool_op.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
endif() endif()
op_library(pool_op DEPS pooling)
op_library(conv_transpose_op DEPS vol2col)
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry) cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
op_library(fill_constant_batch_size_like_op DEPS batch_size_like)
op_library(fill_constant_batch_size_like_op op_library(uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op)
SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc op_library(gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op)
DEPS batch_size_like)
op_library(uniform_random_batch_size_like_op
SRCS uniform_random_batch_size_like_op.cc
DEPS batch_size_like uniform_random_op)
op_library(gaussian_random_batch_size_like_op
SRCS gaussian_random_batch_size_like_op.cc
DEPS batch_size_like gaussian_random_op)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor) op_library(save_op DEPS lod_tensor)
......
...@@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} }
} }
void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist,
T overlap_threshold) const {
constexpr T kEPS = static_cast<T>(1e-6);
int64_t row = dist.dims()[0];
int64_t col = dist.dims()[1];
auto* dist_data = dist.data<T>();
for (int64_t j = 0; j < col; ++j) {
if (match_indices[j] != -1) {
// the j-th column has been matched to one entity.
continue;
}
int max_row_idx = -1;
T max_dist = -1;
for (int i = 0; i < row; ++i) {
T dist = dist_data[i * col + j];
if (dist < kEPS) {
// distance is 0 between m-th row and j-th column
continue;
}
if (dist >= overlap_threshold && dist > max_dist) {
max_row_idx = i;
max_dist = dist;
}
}
if (max_row_idx != -1) {
PADDLE_ENFORCE_EQ(match_indices[j], -1);
match_indices[j] = max_row_idx;
match_dist[j] = max_dist;
}
}
}
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dist_mat = context.Input<LoDTensor>("DistMat"); auto* dist_mat = context.Input<LoDTensor>("DistMat");
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices"); auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
...@@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
int* indices = match_indices->data<int>(); int* indices = match_indices->data<int>();
T* dist = match_dist->data<T>(); T* dist = match_dist->data<T>();
auto type = context.Attr<std::string>("match_type");
auto threshold = context.Attr<float>("dist_threshold");
if (n == 1) { if (n == 1) {
BipartiteMatch(*dist_mat, indices, dist); BipartiteMatch(*dist_mat, indices, dist);
if (type == "per_prediction") {
ArgMaxMatch(*dist_mat, indices, dist, threshold);
}
} else { } else {
auto lod = dist_mat->lod().back(); auto lod = dist_mat->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) { for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
BipartiteMatch(one_ins, indices + i * col, dist + i * col); BipartiteMatch(one_ins, indices + i * col, dist + i * col);
if (type == "per_prediction") {
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
}
} }
} }
} }
...@@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"This tensor can contain LoD information to represent a batch of " "This tensor can contain LoD information to represent a batch of "
"inputs. One instance of this batch can contain different numbers of " "inputs. One instance of this batch can contain different numbers of "
"entities."); "entities.");
AddAttr<std::string>(
"match_type",
"(string, defalut: per_prediction) "
"The type of matching method, should be 'bipartite' or "
"'per_prediction', 'bipartite' by defalut.")
.SetDefault("bipartite")
.InEnum({"bipartite", "per_prediction"});
AddAttr<float>(
"dist_threshold",
"(float, defalut: 0.5) "
"If `match_type` is 'per_prediction', this threshold is to determine "
"the extra matching bboxes based on the maximum distance.")
.SetDefault(0.5);
AddOutput("ColToRowMatchIndices", AddOutput("ColToRowMatchIndices",
"(Tensor) A 2-D Tensor with shape [N, M] in int type. " "(Tensor) A 2-D Tensor with shape [N, M] in int type. "
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it " "N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
...@@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can ...@@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can
find the matched column for each row, also can find the matched row for find the matched column for each row, also can find the matched row for
each column. And this operator only calculate matched indices from column each column. And this operator only calculate matched indices from column
to row. For each instance, the number of matched indices is the number of to row. For each instance, the number of matched indices is the number of
of columns of the input ditance matrix. of columns of the input distance matrix.
There are two outputs to save matched indices and distance. There are two outputs to save matched indices and distance.
A simple description, this algothrim matched the best (maximum distance) A simple description, this algorithm matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices. any row entity, set -1 in ColToRowMatchIndices.
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
...@@ -34,12 +35,46 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -34,12 +35,46 @@ class ConcatKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
size_t output_offset = 0; size_t output_offset = 0;
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims()); // If axis >=1, copy to out immediately need to call many times
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, // of cuda memcpy. Copy the input to cpu and do the stride copy,
out->data<T>() + output_offset, out_stride, // then copy to gpu output.
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis]; if (platform::is_gpu_place(place) && axis >= 1) {
platform::CPUPlace copy_place;
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
framework::Tensor cpu_out;
cpu_out.Resize(out->dims());
cpu_out.mutable_data<T>(copy_place);
auto& dev_ctx = ctx.device_context();
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
for (auto* in : ins) {
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
cpu_ins.emplace_back(std::move(cpu_in));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx.Wait();
for (auto& in : cpu_ins) {
auto& cpu_in = *in.get();
auto in_stride = framework::stride_numel(cpu_in.dims());
StridedNumelCopyWithAxis<T>(
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
cpu_in.data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
framework::TensorCopy(cpu_out, place, dev_ctx, out);
} else {
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} }
} }
}; };
......
...@@ -177,8 +177,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -177,8 +177,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
args.SetMaxSendMessageSize(std::numeric_limits<int>::max()); args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto ch = std::shared_ptr<grpc::Channel>( auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args)); grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch; channels_[ep] = ch;
return ch; return ch;
......
...@@ -129,6 +129,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -129,6 +129,8 @@ class ListenAndServOp : public framework::OperatorBase {
} }
if (exit_flag) { if (exit_flag) {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
rpc_service_->SetCond(1);
break;
} }
try { try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
......
...@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference { ...@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front(); auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::NCCL_COM; auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type); out_var.SetType(var_type);
} }
}; };
......
...@@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server.
} }
}; };
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker); REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference,
ops::SendOpShapeInference);
...@@ -95,7 +95,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -95,7 +95,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
auto var = block->Var(v); auto var = block->Var(v);
var->SetDataType(f::proto::DataType::FP32); var->SetDataType(f::proto::VarType::FP32);
} }
} }
...@@ -122,33 +122,37 @@ void StartServerNet(bool is_sparse) { ...@@ -122,33 +122,37 @@ void StartServerNet(bool is_sparse) {
// sub program run in listen_and_serv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
f::BlockDesc *block = program.MutableBlock(0); f::BlockDesc *optimize_block = program.MutableBlock(0);
// X for server side tensors, RX for received tensers, must be of same shape. // X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
attrs.insert({"Fanin", 1});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", block}); attrs.insert({"OptimizeBlock", optimize_block});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
listen_and_serv_op->Run(scope, place); listen_and_serv_op->Run(scope, place);
} }
TEST(SendRecvOp, CPUDense) { TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false); std::thread server_thread(StartServerNet, false);
sleep(10); // wait server to start sleep(5); // wait server to start
// local net // local net
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
InitTensorsInScope(scope, place); InitTensorsInScope(scope, place);
// create rpc client var
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})});
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, auto send_op = f::OpRegistry::CreateOp(
{{"Out", {"Out"}}}, attrs); "send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto in_var = scope.Var("x1"); auto in_var = scope.Var("x1");
...@@ -175,11 +179,13 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -175,11 +179,13 @@ TEST(SendRecvOp, CPUSparse) {
p::CPUPlace place; p::CPUPlace place;
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(scope, place); InitSelectedRowsInScope(scope, place);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})});
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, auto send_op = f::OpRegistry::CreateOp(
{{"Out", {"Out"}}}, attrs); "send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>(); auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
......
...@@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) { ...@@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) {
.value("CHANNEL", proto::VarType::CHANNEL) .value("CHANNEL", proto::VarType::CHANNEL)
.value("PLACE_LIST", proto::VarType::PLACE_LIST) .value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER) .value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarType::NCCL_COM); .value("RAW", proto::VarType::RAW);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -49,6 +49,7 @@ function cmake_gen() { ...@@ -49,6 +49,7 @@ function cmake_gen() {
-DCUDNN_ROOT=/usr/ -DCUDNN_ROOT=/usr/
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON}
-DWITH_TESTING=${WITH_TESTING:-ON} -DWITH_TESTING=${WITH_TESTING:-ON}
-DWITH_FAST_BUNDLE_TEST=ON
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON
======================================== ========================================
EOF EOF
...@@ -72,6 +73,7 @@ EOF ...@@ -72,6 +73,7 @@ EOF
-DCUDNN_ROOT=/usr/ \ -DCUDNN_ROOT=/usr/ \
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \
-DWITH_TESTING=${WITH_TESTING:-ON} \ -DWITH_TESTING=${WITH_TESTING:-ON} \
-DWITH_FAST_BUNDLE_TEST=ON \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON
} }
......
...@@ -226,8 +226,7 @@ class DistributeTranspiler: ...@@ -226,8 +226,7 @@ class DistributeTranspiler:
rpc_client_var = program.global_block().create_var( rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR", name="RPC_CLIENT_VAR",
persistable=True, persistable=True,
dtype='float32', # dtype and shape is not used in fact type=core.VarDesc.VarType.RAW)
shape=[0])
# create send_op # create send_op
program.global_block().append_op( program.global_block().append_op(
......
...@@ -68,7 +68,7 @@ def save_vars(executor, ...@@ -68,7 +68,7 @@ def save_vars(executor,
main_program=None, main_program=None,
vars=None, vars=None,
predicate=None, predicate=None,
save_file_name=None): filename=None):
""" """
Save variables to directory by executor. Save variables to directory by executor.
...@@ -80,8 +80,8 @@ def save_vars(executor, ...@@ -80,8 +80,8 @@ def save_vars(executor,
as a bool. If it returns true, the corresponding input variable will be saved. as a bool. If it returns true, the corresponding input variable will be saved.
:param vars: variables need to be saved. If vars is specified, program & predicate :param vars: variables need to be saved. If vars is specified, program & predicate
will be ignored will be ignored
:param save_file_name: The name of a single file that all vars are saved to. :param filename: The name of a single file that all vars are saved to.
If it is None, save variables to separate files. If it is None, save variables to separate files.
:return: None :return: None
""" """
...@@ -95,7 +95,7 @@ def save_vars(executor, ...@@ -95,7 +95,7 @@ def save_vars(executor,
executor, executor,
dirname=dirname, dirname=dirname,
vars=filter(predicate, main_program.list_vars()), vars=filter(predicate, main_program.list_vars()),
save_file_name=save_file_name) filename=filename)
else: else:
save_program = Program() save_program = Program()
save_block = save_program.global_block() save_block = save_program.global_block()
...@@ -103,7 +103,7 @@ def save_vars(executor, ...@@ -103,7 +103,7 @@ def save_vars(executor,
save_var_map = {} save_var_map = {}
for each_var in vars: for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var) new_var = _clone_var_in_block_(save_block, each_var)
if save_file_name is None: if filename is None:
save_block.append_op( save_block.append_op(
type='save', type='save',
inputs={'X': [new_var]}, inputs={'X': [new_var]},
...@@ -112,7 +112,7 @@ def save_vars(executor, ...@@ -112,7 +112,7 @@ def save_vars(executor,
else: else:
save_var_map[new_var.name] = new_var save_var_map[new_var.name] = new_var
if save_file_name is not None: if filename is not None:
save_var_list = [] save_var_list = []
for name in sorted(save_var_map.keys()): for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name]) save_var_list.append(save_var_map[name])
...@@ -121,12 +121,12 @@ def save_vars(executor, ...@@ -121,12 +121,12 @@ def save_vars(executor,
type='save_combine', type='save_combine',
inputs={'X': save_var_list}, inputs={'X': save_var_list},
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, save_file_name)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(save_program) executor.run(save_program)
def save_params(executor, dirname, main_program=None, save_file_name=None): def save_params(executor, dirname, main_program=None, filename=None):
""" """
Save all parameters to directory with executor. Save all parameters to directory with executor.
""" """
...@@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None): ...@@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None):
main_program=main_program, main_program=main_program,
vars=None, vars=None,
predicate=is_parameter, predicate=is_parameter,
save_file_name=save_file_name) filename=filename)
def save_persistables(executor, dirname, main_program=None, def save_persistables(executor, dirname, main_program=None, filename=None):
save_file_name=None):
""" """
Save all persistables to directory with executor. Save all persistables to directory with executor.
""" """
...@@ -150,7 +149,7 @@ def save_persistables(executor, dirname, main_program=None, ...@@ -150,7 +149,7 @@ def save_persistables(executor, dirname, main_program=None,
main_program=main_program, main_program=main_program,
vars=None, vars=None,
predicate=is_persistable, predicate=is_persistable,
save_file_name=save_file_name) filename=filename)
def load_vars(executor, def load_vars(executor,
...@@ -158,7 +157,7 @@ def load_vars(executor, ...@@ -158,7 +157,7 @@ def load_vars(executor,
main_program=None, main_program=None,
vars=None, vars=None,
predicate=None, predicate=None,
load_file_name=None): filename=None):
""" """
Load variables from directory by executor. Load variables from directory by executor.
...@@ -170,8 +169,8 @@ def load_vars(executor, ...@@ -170,8 +169,8 @@ def load_vars(executor,
as a bool. If it returns true, the corresponding input variable will be loaded. as a bool. If it returns true, the corresponding input variable will be loaded.
:param vars: variables need to be loaded. If vars is specified, program & :param vars: variables need to be loaded. If vars is specified, program &
predicate will be ignored predicate will be ignored
:param load_file_name: The name of the single file that all vars are loaded from. :param filename: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files. If it is None, load variables from separate files.
:return: None :return: None
""" """
...@@ -185,7 +184,7 @@ def load_vars(executor, ...@@ -185,7 +184,7 @@ def load_vars(executor,
executor, executor,
dirname=dirname, dirname=dirname,
vars=filter(predicate, main_program.list_vars()), vars=filter(predicate, main_program.list_vars()),
load_file_name=load_file_name) filename=filename)
else: else:
load_prog = Program() load_prog = Program()
load_block = load_prog.global_block() load_block = load_prog.global_block()
...@@ -194,7 +193,7 @@ def load_vars(executor, ...@@ -194,7 +193,7 @@ def load_vars(executor,
for each_var in vars: for each_var in vars:
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var) new_var = _clone_var_in_block_(load_block, each_var)
if load_file_name is None: if filename is None:
load_block.append_op( load_block.append_op(
type='load', type='load',
inputs={}, inputs={},
...@@ -203,7 +202,7 @@ def load_vars(executor, ...@@ -203,7 +202,7 @@ def load_vars(executor,
else: else:
load_var_map[new_var.name] = new_var load_var_map[new_var.name] = new_var
if load_file_name is not None: if filename is not None:
load_var_list = [] load_var_list = []
for name in sorted(load_var_map.keys()): for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name]) load_var_list.append(load_var_map[name])
...@@ -212,12 +211,12 @@ def load_vars(executor, ...@@ -212,12 +211,12 @@ def load_vars(executor,
type='load_combine', type='load_combine',
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, load_file_name)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
def load_params(executor, dirname, main_program=None, load_file_name=None): def load_params(executor, dirname, main_program=None, filename=None):
""" """
load all parameters from directory by executor. load all parameters from directory by executor.
""" """
...@@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None): ...@@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None):
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
predicate=is_parameter, predicate=is_parameter,
load_file_name=load_file_name) filename=filename)
def load_persistables(executor, dirname, main_program=None, def load_persistables(executor, dirname, main_program=None, filename=None):
load_file_name=None):
""" """
load all persistables from directory by executor. load all persistables from directory by executor.
""" """
...@@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None, ...@@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None,
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
predicate=is_persistable, predicate=is_persistable,
load_file_name=load_file_name) filename=filename)
def get_inference_program(target_vars, main_program=None): def get_inference_program(target_vars, main_program=None):
...@@ -299,7 +297,8 @@ def save_inference_model(dirname, ...@@ -299,7 +297,8 @@ def save_inference_model(dirname,
target_vars, target_vars,
executor, executor,
main_program=None, main_program=None,
save_file_name=None): model_filename=None,
params_filename=None):
""" """
Build a model especially for inference, Build a model especially for inference,
and save it to directory by the executor. and save it to directory by the executor.
...@@ -310,8 +309,11 @@ def save_inference_model(dirname, ...@@ -310,8 +309,11 @@ def save_inference_model(dirname,
:param executor: executor that save inference model :param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model. :param main_program: original program, which will be pruned to build the inference model.
Default default_main_program(). Default default_main_program().
:param save_file_name: The name of a single file that all parameters are saved to. :param model_filename: The name of file to save inference program.
If it is None, save parameters to separate files. If not specified, default filename `__model__` will be used.
:param params_filename: The name of file to save parameters.
It is used for the case that all parameters are saved in a single binary file.
If not specified, parameters are considered saved in separate files.
:return: None :return: None
""" """
...@@ -342,15 +344,19 @@ def save_inference_model(dirname, ...@@ -342,15 +344,19 @@ def save_inference_model(dirname,
prepend_feed_ops(inference_program, feeded_var_names) prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names) append_fetch_ops(inference_program, fetch_var_names)
if save_file_name == None: if model_filename is not None:
model_file_name = dirname + "/__model__" model_filename = os.path.basename(model_filename)
else: else:
model_file_name = dirname + "/__model_combined__" model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename)
with open(model_file_name, "wb") as f: if params_filename is not None:
params_filename = os.path.basename(params_filename)
with open(model_filename, "wb") as f:
f.write(inference_program.desc.serialize_to_string()) f.write(inference_program.desc.serialize_to_string())
save_persistables(executor, dirname, inference_program, save_file_name) save_persistables(executor, dirname, inference_program, params_filename)
def get_feed_targets_names(program): def get_feed_targets_names(program):
...@@ -371,15 +377,21 @@ def get_fetch_targets_names(program): ...@@ -371,15 +377,21 @@ def get_fetch_targets_names(program):
return fetch_targets_names return fetch_targets_names
def load_inference_model(dirname, executor, load_file_name=None): def load_inference_model(dirname,
executor,
model_filename=None,
params_filename=None):
""" """
Load inference model from a directory Load inference model from a directory
:param dirname: directory path :param dirname: directory path
:param executor: executor that load inference model :param executor: executor that load inference model
:param load_file_name: The name of the single file that all parameters are loaded from. :param model_filename: The name of file to load inference program.
If it is None, load parameters from separate files. If not specified, default filename `__model__` will be used.
:param params_filename: The name of file to load parameters.
It is used for the case that all parameters are saved in a single binary file.
If not specified, parameters are considered saved in separate files.
:return: [program, feed_target_names, fetch_targets] :return: [program, feed_target_names, fetch_targets]
program: program especially for inference. program: program especially for inference.
feed_target_names: Names of variables that need to feed data feed_target_names: Names of variables that need to feed data
...@@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None): ...@@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None):
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
if load_file_name == None: if model_filename is not None:
model_file_name = dirname + "/__model__" model_filename = os.path.basename(model_filename)
else: else:
model_file_name = dirname + "/__model_combined__" model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename)
if params_filename is not None:
params_filename = os.path.basename(params_filename)
with open(model_file_name, "rb") as f: with open(model_filename, "rb") as f:
program_desc_str = f.read() program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, load_file_name) load_persistables(executor, dirname, program, params_filename)
feed_target_names = get_feed_targets_names(program) feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program) fetch_target_names = get_fetch_targets_names(program)
......
...@@ -172,7 +172,10 @@ def detection_map(detect_res, ...@@ -172,7 +172,10 @@ def detection_map(detect_res,
return map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out return map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out
def bipartite_match(dist_matrix, name=None): def bipartite_match(dist_matrix,
match_type=None,
dist_threshold=None,
name=None):
""" """
**Bipartite matchint operator** **Bipartite matchint operator**
...@@ -204,6 +207,11 @@ def bipartite_match(dist_matrix, name=None): ...@@ -204,6 +207,11 @@ def bipartite_match(dist_matrix, name=None):
This tensor can contain LoD information to represent a batch of This tensor can contain LoD information to represent a batch of
inputs. One instance of this batch can contain different numbers of inputs. One instance of this batch can contain different numbers of
entities. entities.
match_type(string|None): The type of matching method, should be
'bipartite' or 'per_prediction', 'bipartite' by defalut.
dist_threshold(float|None): If `match_type` is 'per_prediction',
this threshold is to determine the extra matching bboxes based
on the maximum distance, 0.5 by defalut.
Returns: Returns:
match_indices(Variable): A 2-D Tensor with shape [N, M] in int type. match_indices(Variable): A 2-D Tensor with shape [N, M] in int type.
N is the batch size. If match_indices[i][j] is -1, it N is the batch size. If match_indices[i][j] is -1, it
...@@ -223,6 +231,10 @@ def bipartite_match(dist_matrix, name=None): ...@@ -223,6 +231,10 @@ def bipartite_match(dist_matrix, name=None):
helper.append_op( helper.append_op(
type='bipartite_match', type='bipartite_match',
inputs={'DistMat': dist_matrix}, inputs={'DistMat': dist_matrix},
attrs={
'match_type': match_type,
'dist_threshold': dist_threshold,
},
outputs={ outputs={
'ColToRowMatchIndices': match_indices, 'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_distance 'ColToRowMatchDist': match_distance
...@@ -373,7 +385,7 @@ def ssd_loss(location, ...@@ -373,7 +385,7 @@ def ssd_loss(location,
loc_loss_weight (float): Weight for localization loss, 1.0 by default. loc_loss_weight (float): Weight for localization loss, 1.0 by default.
conf_loss_weight (float): Weight for confidence loss, 1.0 by default. conf_loss_weight (float): Weight for confidence loss, 1.0 by default.
match_type (str): The type of matching method during training, should match_type (str): The type of matching method during training, should
be 'bipartite' or 'per_prediction'. be 'bipartite' or 'per_prediction', 'per_prediction' by defalut.
mining_type (str): The hard example mining type, should be 'hard_example' mining_type (str): The hard example mining type, should be 'hard_example'
or 'max_negative', now only support `max_negative`. or 'max_negative', now only support `max_negative`.
...@@ -421,7 +433,8 @@ def ssd_loss(location, ...@@ -421,7 +433,8 @@ def ssd_loss(location,
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
iou = iou_similarity(x=gt_box, y=prior_box) iou = iou_similarity(x=gt_box, y=prior_box)
# 1.2 Compute matched boundding box by bipartite matching algorithm. # 1.2 Compute matched boundding box by bipartite matching algorithm.
matched_indices, matched_dist = bipartite_match(iou) matched_indices, matched_dist = bipartite_match(iou, match_type,
overlap_threshold)
# 2. Compute confidence for mining hard examples # 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices # 2.1. Get the target label based on matched indices
......
...@@ -21,6 +21,7 @@ from ..framework import Variable ...@@ -21,6 +21,7 @@ from ..framework import Variable
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from layer_function_generator import autodoc from layer_function_generator import autodoc
from tensor import concat from tensor import concat
import utils
__all__ = [ __all__ = [
'fc', 'fc',
...@@ -1138,8 +1139,8 @@ def sequence_conv(input, ...@@ -1138,8 +1139,8 @@ def sequence_conv(input,
def conv2d(input, def conv2d(input,
num_filters, num_filters,
filter_size, filter_size,
stride=None, stride=1,
padding=None, padding=0,
groups=None, groups=None,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
...@@ -1252,12 +1253,10 @@ def conv2d(input, ...@@ -1252,12 +1253,10 @@ def conv2d(input,
raise ValueError("num_channels must be divisible by groups.") raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups num_filter_channels = num_channels / groups
if isinstance(filter_size, int): filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
filter_size = [filter_size, filter_size] stride = utils.convert_to_list(stride, 2, 'stride')
if isinstance(stride, int): padding = utils.convert_to_list(padding, 2, 'padding')
stride = [stride, stride]
if isinstance(padding, int):
padding = [padding, padding]
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
...@@ -1432,10 +1431,10 @@ def sequence_last_step(input): ...@@ -1432,10 +1431,10 @@ def sequence_last_step(input):
def pool2d(input, def pool2d(input,
pool_size, pool_size=-1,
pool_type, pool_type="max",
pool_stride=None, pool_stride=1,
pool_padding=None, pool_padding=0,
global_pooling=False, global_pooling=False,
use_cudnn=True, use_cudnn=True,
name=None): name=None):
...@@ -1443,20 +1442,20 @@ def pool2d(input, ...@@ -1443,20 +1442,20 @@ def pool2d(input,
This function adds the operator for pooling in 2 dimensions, using the This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters. pooling configurations mentioned in input parameters.
""" """
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type)) str(pool_type))
if isinstance(pool_size, int):
pool_size = [pool_size, pool_size] if global_pooling is False and pool_size == -1:
if isinstance(pool_stride, int): raise ValueError(
pool_stride = [pool_stride, pool_stride] "When the global_pooling is False, pool_size must be passed "
if isinstance(pool_padding, int): "and be a valid value. Received pool_size: " + str(pool_size))
pool_padding = [pool_padding, pool_padding]
pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding')
pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
...@@ -1685,9 +1684,9 @@ def conv2d_transpose(input, ...@@ -1685,9 +1684,9 @@ def conv2d_transpose(input,
num_filters, num_filters,
output_size=None, output_size=None,
filter_size=None, filter_size=None,
padding=None, padding=0,
stride=None, stride=1,
dilation=None, dilation=1,
param_attr=None, param_attr=None,
use_cudnn=True, use_cudnn=True,
name=None): name=None):
...@@ -1783,26 +1782,12 @@ def conv2d_transpose(input, ...@@ -1783,26 +1782,12 @@ def conv2d_transpose(input,
raise TypeError("Input of conv2d_transpose must be Variable") raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1] input_channel = input.shape[1]
op_attr = dict() padding = utils.convert_to_list(padding, 2, 'padding')
stride = utils.convert_to_list(stride, 2, 'stride')
if isinstance(padding, int): dilation = utils.convert_to_list(dilation, 2, 'dilation')
op_attr['paddings'] = [padding, padding]
elif padding is not None:
op_attr['paddings'] = padding
if isinstance(stride, int):
op_attr['strides'] = [stride, stride]
elif stride is not None:
op_attr['strides'] = stride
if isinstance(dilation, int):
op_attr['dilations'] = [dilation, dilation]
elif dilation is not None:
op_attr['dilations'] = dilation
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
op_attr['use_cudnn'] = use_cudnn
if filter_size is None: if filter_size is None:
if output_size is None: if output_size is None:
...@@ -1810,10 +1795,6 @@ def conv2d_transpose(input, ...@@ -1810,10 +1795,6 @@ def conv2d_transpose(input,
if isinstance(output_size, int): if isinstance(output_size, int):
output_size = [output_size, output_size] output_size = [output_size, output_size]
padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1])
dilation = op_attr.get('dilations', [1, 1])
h_in = input.shape[2] h_in = input.shape[2]
w_in = input.shape[3] w_in = input.shape[3]
...@@ -1822,9 +1803,9 @@ def conv2d_transpose(input, ...@@ -1822,9 +1803,9 @@ def conv2d_transpose(input,
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1 padding[1] - 1) / dilation[1] + 1
filter_size = [filter_size_h, filter_size_w] filter_size = [filter_size_h, filter_size_w]
else:
elif isinstance(filter_size, int): filter_size = utils.convert_to_list(filter_size, 2,
filter_size = [filter_size, filter_size] 'conv2d_transpose.filter_size')
filter_shape = [input_channel, num_filters] + filter_size filter_shape = [input_channel, num_filters] + filter_size
img_filter = helper.create_parameter( img_filter = helper.create_parameter(
...@@ -1836,7 +1817,12 @@ def conv2d_transpose(input, ...@@ -1836,7 +1817,12 @@ def conv2d_transpose(input,
inputs={'Input': [input], inputs={'Input': [input],
'Filter': [img_filter]}, 'Filter': [img_filter]},
outputs={'Output': out}, outputs={'Output': out},
attrs=op_attr) attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'use_cudnn': use_cudnn
})
return out return out
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def convert_to_list(value, n, name, dtype=np.int):
"""
Converts a single numerical type or iterable of numerical
types into an numerical type list.
Arguments:
value: The value to validate and convert. Could an int, or any iterable
of ints.
n: The size of the list to be returned.
name: The name of the argument being validated, e.g. "stride" or
"filter_size". This is only used to format error messages.
dtype: the numerical type of the element of the list to be returned.
Returns:
A list of n dtypes.
Raises:
ValueError: If something else than an int/long or iterable thereof was
passed.
"""
if isinstance(value, dtype):
return [value, ] * n
else:
try:
value_list = list(value)
except TypeError:
raise ValueError("The " + name +
"'s type must be list or tuple. Received: " + str(
value))
if len(value_list) != n:
raise ValueError("The " + name + "'s length must be " + str(n) +
". Received: " + str(value))
for single_value in value_list:
try:
dtype(single_value)
except (ValueError, TypeError):
raise ValueError(
"The " + name + "'s type must be a list or tuple of " + str(
n) + " " + str(dtype) + " . Received: " + str(
value) + " "
"including element " + str(single_value) + " of type" + " "
+ str(type(single_value)))
return value_list
...@@ -228,32 +228,34 @@ def infer(use_cuda, save_dirname=None): ...@@ -228,32 +228,34 @@ def infer(use_cuda, save_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc, inference_scope = fluid.core.Scope()
# the feed_target_names (the names of variables that will be feeded with fluid.scope_guard(inference_scope):
# data using feed operators), and the fetch_targets (variables that # Use fluid.io.load_inference_model to obtain the inference program desc,
# we want to obtain data from using fetch operators). # the feed_target_names (the names of variables that will be feeded
[inference_program, feed_target_names, # data using feed operators), and the fetch_targets (variables that
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) # we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
lod = [0, 4, 10] fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
word_data = create_random_lodtensor(lod, place, low=0, high=1)
trg_word = create_random_lodtensor(lod, place, low=0, high=1) lod = [0, 4, 10]
word_data = create_random_lodtensor(lod, place, low=0, high=1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data} trg_word = create_random_lodtensor(lod, place, low=0, high=1)
# and results will contain a list of data corresponding to fetch_targets.
assert feed_target_names[0] == 'source_sequence' # Construct feed as a dictionary of {feed_target_name: feed_target_data}
assert feed_target_names[1] == 'target_sequence' # and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program, assert feed_target_names[0] == 'source_sequence'
feed={ assert feed_target_names[1] == 'target_sequence'
feed_target_names[0]: word_data, results = exe.run(inference_program,
feed_target_names[1]: trg_word, feed={
}, feed_target_names[0]: word_data,
fetch_list=fetch_targets, feed_target_names[1]: trg_word,
return_numpy=False) },
print(results[0].lod()) fetch_list=fetch_targets,
np_data = np.array(results[0]) return_numpy=False)
print("Inference shape: ", np_data.shape) print(results[0].lod())
print("Inference results: ", np_data) np_data = np.array(results[0])
print("Inference shape: ", np_data.shape)
print("Inference results: ", np_data)
def main(use_cuda): def main(use_cuda):
......
...@@ -72,23 +72,26 @@ def infer(use_cuda, save_dirname=None): ...@@ -72,23 +72,26 @@ def infer(use_cuda, save_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc, inference_scope = fluid.core.Scope()
# the feed_target_names (the names of variables that will be feeded with fluid.scope_guard(inference_scope):
# data using feed operators), and the fetch_targets (variables that # Use fluid.io.load_inference_model to obtain the inference program desc,
# we want to obtain data from using fetch operators). # the feed_target_names (the names of variables that will be feeded
[inference_program, feed_target_names, # data using feed operators), and the fetch_targets (variables that
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) # we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
# The input's dimension should be 2-D and the second dim is 13 fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
# The input data should be >= 0
batch_size = 10 # The input's dimension should be 2-D and the second dim is 13
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32") # The input data should be >= 0
assert feed_target_names[0] == 'x' batch_size = 10
results = exe.run(inference_program, tensor_x = numpy.random.uniform(0, 10,
feed={feed_target_names[0]: tensor_x}, [batch_size, 13]).astype("float32")
fetch_list=fetch_targets) assert feed_target_names[0] == 'x'
print("infer shape: ", results[0].shape) results = exe.run(inference_program,
print("infer results: ", results[0]) feed={feed_target_names[0]: tensor_x},
fetch_list=fetch_targets)
print("infer shape: ", results[0].shape)
print("infer results: ", results[0])
def main(use_cuda): def main(use_cuda):
......
...@@ -174,22 +174,26 @@ def infer(use_cuda, save_dirname=None): ...@@ -174,22 +174,26 @@ def infer(use_cuda, save_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc, inference_scope = fluid.core.Scope()
# the feed_target_names (the names of variables that will be feeded with fluid.scope_guard(inference_scope):
# data using feed operators), and the fetch_targets (variables that # Use fluid.io.load_inference_model to obtain the inference program desc,
# we want to obtain data from using fetch operators). # the feed_target_names (the names of variables that will be feeded
[inference_program, feed_target_names, # data using feed operators), and the fetch_targets (variables that
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) # we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
# The input's dimension of conv should be 4-D or 5-D. fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
tensor_img = numpy.random.rand(1, 3, 32, 32).astype("float32")
# The input's dimension of conv should be 4-D or 5-D.
# Construct feed as a dictionary of {feed_target_name: feed_target_data} # Use normilized image pixels as input data, which should be in the range [0, 1.0].
# and results will contain a list of data corresponding to fetch_targets. batch_size = 1
results = exe.run(inference_program, tensor_img = numpy.random.rand(batch_size, 3, 32, 32).astype("float32")
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets) # Construct feed as a dictionary of {feed_target_name: feed_target_data}
print("infer results: ", results[0]) # and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
print("infer results: ", results[0])
def main(net_type, use_cuda): def main(net_type, use_cuda):
......
...@@ -26,7 +26,7 @@ import unittest ...@@ -26,7 +26,7 @@ import unittest
word_dict, verb_dict, label_dict = conll05.get_dict() word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict) word_dict_len = len(word_dict)
label_dict_len = len(label_dict) label_dict_len = len(label_dict)
pred_len = len(verb_dict) pred_dict_len = len(verb_dict)
mark_dict_len = 2 mark_dict_len = 2
word_dim = 32 word_dim = 32
...@@ -53,7 +53,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -53,7 +53,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
# 8 features # 8 features
predicate_embedding = fluid.layers.embedding( predicate_embedding = fluid.layers.embedding(
input=predicate, input=predicate,
size=[pred_len, word_dim], size=[pred_dict_len, word_dim],
dtype='float32', dtype='float32',
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE,
param_attr='vemb') param_attr='vemb')
...@@ -234,6 +234,7 @@ def train(use_cuda, save_dirname=None): ...@@ -234,6 +234,7 @@ def train(use_cuda, save_dirname=None):
# Set the threshold low to speed up the CI test # Set the threshold low to speed up the CI test
if float(pass_precision) > 0.05: if float(pass_precision) > 0.05:
if save_dirname is not None: if save_dirname is not None:
# TODO(liuyiqun): Change the target to crf_decode
fluid.io.save_inference_model(save_dirname, [ fluid.io.save_inference_model(save_dirname, [
'word_data', 'verb_data', 'ctx_n2_data', 'word_data', 'verb_data', 'ctx_n2_data',
'ctx_n1_data', 'ctx_0_data', 'ctx_p1_data', 'ctx_n1_data', 'ctx_0_data', 'ctx_p1_data',
...@@ -251,51 +252,60 @@ def infer(use_cuda, save_dirname=None): ...@@ -251,51 +252,60 @@ def infer(use_cuda, save_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc, inference_scope = fluid.core.Scope()
# the feed_target_names (the names of variables that will be feeded with fluid.scope_guard(inference_scope):
# data using feed operators), and the fetch_targets (variables that # Use fluid.io.load_inference_model to obtain the inference program desc,
# we want to obtain data from using fetch operators). # the feed_target_names (the names of variables that will be feeded
[inference_program, feed_target_names, # data using feed operators), and the fetch_targets (variables that
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) # we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
lod = [0, 4, 10] fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
ts_word = create_random_lodtensor(lod, place, low=0, high=1)
ts_pred = create_random_lodtensor(lod, place, low=0, high=1) lod = [0, 4, 10]
ts_ctx_n2 = create_random_lodtensor(lod, place, low=0, high=1) word = create_random_lodtensor(
ts_ctx_n1 = create_random_lodtensor(lod, place, low=0, high=1) lod, place, low=0, high=word_dict_len - 1)
ts_ctx_0 = create_random_lodtensor(lod, place, low=0, high=1) pred = create_random_lodtensor(
ts_ctx_p1 = create_random_lodtensor(lod, place, low=0, high=1) lod, place, low=0, high=pred_dict_len - 1)
ts_ctx_p2 = create_random_lodtensor(lod, place, low=0, high=1) ctx_n2 = create_random_lodtensor(
ts_mark = create_random_lodtensor(lod, place, low=0, high=1) lod, place, low=0, high=word_dict_len - 1)
ctx_n1 = create_random_lodtensor(
# Construct feed as a dictionary of {feed_target_name: feed_target_data} lod, place, low=0, high=word_dict_len - 1)
# and results will contain a list of data corresponding to fetch_targets. ctx_0 = create_random_lodtensor(
assert feed_target_names[0] == 'word_data' lod, place, low=0, high=word_dict_len - 1)
assert feed_target_names[1] == 'verb_data' ctx_p1 = create_random_lodtensor(
assert feed_target_names[2] == 'ctx_n2_data' lod, place, low=0, high=word_dict_len - 1)
assert feed_target_names[3] == 'ctx_n1_data' ctx_p2 = create_random_lodtensor(
assert feed_target_names[4] == 'ctx_0_data' lod, place, low=0, high=word_dict_len - 1)
assert feed_target_names[5] == 'ctx_p1_data' mark = create_random_lodtensor(
assert feed_target_names[6] == 'ctx_p2_data' lod, place, low=0, high=mark_dict_len - 1)
assert feed_target_names[7] == 'mark_data'
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
results = exe.run(inference_program, # and results will contain a list of data corresponding to fetch_targets.
feed={ assert feed_target_names[0] == 'word_data'
feed_target_names[0]: ts_word, assert feed_target_names[1] == 'verb_data'
feed_target_names[1]: ts_pred, assert feed_target_names[2] == 'ctx_n2_data'
feed_target_names[2]: ts_ctx_n2, assert feed_target_names[3] == 'ctx_n1_data'
feed_target_names[3]: ts_ctx_n1, assert feed_target_names[4] == 'ctx_0_data'
feed_target_names[4]: ts_ctx_0, assert feed_target_names[5] == 'ctx_p1_data'
feed_target_names[5]: ts_ctx_p1, assert feed_target_names[6] == 'ctx_p2_data'
feed_target_names[6]: ts_ctx_p2, assert feed_target_names[7] == 'mark_data'
feed_target_names[7]: ts_mark
}, results = exe.run(inference_program,
fetch_list=fetch_targets, feed={
return_numpy=False) feed_target_names[0]: word,
print(results[0].lod()) feed_target_names[1]: pred,
np_data = np.array(results[0]) feed_target_names[2]: ctx_n2,
print("Inference Shape: ", np_data.shape) feed_target_names[3]: ctx_n1,
print("Inference results: ", np_data) feed_target_names[4]: ctx_0,
feed_target_names[5]: ctx_p1,
feed_target_names[6]: ctx_p2,
feed_target_names[7]: mark
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
def main(use_cuda): def main(use_cuda):
......
...@@ -78,7 +78,12 @@ def conv_net(img, label): ...@@ -78,7 +78,12 @@ def conv_net(img, label):
return loss_net(conv_pool_2, label) return loss_net(conv_pool_2, label)
def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): def train(nn_type,
use_cuda,
parallel,
save_dirname=None,
model_filename=None,
params_filename=None):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
...@@ -146,7 +151,8 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): ...@@ -146,7 +151,8 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
fluid.io.save_inference_model( fluid.io.save_inference_model(
save_dirname, ["img"], [prediction], save_dirname, ["img"], [prediction],
exe, exe,
save_file_name=save_param_filename) model_filename=model_filename,
params_filename=params_filename)
return return
else: else:
print( print(
...@@ -158,54 +164,62 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): ...@@ -158,54 +164,62 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
raise AssertionError("Loss of recognize digits is too large") raise AssertionError("Loss of recognize digits is too large")
def infer(use_cuda, save_dirname=None, param_filename=None): def infer(use_cuda,
save_dirname=None,
model_filename=None,
params_filename=None):
if save_dirname is None: if save_dirname is None:
return return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc, inference_scope = fluid.core.Scope()
# the feed_target_names (the names of variables that will be feeded with fluid.scope_guard(inference_scope):
# data using feed operators), and the fetch_targets (variables that # Use fluid.io.load_inference_model to obtain the inference program desc,
# we want to obtain data from using fetch operators). # the feed_target_names (the names of variables that will be feeded
[inference_program, feed_target_names, fetch_targets # data using feed operators), and the fetch_targets (variables that
] = fluid.io.load_inference_model(save_dirname, exe, param_filename) # we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
# The input's dimension of conv should be 4-D or 5-D. fetch_targets] = fluid.io.load_inference_model(
# Use normilized image pixels as input data, which should be in the range [-1.0, 1.0]. save_dirname, exe, model_filename, params_filename)
batch_size = 1
tensor_img = numpy.random.uniform(-1.0, 1.0, # The input's dimension of conv should be 4-D or 5-D.
[batch_size, 1, 28, 28]).astype("float32") # Use normilized image pixels as input data, which should be in the range [-1.0, 1.0].
batch_size = 1
# Construct feed as a dictionary of {feed_target_name: feed_target_data} tensor_img = numpy.random.uniform(
# and results will contain a list of data corresponding to fetch_targets. -1.0, 1.0, [batch_size, 1, 28, 28]).astype("float32")
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img}, # Construct feed as a dictionary of {feed_target_name: feed_target_data}
fetch_list=fetch_targets) # and results will contain a list of data corresponding to fetch_targets.
print("infer results: ", results[0]) results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
print("infer results: ", results[0])
def main(use_cuda, parallel, nn_type, combine): def main(use_cuda, parallel, nn_type, combine):
save_dirname = None
model_filename = None
params_filename = None
if not use_cuda and not parallel: if not use_cuda and not parallel:
save_dirname = "recognize_digits_" + nn_type + ".inference.model" save_dirname = "recognize_digits_" + nn_type + ".inference.model"
save_filename = None
if combine == True: if combine == True:
save_filename = "__params_combined__" model_filename = "__model_combined__"
else: params_filename = "__params_combined__"
save_dirname = None
save_filename = None
train( train(
nn_type=nn_type, nn_type=nn_type,
use_cuda=use_cuda, use_cuda=use_cuda,
parallel=parallel, parallel=parallel,
save_dirname=save_dirname, save_dirname=save_dirname,
save_param_filename=save_filename) model_filename=model_filename,
params_filename=params_filename)
infer( infer(
use_cuda=use_cuda, use_cuda=use_cuda,
save_dirname=save_dirname, save_dirname=save_dirname,
param_filename=save_filename) model_filename=model_filename,
params_filename=params_filename)
class TestRecognizeDigits(unittest.TestCase): class TestRecognizeDigits(unittest.TestCase):
......
...@@ -251,13 +251,6 @@ def infer(use_cuda, save_dirname=None): ...@@ -251,13 +251,6 @@ def infer(use_cuda, save_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
def create_lod_tensor(data, lod=None): def create_lod_tensor(data, lod=None):
tensor = fluid.LoDTensor() tensor = fluid.LoDTensor()
if lod is None: if lod is None:
...@@ -275,44 +268,53 @@ def infer(use_cuda, save_dirname=None): ...@@ -275,44 +268,53 @@ def infer(use_cuda, save_dirname=None):
tensor.set(flattened_data, place) tensor.set(flattened_data, place)
return tensor return tensor
# Use the first data from paddle.dataset.movielens.test() as input inference_scope = fluid.core.Scope()
assert feed_target_names[0] == "user_id" with fluid.scope_guard(inference_scope):
user_id = create_lod_tensor([[1]]) # Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
assert feed_target_names[1] == "gender_id" # data using feed operators), and the fetch_targets (variables that
gender_id = create_lod_tensor([[1]]) # we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
assert feed_target_names[2] == "age_id" fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
age_id = create_lod_tensor([[0]])
# Use the first data from paddle.dataset.movielens.test() as input
assert feed_target_names[3] == "job_id" assert feed_target_names[0] == "user_id"
job_id = create_lod_tensor([[10]]) user_id = create_lod_tensor([[1]])
assert feed_target_names[4] == "movie_id" assert feed_target_names[1] == "gender_id"
movie_id = create_lod_tensor([[783]]) gender_id = create_lod_tensor([[1]])
assert feed_target_names[5] == "category_id" assert feed_target_names[2] == "age_id"
category_id = create_lod_tensor([[10], [8], [9]], [[0, 3]]) age_id = create_lod_tensor([[0]])
assert feed_target_names[6] == "movie_title" assert feed_target_names[3] == "job_id"
movie_title = create_lod_tensor([[1069], [4140], [2923], [710], [988]], job_id = create_lod_tensor([[10]])
[[0, 5]])
assert feed_target_names[4] == "movie_id"
# Construct feed as a dictionary of {feed_target_name: feed_target_data} movie_id = create_lod_tensor([[783]])
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program, assert feed_target_names[5] == "category_id"
feed={ category_id = create_lod_tensor([[10], [8], [9]], [[0, 3]])
feed_target_names[0]: user_id,
feed_target_names[1]: gender_id, assert feed_target_names[6] == "movie_title"
feed_target_names[2]: age_id, movie_title = create_lod_tensor([[1069], [4140], [2923], [710], [988]],
feed_target_names[3]: job_id, [[0, 5]])
feed_target_names[4]: movie_id,
feed_target_names[5]: category_id, # Construct feed as a dictionary of {feed_target_name: feed_target_data}
feed_target_names[6]: movie_title # and results will contain a list of data corresponding to fetch_targets.
}, results = exe.run(inference_program,
fetch_list=fetch_targets, feed={
return_numpy=False) feed_target_names[0]: user_id,
print("inferred score: ", np.array(results[0])) feed_target_names[1]: gender_id,
feed_target_names[2]: age_id,
feed_target_names[3]: job_id,
feed_target_names[4]: movie_id,
feed_target_names[5]: category_id,
feed_target_names[6]: movie_title
},
fetch_list=fetch_targets,
return_numpy=False)
print("inferred score: ", np.array(results[0]))
def main(use_cuda): def main(use_cuda):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -193,36 +193,39 @@ def train(word_dict, net_method, use_cuda, parallel=False, save_dirname=None): ...@@ -193,36 +193,39 @@ def train(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
net_method.__name__)) net_method.__name__))
def infer(use_cuda, save_dirname=None): def infer(word_dict, use_cuda, save_dirname=None):
if save_dirname is None: if save_dirname is None:
return return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc, inference_scope = fluid.core.Scope()
# the feed_target_names (the names of variables that will be feeded with fluid.scope_guard(inference_scope):
# data using feed operators), and the fetch_targets (variables that # Use fluid.io.load_inference_model to obtain the inference program desc,
# we want to obtain data from using fetch operators). # the feed_target_names (the names of variables that will be feeded
[inference_program, feed_target_names, # data using feed operators), and the fetch_targets (variables that
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) # we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
lod = [0, 4, 10] fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
word_dict = paddle.dataset.imdb.word_dict()
tensor_words = create_random_lodtensor( word_dict_len = len(word_dict)
lod, place, low=0, high=len(word_dict) - 1)
lod = [0, 4, 10]
# Construct feed as a dictionary of {feed_target_name: feed_target_data} tensor_words = create_random_lodtensor(
# and results will contain a list of data corresponding to fetch_targets. lod, place, low=0, high=word_dict_len - 1)
assert feed_target_names[0] == "words"
results = exe.run(inference_program, # Construct feed as a dictionary of {feed_target_name: feed_target_data}
feed={feed_target_names[0]: tensor_words}, # and results will contain a list of data corresponding to fetch_targets.
fetch_list=fetch_targets, assert feed_target_names[0] == "words"
return_numpy=False) results = exe.run(inference_program,
print(results[0].lod()) feed={feed_target_names[0]: tensor_words},
np_data = np.array(results[0]) fetch_list=fetch_targets,
print("Inference Shape: ", np_data.shape) return_numpy=False)
print("Inference results: ", np_data) print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)
def main(word_dict, net_method, use_cuda, parallel=False, save_dirname=None): def main(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
...@@ -258,7 +261,7 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -258,7 +261,7 @@ class TestUnderstandSentiment(unittest.TestCase):
self.word_dict, self.word_dict,
net_method=convolution_net, net_method=convolution_net,
use_cuda=False, use_cuda=False,
save_dirname="understand_sentiment.inference.model") save_dirname="understand_sentiment_conv.inference.model")
def test_conv_cpu_parallel(self): def test_conv_cpu_parallel(self):
with self.new_program_scope(): with self.new_program_scope():
...@@ -271,7 +274,11 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -271,7 +274,11 @@ class TestUnderstandSentiment(unittest.TestCase):
@unittest.skip(reason="make CI faster") @unittest.skip(reason="make CI faster")
def test_stacked_lstm_cpu(self): def test_stacked_lstm_cpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=False) main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=False,
save_dirname="understand_sentiment_stacked_lstm.inference.model")
def test_stacked_lstm_cpu_parallel(self): def test_stacked_lstm_cpu_parallel(self):
with self.new_program_scope(): with self.new_program_scope():
...@@ -287,7 +294,7 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -287,7 +294,7 @@ class TestUnderstandSentiment(unittest.TestCase):
self.word_dict, self.word_dict,
net_method=convolution_net, net_method=convolution_net,
use_cuda=True, use_cuda=True,
save_dirname="understand_sentiment.inference.model") save_dirname="understand_sentiment_conv.inference.model")
def test_conv_gpu_parallel(self): def test_conv_gpu_parallel(self):
with self.new_program_scope(): with self.new_program_scope():
...@@ -300,7 +307,11 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -300,7 +307,11 @@ class TestUnderstandSentiment(unittest.TestCase):
@unittest.skip(reason="make CI faster") @unittest.skip(reason="make CI faster")
def test_stacked_lstm_gpu(self): def test_stacked_lstm_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=True) main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=True,
save_dirname="understand_sentiment_stacked_lstm.inference.model")
def test_stacked_lstm_gpu_parallel(self): def test_stacked_lstm_gpu_parallel(self):
with self.new_program_scope(): with self.new_program_scope():
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# # Licensed under the Apache License, Version 2.0 (the "License"); #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
...@@ -21,6 +22,7 @@ import sys ...@@ -21,6 +22,7 @@ import sys
def create_random_lodtensor(lod, place, low, high): def create_random_lodtensor(lod, place, low, high):
# The range of data elements is [low, high]
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64") data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor() res = fluid.LoDTensor()
res.set(data, place) res.set(data, place)
...@@ -28,54 +30,7 @@ def create_random_lodtensor(lod, place, low, high): ...@@ -28,54 +30,7 @@ def create_random_lodtensor(lod, place, low, high):
return res return res
def infer(use_cuda, save_dirname=None): def train(use_cuda, is_sparse, is_parallel, save_dirname):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict) - 1
# Setup input, by creating 4 words, and setting up lod required for
# lookup_table_op
lod = [0, 1]
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
assert feed_target_names[0] == 'firstw'
assert feed_target_names[1] == 'secondw'
assert feed_target_names[2] == 'thirdw'
assert feed_target_names[3] == 'forthw'
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={
feed_target_names[0]: first_word,
feed_target_names[1]: second_word,
feed_target_names[2]: third_word,
feed_target_names[3]: fourth_word
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)
def train(use_cuda, is_sparse, parallel, save_dirname):
PASS_NUM = 100 PASS_NUM = 100
EMBED_SIZE = 32 EMBED_SIZE = 32
HIDDEN_SIZE = 256 HIDDEN_SIZE = 256
...@@ -130,7 +85,7 @@ def train(use_cuda, is_sparse, parallel, save_dirname): ...@@ -130,7 +85,7 @@ def train(use_cuda, is_sparse, parallel, save_dirname):
forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64') forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64')
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64') next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
if not parallel: if not is_parallel:
avg_cost, predict_word = __network__( avg_cost, predict_word = __network__(
[first_word, second_word, third_word, forth_word, next_word]) [first_word, second_word, third_word, forth_word, next_word])
else: else:
...@@ -176,11 +131,67 @@ def train(use_cuda, is_sparse, parallel, save_dirname): ...@@ -176,11 +131,67 @@ def train(use_cuda, is_sparse, parallel, save_dirname):
raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0])) raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0]))
def main(use_cuda, is_sparse, parallel): def infer(use_cuda, save_dirname=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
# Setup inputs, by creating 4 words, the lod of which should be [0, 1]
lod = [0, 1]
first_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
second_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
third_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
fourth_word = create_random_lodtensor(
lod, place, low=0, high=dict_size - 1)
assert feed_target_names[0] == 'firstw'
assert feed_target_names[1] == 'secondw'
assert feed_target_names[2] == 'thirdw'
assert feed_target_names[3] == 'forthw'
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={
feed_target_names[0]: first_word,
feed_target_names[1]: second_word,
feed_target_names[2]: third_word,
feed_target_names[3]: fourth_word
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
def main(use_cuda, is_sparse, is_parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
save_dirname = "word2vec.inference.model"
train(use_cuda, is_sparse, parallel, save_dirname) if not is_parallel:
save_dirname = "word2vec.inference.model"
else:
save_dirname = None
train(use_cuda, is_sparse, is_parallel, save_dirname)
infer(use_cuda, save_dirname) infer(use_cuda, save_dirname)
...@@ -193,10 +204,10 @@ class W2VTest(unittest.TestCase): ...@@ -193,10 +204,10 @@ class W2VTest(unittest.TestCase):
pass pass
def inject_test_method(use_cuda, is_sparse, parallel): def inject_test_method(use_cuda, is_sparse, is_parallel):
fn_name = "test_{0}_{1}_{2}".format("cuda" if use_cuda else "cpu", "sparse" fn_name = "test_{0}_{1}_{2}".format("cuda" if use_cuda else "cpu", "sparse"
if is_sparse else "dense", "parallel" if is_sparse else "dense", "parallel"
if parallel else "normal") if is_parallel else "normal")
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
prog = fluid.Program() prog = fluid.Program()
...@@ -204,10 +215,12 @@ def inject_test_method(use_cuda, is_sparse, parallel): ...@@ -204,10 +215,12 @@ def inject_test_method(use_cuda, is_sparse, parallel):
scope = fluid.core.Scope() scope = fluid.core.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog): with fluid.program_guard(prog, startup_prog):
main(use_cuda=use_cuda, is_sparse=is_sparse, parallel=parallel) main(
use_cuda=use_cuda,
is_sparse=is_sparse,
is_parallel=is_parallel)
# run only 2 cases: use_cuda is either True or False if use_cuda and is_sparse:
if is_sparse == False and parallel == False:
fn = __impl__ fn = __impl__
else: else:
# skip the other test when on CI server # skip the other test when on CI server
...@@ -219,8 +232,8 @@ def inject_test_method(use_cuda, is_sparse, parallel): ...@@ -219,8 +232,8 @@ def inject_test_method(use_cuda, is_sparse, parallel):
for use_cuda in (False, True): for use_cuda in (False, True):
for is_sparse in (False, True): for is_sparse in (False, True):
for parallel in (False, True): for is_parallel in (False, True):
inject_test_method(use_cuda, is_sparse, parallel) inject_test_method(use_cuda, is_sparse, is_parallel)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist): ...@@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist):
idx += 1 idx += 1
def batch_bipartite_match(distance, lod): def argmax_match(distance, match_indices, match_dist, threshold):
r, c = distance.shape
for j in xrange(c):
if match_indices[j] != -1:
continue
col_dist = distance[:, j]
indices = np.argwhere(col_dist >= threshold).flatten()
if len(indices) < 1:
continue
match_indices[j] = indices[np.argmax(col_dist[indices])]
match_dist[j] = col_dist[match_indices[j]]
def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
"""Bipartite Matching algorithm for batch input. """Bipartite Matching algorithm for batch input.
Arg: Arg:
distance (numpy.array) : The distance of two entries with shape [M, N]. distance (numpy.array) : The distance of two entries with shape [M, N].
...@@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod): ...@@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod):
for i in range(len(lod) - 1): for i in range(len(lod) - 1):
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
match_dist[i, :]) match_dist[i, :])
if match_type == 'per_prediction':
argmax_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
match_dist[i, :], dist_threshold)
return match_indices, match_dist return match_indices, match_dist
...@@ -71,8 +87,8 @@ class TestBipartiteMatchOpWithLoD(OpTest): ...@@ -71,8 +87,8 @@ class TestBipartiteMatchOpWithLoD(OpTest):
self.inputs = {'DistMat': (dist, lod)} self.inputs = {'DistMat': (dist, lod)}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': (match_dist), 'ColToRowMatchDist': match_dist,
} }
def test_check_output(self): def test_check_output(self):
...@@ -96,5 +112,27 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): ...@@ -96,5 +112,27 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
self.check_output() self.check_output()
class TestBipartiteMatchOpWithPerPredictionType(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[0, 5, 11, 23]]
dist = np.random.random((23, 237)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0],
'per_prediction', 0.5)
self.inputs = {'DistMat': (dist, lod)}
self.outputs = {
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_dist,
}
self.attrs = {
'match_type': 'per_prediction',
'dist_threshold': 0.5,
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册