diff --git a/CMakeLists.txt b/CMakeLists.txt index 5db5c228be2d6491463ec1ddb17de7bec730bd44..7500e8ed3ca1a93bb7fb4716e98b2660b82ad430 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,7 @@ option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) -option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" ON) +option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF) # CMAKE_BUILD_TYPE if(NOT CMAKE_BUILD_TYPE) diff --git a/doc/howto/capi/workflow_of_capi_cn.md b/doc/howto/capi/workflow_of_capi_cn.md index a61d2267bfdb7c32da528735b20d7c6a531aaa1f..1ccc72eefbc730b2eab2d51f5b04e50728b735d7 100644 --- a/doc/howto/capi/workflow_of_capi_cn.md +++ b/doc/howto/capi/workflow_of_capi_cn.md @@ -65,6 +65,7 @@ output_file = "output.paddle.model" merge_v2_model(net, param_file, output_file) ``` + 对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)这个示例,可直接运行 `python` [merge_v2_model.py](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense/merge_v2_model.py)。序列化结果会写入当前运行目录下的`output.paddle.model`文件中。使用这种方式,运行时C-API可以通过指定`output.paddle.model`文件的路径来加载预测模型。 #### 注意事项 diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0d2691e8115ad6de46dcd4fcd5b7fd79ed60ecb9..88863ab99eb765124bc825b4e9ec9dff890ba3cc 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { var->GetMutable(); } else if (var_type == proto::VarType::CHANNEL) { var->GetMutable(); - } else if (var_type == proto::VarType::NCCL_COM) { - // GetMutable will be called in ncclInit + } else if (var_type == proto::VarType::RAW) { + // GetMutable will be called in operator } else { PADDLE_THROW( "Variable type %d is not in " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]", + "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]", var_type); } } diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 5b43f5a8a4a1c128b04ac206d387e30c55f533fe..53725d3d802c27202a6379cee518991a628cf9a1 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -113,7 +113,10 @@ message VarType { PLACE_LIST = 14; READER = 15; CHANNEL = 16; - NCCL_COM = 17; + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators like nccl_op + RAW = 17; } required Type type = 1; diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 4cf14c8da547d79258e99d0c64e83f9218a92910..e2f4e9cad1996578b7c51257785e1273d126f80f 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -31,8 +31,14 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { os << "{"; for (auto &v : lod) { os << "{"; + bool is_first = true; for (auto &i : v) { - os << i << ","; + if (is_first) { + os << i; + is_first = false; + } else { + os << ", " << i; + } } os << "}"; } diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 71c5ab3db937f70ff84391e98d28f023f6dddcfb..80eb9889670744ae527ea29609b33631a021bfa8 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) { inputfs.close(); } -bool IsParameter(const framework::VarDesc* var, - const framework::ProgramDesc& main_program) { - if (var->Persistable()) { - // There are many unreachable variables in the program - for (size_t i = 0; i < main_program.Size(); ++i) { - const framework::BlockDesc& block = main_program.Block(i); - for (auto* op : block.AllOps()) { - if (op->Type() == framework::kFeedOpType) { - continue; - } - for (auto input_argument_name : op->InputArgumentNames()) { - if (input_argument_name == var->Name()) { - return true; - } - } - } - } +bool IsPersistable(const framework::VarDesc* var) { + if (var->Persistable() && + var->GetType() != framework::proto::VarType::FEED_MINIBATCH && + var->GetType() != framework::proto::VarType::FETCH_LIST) { + return true; } return false; } @@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor, std::vector paramlist; for (auto* var : global_block.AllVars()) { - if (IsParameter(var, main_program)) { - VLOG(3) << "parameter's name: " << var->Name(); + if (IsPersistable(var)) { + VLOG(3) << "persistable variable's name: " << var->Name(); framework::VarDesc* new_var = load_block->Var(var->Name()); new_var->SetShape(var->GetShape()); @@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor, executor.Run(*load_program, &scope, 0, true, true); - VLOG(3) << "Ran loading successfully"; delete load_program; } diff --git a/paddle/fluid/inference/tests/book/CMakeLists.txt b/paddle/fluid/inference/tests/book/CMakeLists.txt index 4ead540e5dd87ccf66168ab29c9d4aeaf6921269..e7ffb00ec8d8926193fe510ebdb7185f75c90906 100644 --- a/paddle/fluid/inference/tests/book/CMakeLists.txt +++ b/paddle/fluid/inference/tests/book/CMakeLists.txt @@ -30,5 +30,5 @@ inference_test(label_semantic_roles) inference_test(recognize_digits ARGS mlp conv) inference_test(recommender_system) #inference_test(rnn_encoder_decoder) -inference_test(understand_sentiment) +inference_test(understand_sentiment ARGS conv) inference_test(word2vec) diff --git a/paddle/fluid/inference/tests/book/test_inference_label_semantic_roles.cc b/paddle/fluid/inference/tests/book/test_inference_label_semantic_roles.cc index 443193aae8b38323883d460bc37a9c14430fc8bb..184924016634bba26204d937744ca5fa87cd443c 100644 --- a/paddle/fluid/inference/tests/book/test_inference_label_semantic_roles.cc +++ b/paddle/fluid/inference/tests/book/test_inference_label_semantic_roles.cc @@ -32,16 +32,42 @@ TEST(inference, label_semantic_roles) { paddle::framework::LoDTensor word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark; paddle::framework::LoD lod{{0, 4, 10}}; - - SetupLoDTensor(word, lod, static_cast(0), static_cast(1)); - SetupLoDTensor( - predicate, lod, static_cast(0), static_cast(1)); - SetupLoDTensor(ctx_n2, lod, static_cast(0), static_cast(1)); - SetupLoDTensor(ctx_n1, lod, static_cast(0), static_cast(1)); - SetupLoDTensor(ctx_0, lod, static_cast(0), static_cast(1)); - SetupLoDTensor(ctx_p1, lod, static_cast(0), static_cast(1)); - SetupLoDTensor(ctx_p2, lod, static_cast(0), static_cast(1)); - SetupLoDTensor(mark, lod, static_cast(0), static_cast(1)); + int64_t word_dict_len = 44068; + int64_t predicate_dict_len = 3162; + int64_t mark_dict_len = 2; + + SetupLoDTensor(word, + lod, + static_cast(0), + static_cast(word_dict_len - 1)); + SetupLoDTensor(predicate, + lod, + static_cast(0), + static_cast(predicate_dict_len - 1)); + SetupLoDTensor(ctx_n2, + lod, + static_cast(0), + static_cast(word_dict_len - 1)); + SetupLoDTensor(ctx_n1, + lod, + static_cast(0), + static_cast(word_dict_len - 1)); + SetupLoDTensor(ctx_0, + lod, + static_cast(0), + static_cast(word_dict_len - 1)); + SetupLoDTensor(ctx_p1, + lod, + static_cast(0), + static_cast(word_dict_len - 1)); + SetupLoDTensor(ctx_p2, + lod, + static_cast(0), + static_cast(word_dict_len - 1)); + SetupLoDTensor(mark, + lod, + static_cast(0), + static_cast(mark_dict_len - 1)); std::vector cpu_feeds; cpu_feeds.push_back(&word); diff --git a/paddle/fluid/inference/tests/book/test_inference_understand_sentiment.cc b/paddle/fluid/inference/tests/book/test_inference_understand_sentiment.cc index e67064fb61d18ff8db540a68e94729649e44cd1a..824b3274ebc7ba046e61798b3f61ef9924a75679 100644 --- a/paddle/fluid/inference/tests/book/test_inference_understand_sentiment.cc +++ b/paddle/fluid/inference/tests/book/test_inference_understand_sentiment.cc @@ -31,7 +31,12 @@ TEST(inference, understand_sentiment) { paddle::framework::LoDTensor words; paddle::framework::LoD lod{{0, 4, 10}}; - SetupLoDTensor(words, lod, static_cast(0), static_cast(10)); + int64_t word_dict_len = 5147; + + SetupLoDTensor(words, + lod, + static_cast(0), + static_cast(word_dict_len - 1)); std::vector cpu_feeds; cpu_feeds.push_back(&words); diff --git a/paddle/fluid/inference/tests/book/test_inference_word2vec.cc b/paddle/fluid/inference/tests/book/test_inference_word2vec.cc index e2f2f36a8222e03f77eca65d6331b4a52c0eea82..1481760c529c29a7290f476e2a22e1ded5ab7787 100644 --- a/paddle/fluid/inference/tests/book/test_inference_word2vec.cc +++ b/paddle/fluid/inference/tests/book/test_inference_word2vec.cc @@ -31,12 +31,12 @@ TEST(inference, word2vec) { paddle::framework::LoDTensor first_word, second_word, third_word, fourth_word; paddle::framework::LoD lod{{0, 1}}; - int64_t dict_size = 2072; // Hard-coding the size of dictionary + int64_t dict_size = 2073; // The size of dictionary - SetupLoDTensor(first_word, lod, static_cast(0), dict_size); - SetupLoDTensor(second_word, lod, static_cast(0), dict_size); - SetupLoDTensor(third_word, lod, static_cast(0), dict_size); - SetupLoDTensor(fourth_word, lod, static_cast(0), dict_size); + SetupLoDTensor(first_word, lod, static_cast(0), dict_size - 1); + SetupLoDTensor(second_word, lod, static_cast(0), dict_size - 1); + SetupLoDTensor(third_word, lod, static_cast(0), dict_size - 1); + SetupLoDTensor(fourth_word, lod, static_cast(0), dict_size - 1); std::vector cpu_feeds; cpu_feeds.push_back(&first_word); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index abe2032cc058e50a63ac72cccd90e060c6e14479..49518e50d8541477234f17ac5b8709aeb57662ff 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -101,8 +101,8 @@ void TestInference(const std::string& dirname, if (IsCombined) { // All parameters are saved in a single file. // Hard-coding the file names of program and parameters in unittest. - // Users are free to specify different filename - // (provided: the filenames are changed in the python api as well: io.py) + // The file names should be consistent with that used in Python API + // `fluid.io.save_inference_model`. std::string prog_filename = "__model_combined__"; std::string param_filename = "__params_combined__"; inference_program = paddle::inference::Load(executor, diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 8f14fd376ae51eff0f56c5a8d679c49cec23bd68..0bda0e05e0c83eeddd9829e5a7a9d3e97585ed64 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -11,6 +11,8 @@ function(op_library TARGET) set(cc_srcs) set(cu_srcs) set(cu_cc_srcs) + set(cudnn_cu_cc_srcs) + set(CUDNN_FILE) set(op_common_deps operator op_registry math_function) set(options "") set(oneValueArgs "") @@ -30,10 +32,16 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() + string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) + list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) + endif() else() foreach(src ${op_library_SRCS}) if (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) + elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") + list(APPEND cudnn_cu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cu.cc$") list(APPEND cu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") @@ -54,7 +62,7 @@ function(op_library TARGET) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) endif() if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) else() cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} @@ -98,6 +106,12 @@ function(op_library TARGET) set(pybind_flag 1) endif() + # pybind USE_OP_DEVICE_KERNEL for CUDNN + list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len) + if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") + endif() + # pybind USE_OP if (${pybind_flag} EQUAL 0) file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") @@ -152,43 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(lstmp_op DEPS sequence2batch lstm_compute) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) -op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function) +op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) op_library(create_reader_op DEPS reader) # Regist multiple Kernel to pybind if (WITH_GPU) - -op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS - vol2col depthwise_conv) - -op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function) -op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling) -op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc - conv_transpose_cudnn_op.cu.cc DEPS vol2col) -file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n") -file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n") -file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n") + op_library(conv_op DEPS vol2col depthwise_conv) else() -op_library(conv_op SRCS conv_op.cc DEPS vol2col) -op_library(pool_op SRCS pool_op.cc DEPS pooling) -op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col) + op_library(conv_op DEPS vol2col) endif() +op_library(pool_op DEPS pooling) +op_library(conv_transpose_op DEPS vol2col) cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry) - -op_library(fill_constant_batch_size_like_op - SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc - DEPS batch_size_like) - -op_library(uniform_random_batch_size_like_op - SRCS uniform_random_batch_size_like_op.cc - DEPS batch_size_like uniform_random_op) - -op_library(gaussian_random_batch_size_like_op - SRCS gaussian_random_batch_size_like_op.cc - DEPS batch_size_like gaussian_random_op) +op_library(fill_constant_batch_size_like_op DEPS batch_size_like) +op_library(uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op) +op_library(gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op) # FIXME(typhoonzero): save/load depends lodtensor serialization functions op_library(save_op DEPS lod_tensor) diff --git a/paddle/fluid/operators/bipartite_match_op.cc b/paddle/fluid/operators/bipartite_match_op.cc index c536cf6b6b822c8d9553d7d2cf57902e5e6e5343..2b3f26c0a890c33f9b4f4c8a5a271123d7ff0b31 100644 --- a/paddle/fluid/operators/bipartite_match_op.cc +++ b/paddle/fluid/operators/bipartite_match_op.cc @@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel { } } + void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist, + T overlap_threshold) const { + constexpr T kEPS = static_cast(1e-6); + int64_t row = dist.dims()[0]; + int64_t col = dist.dims()[1]; + auto* dist_data = dist.data(); + for (int64_t j = 0; j < col; ++j) { + if (match_indices[j] != -1) { + // the j-th column has been matched to one entity. + continue; + } + int max_row_idx = -1; + T max_dist = -1; + for (int i = 0; i < row; ++i) { + T dist = dist_data[i * col + j]; + if (dist < kEPS) { + // distance is 0 between m-th row and j-th column + continue; + } + if (dist >= overlap_threshold && dist > max_dist) { + max_row_idx = i; + max_dist = dist; + } + } + if (max_row_idx != -1) { + PADDLE_ENFORCE_EQ(match_indices[j], -1); + match_indices[j] = max_row_idx; + match_dist[j] = max_dist; + } + } + } + void Compute(const framework::ExecutionContext& context) const override { auto* dist_mat = context.Input("DistMat"); auto* match_indices = context.Output("ColToRowMatchIndices"); @@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel { int* indices = match_indices->data(); T* dist = match_dist->data(); + auto type = context.Attr("match_type"); + auto threshold = context.Attr("dist_threshold"); if (n == 1) { BipartiteMatch(*dist_mat, indices, dist); + if (type == "per_prediction") { + ArgMaxMatch(*dist_mat, indices, dist, threshold); + } } else { auto lod = dist_mat->lod().back(); for (size_t i = 0; i < lod.size() - 1; ++i) { Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); BipartiteMatch(one_ins, indices + i * col, dist + i * col); + if (type == "per_prediction") { + ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold); + } } } } @@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { "This tensor can contain LoD information to represent a batch of " "inputs. One instance of this batch can contain different numbers of " "entities."); + AddAttr( + "match_type", + "(string, defalut: per_prediction) " + "The type of matching method, should be 'bipartite' or " + "'per_prediction', 'bipartite' by defalut.") + .SetDefault("bipartite") + .InEnum({"bipartite", "per_prediction"}); + AddAttr( + "dist_threshold", + "(float, defalut: 0.5) " + "If `match_type` is 'per_prediction', this threshold is to determine " + "the extra matching bboxes based on the maximum distance.") + .SetDefault(0.5); AddOutput("ColToRowMatchIndices", "(Tensor) A 2-D Tensor with shape [N, M] in int type. " "N is the batch size. If ColToRowMatchIndices[i][j] is -1, it " @@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can find the matched column for each row, also can find the matched row for each column. And this operator only calculate matched indices from column to row. For each instance, the number of matched indices is the number of -of columns of the input ditance matrix. +of columns of the input distance matrix. There are two outputs to save matched indices and distance. -A simple description, this algothrim matched the best (maximum distance) +A simple description, this algorithm matched the best (maximum distance) row entity to the column entity and the matched indices are not duplicated in each row of ColToRowMatchIndices. If the column entity is not matched any row entity, set -1 in ColToRowMatchIndices. diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index eb0e43ad2d84f681f39ed4adc5a27f6d3ab00f08..208a4481c6afe1b8f62e8f675c951c3349639f46 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/strided_memcpy.h" @@ -34,12 +35,46 @@ class ConcatKernel : public framework::OpKernel { auto out_stride = framework::stride_numel(out->dims()); size_t output_offset = 0; - for (auto* in : ins) { - auto in_stride = framework::stride_numel(in->dims()); - StridedNumelCopyWithAxis(ctx.device_context(), axis, - out->data() + output_offset, out_stride, - in->data(), in_stride, in_stride[axis]); - output_offset += in_stride[axis]; + + // If axis >=1, copy to out immediately need to call many times + // of cuda memcpy. Copy the input to cpu and do the stride copy, + // then copy to gpu output. + + if (platform::is_gpu_place(place) && axis >= 1) { + platform::CPUPlace copy_place; + auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place); + framework::Tensor cpu_out; + cpu_out.Resize(out->dims()); + cpu_out.mutable_data(copy_place); + auto& dev_ctx = ctx.device_context(); + std::vector> cpu_ins; + for (auto* in : ins) { + std::unique_ptr cpu_in(new framework::Tensor); + framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get()); + cpu_ins.emplace_back(std::move(cpu_in)); + } + // TODO(dzhwinter): overlap copy and compute stream + // https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/ + dev_ctx.Wait(); + + for (auto& in : cpu_ins) { + auto& cpu_in = *in.get(); + auto in_stride = framework::stride_numel(cpu_in.dims()); + + StridedNumelCopyWithAxis( + cpu_ctx, axis, cpu_out.data() + output_offset, out_stride, + cpu_in.data(), in_stride, in_stride[axis]); + output_offset += in_stride[axis]; + } + framework::TensorCopy(cpu_out, place, dev_ctx, out); + } else { + for (auto* in : ins) { + auto in_stride = framework::stride_numel(in->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, + out->data() + output_offset, out_stride, + in->data(), in_stride, in_stride[axis]); + output_offset += in_stride[axis]; + } } } }; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ee9044b1f5d46dc725c9583d0d90ab5681d2850c..7266f3276477891d3c7b6827316a428ef7a31c6e 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -177,8 +177,8 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { args.SetMaxSendMessageSize(std::numeric_limits::max()); args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - auto ch = std::shared_ptr( - grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args)); + auto ch = + grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); channels_[ep] = ch; return ch; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index ee0e3533ce028992af3d4558e3fd198a09c4816b..8e9923c87ce22ed229f78ef15430e50cab16c947 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -129,6 +129,8 @@ class ListenAndServOp : public framework::OperatorBase { } if (exit_flag) { rpc_service_->ShutDown(); + rpc_service_->SetCond(1); + break; } try { executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 0994bba782b42be994ae479f4c9c4de5a2e384ed..9185666c56c4621d42429c9cfdb079001c6336f1 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference { framework::BlockDesc *block) const override { auto out_var_name = op_desc.Output("Communicator").front(); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::NCCL_COM; + auto var_type = framework::proto::VarType::RAW; out_var.SetType(var_type); } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 58850bf566e00f88de19305110e2ef696b73467e..178976f96fdbd08cead7b7c518ea1fbaaa2a5db8 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server. } }; +class SendOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output("RPCClient").front(); + auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class SendOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker); +REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, + ops::SendOpMaker, ops::SendOpVarTypeInference, + ops::SendOpShapeInference); diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index 008c012a32e0c88dfb0c05d7e485ffc367b3cac5..e9fb845b475ff5776bf948ab120a44c16ed87aa0 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -95,7 +95,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, for (auto kv : outputs) { for (auto v : kv.second) { auto var = block->Var(v); - var->SetDataType(f::proto::DataType::FP32); + var->SetDataType(f::proto::VarType::FP32); } } @@ -122,33 +122,37 @@ void StartServerNet(bool is_sparse) { // sub program run in listen_and_serv_op, for simple test we use sum f::ProgramDesc program; - f::BlockDesc *block = program.MutableBlock(0); + f::BlockDesc *optimize_block = program.MutableBlock(0); // X for server side tensors, RX for received tensers, must be of same shape. - AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, block); + AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); f::AttributeMap attrs; attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); + attrs.insert({"Fanin", 1}); attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); - attrs.insert({"OptimizeBlock", block}); + attrs.insert({"OptimizeBlock", optimize_block}); listen_and_serv_op = - f::OpRegistry::CreateOp("listen_and_serv", {}, {}, attrs); + f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); listen_and_serv_op->Run(scope, place); } TEST(SendRecvOp, CPUDense) { std::thread server_thread(StartServerNet, false); - sleep(10); // wait server to start + sleep(5); // wait server to start // local net f::Scope scope; p::CPUPlace place; InitTensorsInScope(scope, place); + // create rpc client var + scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); - auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}}, attrs); + auto send_op = f::OpRegistry::CreateOp( + "send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); send_op->Run(scope, place); auto in_var = scope.Var("x1"); @@ -175,11 +179,13 @@ TEST(SendRecvOp, CPUSparse) { p::CPUPlace place; p::CPUDeviceContext ctx(place); InitSelectedRowsInScope(scope, place); + scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); - auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}}, attrs); + auto send_op = f::OpRegistry::CreateOp( + "send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); send_op->Run(scope, place); auto x0 = scope.Var("x0")->GetMutable(); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index b725be79529c5ccdde12446b5b5c09eaf47550e6..b0a2497d919b65afbe5eeaf4fe47c19baa1aba1c 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) { .value("CHANNEL", proto::VarType::CHANNEL) .value("PLACE_LIST", proto::VarType::PLACE_LIST) .value("READER", proto::VarType::READER) - .value("NCCL_COM", proto::VarType::NCCL_COM); + .value("RAW", proto::VarType::RAW); } void BindOpDesc(py::module &m) { diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 8ec3d0c657400165c2225238f21facfb6c84df7c..2220a593b3bf3658a3bbb272a1ac0dc5a1d24f94 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -49,6 +49,7 @@ function cmake_gen() { -DCUDNN_ROOT=/usr/ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} -DWITH_TESTING=${WITH_TESTING:-ON} + -DWITH_FAST_BUNDLE_TEST=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ======================================== EOF @@ -72,6 +73,7 @@ EOF -DCUDNN_ROOT=/usr/ \ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \ -DWITH_TESTING=${WITH_TESTING:-ON} \ + -DWITH_FAST_BUNDLE_TEST=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON } diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 2fcf3753c5f1211d3b27f38fbdc8d097c437c79a..8da9ca290b22ae69b1fd195d8614c31dc4e13e00 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -226,8 +226,7 @@ class DistributeTranspiler: rpc_client_var = program.global_block().create_var( name="RPC_CLIENT_VAR", persistable=True, - dtype='float32', # dtype and shape is not used in fact - shape=[0]) + type=core.VarDesc.VarType.RAW) # create send_op program.global_block().append_op( diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 33f709ece48c450fbf893855edd59cd687cb0d9d..1817caa94275e4efa47ec1a5a0aa861255c75561 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -68,7 +68,7 @@ def save_vars(executor, main_program=None, vars=None, predicate=None, - save_file_name=None): + filename=None): """ Save variables to directory by executor. @@ -80,8 +80,8 @@ def save_vars(executor, as a bool. If it returns true, the corresponding input variable will be saved. :param vars: variables need to be saved. If vars is specified, program & predicate will be ignored - :param save_file_name: The name of a single file that all vars are saved to. - If it is None, save variables to separate files. + :param filename: The name of a single file that all vars are saved to. + If it is None, save variables to separate files. :return: None """ @@ -95,7 +95,7 @@ def save_vars(executor, executor, dirname=dirname, vars=filter(predicate, main_program.list_vars()), - save_file_name=save_file_name) + filename=filename) else: save_program = Program() save_block = save_program.global_block() @@ -103,7 +103,7 @@ def save_vars(executor, save_var_map = {} for each_var in vars: new_var = _clone_var_in_block_(save_block, each_var) - if save_file_name is None: + if filename is None: save_block.append_op( type='save', inputs={'X': [new_var]}, @@ -112,7 +112,7 @@ def save_vars(executor, else: save_var_map[new_var.name] = new_var - if save_file_name is not None: + if filename is not None: save_var_list = [] for name in sorted(save_var_map.keys()): save_var_list.append(save_var_map[name]) @@ -121,12 +121,12 @@ def save_vars(executor, type='save_combine', inputs={'X': save_var_list}, outputs={}, - attrs={'file_path': os.path.join(dirname, save_file_name)}) + attrs={'file_path': os.path.join(dirname, filename)}) executor.run(save_program) -def save_params(executor, dirname, main_program=None, save_file_name=None): +def save_params(executor, dirname, main_program=None, filename=None): """ Save all parameters to directory with executor. """ @@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None): main_program=main_program, vars=None, predicate=is_parameter, - save_file_name=save_file_name) + filename=filename) -def save_persistables(executor, dirname, main_program=None, - save_file_name=None): +def save_persistables(executor, dirname, main_program=None, filename=None): """ Save all persistables to directory with executor. """ @@ -150,7 +149,7 @@ def save_persistables(executor, dirname, main_program=None, main_program=main_program, vars=None, predicate=is_persistable, - save_file_name=save_file_name) + filename=filename) def load_vars(executor, @@ -158,7 +157,7 @@ def load_vars(executor, main_program=None, vars=None, predicate=None, - load_file_name=None): + filename=None): """ Load variables from directory by executor. @@ -170,8 +169,8 @@ def load_vars(executor, as a bool. If it returns true, the corresponding input variable will be loaded. :param vars: variables need to be loaded. If vars is specified, program & predicate will be ignored - :param load_file_name: The name of the single file that all vars are loaded from. - If it is None, load variables from separate files. + :param filename: The name of the single file that all vars are loaded from. + If it is None, load variables from separate files. :return: None """ @@ -185,7 +184,7 @@ def load_vars(executor, executor, dirname=dirname, vars=filter(predicate, main_program.list_vars()), - load_file_name=load_file_name) + filename=filename) else: load_prog = Program() load_block = load_prog.global_block() @@ -194,7 +193,7 @@ def load_vars(executor, for each_var in vars: assert isinstance(each_var, Variable) new_var = _clone_var_in_block_(load_block, each_var) - if load_file_name is None: + if filename is None: load_block.append_op( type='load', inputs={}, @@ -203,7 +202,7 @@ def load_vars(executor, else: load_var_map[new_var.name] = new_var - if load_file_name is not None: + if filename is not None: load_var_list = [] for name in sorted(load_var_map.keys()): load_var_list.append(load_var_map[name]) @@ -212,12 +211,12 @@ def load_vars(executor, type='load_combine', inputs={}, outputs={"Out": load_var_list}, - attrs={'file_path': os.path.join(dirname, load_file_name)}) + attrs={'file_path': os.path.join(dirname, filename)}) executor.run(load_prog) -def load_params(executor, dirname, main_program=None, load_file_name=None): +def load_params(executor, dirname, main_program=None, filename=None): """ load all parameters from directory by executor. """ @@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None): dirname=dirname, main_program=main_program, predicate=is_parameter, - load_file_name=load_file_name) + filename=filename) -def load_persistables(executor, dirname, main_program=None, - load_file_name=None): +def load_persistables(executor, dirname, main_program=None, filename=None): """ load all persistables from directory by executor. """ @@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None, dirname=dirname, main_program=main_program, predicate=is_persistable, - load_file_name=load_file_name) + filename=filename) def get_inference_program(target_vars, main_program=None): @@ -299,7 +297,8 @@ def save_inference_model(dirname, target_vars, executor, main_program=None, - save_file_name=None): + model_filename=None, + params_filename=None): """ Build a model especially for inference, and save it to directory by the executor. @@ -310,8 +309,11 @@ def save_inference_model(dirname, :param executor: executor that save inference model :param main_program: original program, which will be pruned to build the inference model. Default default_main_program(). - :param save_file_name: The name of a single file that all parameters are saved to. - If it is None, save parameters to separate files. + :param model_filename: The name of file to save inference program. + If not specified, default filename `__model__` will be used. + :param params_filename: The name of file to save parameters. + It is used for the case that all parameters are saved in a single binary file. + If not specified, parameters are considered saved in separate files. :return: None """ @@ -342,15 +344,19 @@ def save_inference_model(dirname, prepend_feed_ops(inference_program, feeded_var_names) append_fetch_ops(inference_program, fetch_var_names) - if save_file_name == None: - model_file_name = dirname + "/__model__" + if model_filename is not None: + model_filename = os.path.basename(model_filename) else: - model_file_name = dirname + "/__model_combined__" + model_filename = "__model__" + model_filename = os.path.join(dirname, model_filename) - with open(model_file_name, "wb") as f: + if params_filename is not None: + params_filename = os.path.basename(params_filename) + + with open(model_filename, "wb") as f: f.write(inference_program.desc.serialize_to_string()) - save_persistables(executor, dirname, inference_program, save_file_name) + save_persistables(executor, dirname, inference_program, params_filename) def get_feed_targets_names(program): @@ -371,15 +377,21 @@ def get_fetch_targets_names(program): return fetch_targets_names -def load_inference_model(dirname, executor, load_file_name=None): +def load_inference_model(dirname, + executor, + model_filename=None, + params_filename=None): """ Load inference model from a directory :param dirname: directory path :param executor: executor that load inference model - :param load_file_name: The name of the single file that all parameters are loaded from. - If it is None, load parameters from separate files. - + :param model_filename: The name of file to load inference program. + If not specified, default filename `__model__` will be used. + :param params_filename: The name of file to load parameters. + It is used for the case that all parameters are saved in a single binary file. + If not specified, parameters are considered saved in separate files. + :return: [program, feed_target_names, fetch_targets] program: program especially for inference. feed_target_names: Names of variables that need to feed data @@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None): if not os.path.isdir(dirname): raise ValueError("There is no directory named '%s'", dirname) - if load_file_name == None: - model_file_name = dirname + "/__model__" + if model_filename is not None: + model_filename = os.path.basename(model_filename) else: - model_file_name = dirname + "/__model_combined__" + model_filename = "__model__" + model_filename = os.path.join(dirname, model_filename) + + if params_filename is not None: + params_filename = os.path.basename(params_filename) - with open(model_file_name, "rb") as f: + with open(model_filename, "rb") as f: program_desc_str = f.read() program = Program.parse_from_string(program_desc_str) - load_persistables(executor, dirname, program, load_file_name) + load_persistables(executor, dirname, program, params_filename) feed_target_names = get_feed_targets_names(program) fetch_target_names = get_fetch_targets_names(program) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 2c2f80b4bfc713647b2faf2f9672b2925cac09b5..fff64a57a43bc3f1ce5806d66e857d033f780620 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -172,7 +172,10 @@ def detection_map(detect_res, return map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out -def bipartite_match(dist_matrix, name=None): +def bipartite_match(dist_matrix, + match_type=None, + dist_threshold=None, + name=None): """ **Bipartite matchint operator** @@ -204,6 +207,11 @@ def bipartite_match(dist_matrix, name=None): This tensor can contain LoD information to represent a batch of inputs. One instance of this batch can contain different numbers of entities. + match_type(string|None): The type of matching method, should be + 'bipartite' or 'per_prediction', 'bipartite' by defalut. + dist_threshold(float|None): If `match_type` is 'per_prediction', + this threshold is to determine the extra matching bboxes based + on the maximum distance, 0.5 by defalut. Returns: match_indices(Variable): A 2-D Tensor with shape [N, M] in int type. N is the batch size. If match_indices[i][j] is -1, it @@ -223,6 +231,10 @@ def bipartite_match(dist_matrix, name=None): helper.append_op( type='bipartite_match', inputs={'DistMat': dist_matrix}, + attrs={ + 'match_type': match_type, + 'dist_threshold': dist_threshold, + }, outputs={ 'ColToRowMatchIndices': match_indices, 'ColToRowMatchDist': match_distance @@ -373,7 +385,7 @@ def ssd_loss(location, loc_loss_weight (float): Weight for localization loss, 1.0 by default. conf_loss_weight (float): Weight for confidence loss, 1.0 by default. match_type (str): The type of matching method during training, should - be 'bipartite' or 'per_prediction'. + be 'bipartite' or 'per_prediction', 'per_prediction' by defalut. mining_type (str): The hard example mining type, should be 'hard_example' or 'max_negative', now only support `max_negative`. @@ -421,7 +433,8 @@ def ssd_loss(location, # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. iou = iou_similarity(x=gt_box, y=prior_box) # 1.2 Compute matched boundding box by bipartite matching algorithm. - matched_indices, matched_dist = bipartite_match(iou) + matched_indices, matched_dist = bipartite_match(iou, match_type, + overlap_threshold) # 2. Compute confidence for mining hard examples # 2.1. Get the target label based on matched indices diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3453dd945d558a93a854f99209a6ea8055875d84..ead7041b7b20c7036bbea3da544f3b422c9f31fa 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -21,6 +21,7 @@ from ..framework import Variable from ..param_attr import ParamAttr from layer_function_generator import autodoc from tensor import concat +import utils __all__ = [ 'fc', @@ -1138,8 +1139,8 @@ def sequence_conv(input, def conv2d(input, num_filters, filter_size, - stride=None, - padding=None, + stride=1, + padding=0, groups=None, param_attr=None, bias_attr=None, @@ -1252,12 +1253,10 @@ def conv2d(input, raise ValueError("num_channels must be divisible by groups.") num_filter_channels = num_channels / groups - if isinstance(filter_size, int): - filter_size = [filter_size, filter_size] - if isinstance(stride, int): - stride = [stride, stride] - if isinstance(padding, int): - padding = [padding, padding] + filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') + stride = utils.convert_to_list(stride, 2, 'stride') + padding = utils.convert_to_list(padding, 2, 'padding') + if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1432,10 +1431,10 @@ def sequence_last_step(input): def pool2d(input, - pool_size, - pool_type, - pool_stride=None, - pool_padding=None, + pool_size=-1, + pool_type="max", + pool_stride=1, + pool_padding=0, global_pooling=False, use_cudnn=True, name=None): @@ -1443,20 +1442,20 @@ def pool2d(input, This function adds the operator for pooling in 2 dimensions, using the pooling configurations mentioned in input parameters. """ - if pool_padding is None: - pool_padding = [0, 0] - if pool_stride is None: - pool_stride = [1, 1] if pool_type not in ["max", "avg"]: raise ValueError( "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", str(pool_type)) - if isinstance(pool_size, int): - pool_size = [pool_size, pool_size] - if isinstance(pool_stride, int): - pool_stride = [pool_stride, pool_stride] - if isinstance(pool_padding, int): - pool_padding = [pool_padding, pool_padding] + + if global_pooling is False and pool_size == -1: + raise ValueError( + "When the global_pooling is False, pool_size must be passed " + "and be a valid value. Received pool_size: " + str(pool_size)) + + pool_size = utils.convert_to_list(pool_size, 2, 'pool_size') + pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding') + pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride') + if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1685,9 +1684,9 @@ def conv2d_transpose(input, num_filters, output_size=None, filter_size=None, - padding=None, - stride=None, - dilation=None, + padding=0, + stride=1, + dilation=1, param_attr=None, use_cudnn=True, name=None): @@ -1783,26 +1782,12 @@ def conv2d_transpose(input, raise TypeError("Input of conv2d_transpose must be Variable") input_channel = input.shape[1] - op_attr = dict() - - if isinstance(padding, int): - op_attr['paddings'] = [padding, padding] - elif padding is not None: - op_attr['paddings'] = padding - - if isinstance(stride, int): - op_attr['strides'] = [stride, stride] - elif stride is not None: - op_attr['strides'] = stride - - if isinstance(dilation, int): - op_attr['dilations'] = [dilation, dilation] - elif dilation is not None: - op_attr['dilations'] = dilation + padding = utils.convert_to_list(padding, 2, 'padding') + stride = utils.convert_to_list(stride, 2, 'stride') + dilation = utils.convert_to_list(dilation, 2, 'dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") - op_attr['use_cudnn'] = use_cudnn if filter_size is None: if output_size is None: @@ -1810,10 +1795,6 @@ def conv2d_transpose(input, if isinstance(output_size, int): output_size = [output_size, output_size] - padding = op_attr.get('paddings', [0, 0]) - stride = op_attr.get('strides', [1, 1]) - dilation = op_attr.get('dilations', [1, 1]) - h_in = input.shape[2] w_in = input.shape[3] @@ -1822,9 +1803,9 @@ def conv2d_transpose(input, filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] - 1) / dilation[1] + 1 filter_size = [filter_size_h, filter_size_w] - - elif isinstance(filter_size, int): - filter_size = [filter_size, filter_size] + else: + filter_size = utils.convert_to_list(filter_size, 2, + 'conv2d_transpose.filter_size') filter_shape = [input_channel, num_filters] + filter_size img_filter = helper.create_parameter( @@ -1836,7 +1817,12 @@ def conv2d_transpose(input, inputs={'Input': [input], 'Filter': [img_filter]}, outputs={'Output': out}, - attrs=op_attr) + attrs={ + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'use_cudnn': use_cudnn + }) return out diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..49ec3088831dff415e042e1b0a632f63106eb07b --- /dev/null +++ b/python/paddle/fluid/layers/utils.py @@ -0,0 +1,59 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + + +def convert_to_list(value, n, name, dtype=np.int): + """ + Converts a single numerical type or iterable of numerical + types into an numerical type list. + + Arguments: + value: The value to validate and convert. Could an int, or any iterable + of ints. + n: The size of the list to be returned. + name: The name of the argument being validated, e.g. "stride" or + "filter_size". This is only used to format error messages. + dtype: the numerical type of the element of the list to be returned. + + Returns: + A list of n dtypes. + + Raises: + ValueError: If something else than an int/long or iterable thereof was + passed. + """ + if isinstance(value, dtype): + return [value, ] * n + else: + try: + value_list = list(value) + except TypeError: + raise ValueError("The " + name + + "'s type must be list or tuple. Received: " + str( + value)) + if len(value_list) != n: + raise ValueError("The " + name + "'s length must be " + str(n) + + ". Received: " + str(value)) + for single_value in value_list: + try: + dtype(single_value) + except (ValueError, TypeError): + raise ValueError( + "The " + name + "'s type must be a list or tuple of " + str( + n) + " " + str(dtype) + " . Received: " + str( + value) + " " + "including element " + str(single_value) + " of type" + " " + + str(type(single_value))) + return value_list diff --git a/python/paddle/fluid/tests/book/notest_rnn_encoder_decoer.py b/python/paddle/fluid/tests/book/notest_rnn_encoder_decoer.py index 1ed58c3d3d170a938cef813692d7841227964b16..983f8f4dbeac83566839de25ec9765eb248be768 100644 --- a/python/paddle/fluid/tests/book/notest_rnn_encoder_decoer.py +++ b/python/paddle/fluid/tests/book/notest_rnn_encoder_decoer.py @@ -228,32 +228,34 @@ def infer(use_cuda, save_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - - lod = [0, 4, 10] - word_data = create_random_lodtensor(lod, place, low=0, high=1) - trg_word = create_random_lodtensor(lod, place, low=0, high=1) - - # Construct feed as a dictionary of {feed_target_name: feed_target_data} - # and results will contain a list of data corresponding to fetch_targets. - assert feed_target_names[0] == 'source_sequence' - assert feed_target_names[1] == 'target_sequence' - results = exe.run(inference_program, - feed={ - feed_target_names[0]: word_data, - feed_target_names[1]: trg_word, - }, - fetch_list=fetch_targets, - return_numpy=False) - print(results[0].lod()) - np_data = np.array(results[0]) - print("Inference shape: ", np_data.shape) - print("Inference results: ", np_data) + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + lod = [0, 4, 10] + word_data = create_random_lodtensor(lod, place, low=0, high=1) + trg_word = create_random_lodtensor(lod, place, low=0, high=1) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + assert feed_target_names[0] == 'source_sequence' + assert feed_target_names[1] == 'target_sequence' + results = exe.run(inference_program, + feed={ + feed_target_names[0]: word_data, + feed_target_names[1]: trg_word, + }, + fetch_list=fetch_targets, + return_numpy=False) + print(results[0].lod()) + np_data = np.array(results[0]) + print("Inference shape: ", np_data.shape) + print("Inference results: ", np_data) def main(use_cuda): diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 8ceee52ff9c425a6cc5479acb9c5b8f0928fc991..8a45533e3bfbacffbef3fc226892062d8cc8e6c7 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -72,23 +72,26 @@ def infer(use_cuda, save_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - - # The input's dimension should be 2-D and the second dim is 13 - # The input data should be >= 0 - batch_size = 10 - tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32") - assert feed_target_names[0] == 'x' - results = exe.run(inference_program, - feed={feed_target_names[0]: tensor_x}, - fetch_list=fetch_targets) - print("infer shape: ", results[0].shape) - print("infer results: ", results[0]) + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + # The input's dimension should be 2-D and the second dim is 13 + # The input data should be >= 0 + batch_size = 10 + tensor_x = numpy.random.uniform(0, 10, + [batch_size, 13]).astype("float32") + assert feed_target_names[0] == 'x' + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_x}, + fetch_list=fetch_targets) + print("infer shape: ", results[0].shape) + print("infer results: ", results[0]) def main(use_cuda): diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 615e23529a9ac613d5e37ae68175cc09ad73b43f..60c66bc22c69ec836949d40ce2e18f8ecf0e07b8 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -174,22 +174,26 @@ def infer(use_cuda, save_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - - # The input's dimension of conv should be 4-D or 5-D. - tensor_img = numpy.random.rand(1, 3, 32, 32).astype("float32") - - # Construct feed as a dictionary of {feed_target_name: feed_target_data} - # and results will contain a list of data corresponding to fetch_targets. - results = exe.run(inference_program, - feed={feed_target_names[0]: tensor_img}, - fetch_list=fetch_targets) - print("infer results: ", results[0]) + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + # The input's dimension of conv should be 4-D or 5-D. + # Use normilized image pixels as input data, which should be in the range [0, 1.0]. + batch_size = 1 + tensor_img = numpy.random.rand(batch_size, 3, 32, 32).astype("float32") + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + print("infer results: ", results[0]) def main(net_type, use_cuda): diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index 336e6ed2a32aeefe07a17055ae29d6b82eb5041e..cbb4d4b0401d160db7b97ad3d5e6489e2766d19c 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -26,7 +26,7 @@ import unittest word_dict, verb_dict, label_dict = conll05.get_dict() word_dict_len = len(word_dict) label_dict_len = len(label_dict) -pred_len = len(verb_dict) +pred_dict_len = len(verb_dict) mark_dict_len = 2 word_dim = 32 @@ -53,7 +53,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, # 8 features predicate_embedding = fluid.layers.embedding( input=predicate, - size=[pred_len, word_dim], + size=[pred_dict_len, word_dim], dtype='float32', is_sparse=IS_SPARSE, param_attr='vemb') @@ -234,6 +234,7 @@ def train(use_cuda, save_dirname=None): # Set the threshold low to speed up the CI test if float(pass_precision) > 0.05: if save_dirname is not None: + # TODO(liuyiqun): Change the target to crf_decode fluid.io.save_inference_model(save_dirname, [ 'word_data', 'verb_data', 'ctx_n2_data', 'ctx_n1_data', 'ctx_0_data', 'ctx_p1_data', @@ -251,51 +252,60 @@ def infer(use_cuda, save_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - - lod = [0, 4, 10] - ts_word = create_random_lodtensor(lod, place, low=0, high=1) - ts_pred = create_random_lodtensor(lod, place, low=0, high=1) - ts_ctx_n2 = create_random_lodtensor(lod, place, low=0, high=1) - ts_ctx_n1 = create_random_lodtensor(lod, place, low=0, high=1) - ts_ctx_0 = create_random_lodtensor(lod, place, low=0, high=1) - ts_ctx_p1 = create_random_lodtensor(lod, place, low=0, high=1) - ts_ctx_p2 = create_random_lodtensor(lod, place, low=0, high=1) - ts_mark = create_random_lodtensor(lod, place, low=0, high=1) - - # Construct feed as a dictionary of {feed_target_name: feed_target_data} - # and results will contain a list of data corresponding to fetch_targets. - assert feed_target_names[0] == 'word_data' - assert feed_target_names[1] == 'verb_data' - assert feed_target_names[2] == 'ctx_n2_data' - assert feed_target_names[3] == 'ctx_n1_data' - assert feed_target_names[4] == 'ctx_0_data' - assert feed_target_names[5] == 'ctx_p1_data' - assert feed_target_names[6] == 'ctx_p2_data' - assert feed_target_names[7] == 'mark_data' - - results = exe.run(inference_program, - feed={ - feed_target_names[0]: ts_word, - feed_target_names[1]: ts_pred, - feed_target_names[2]: ts_ctx_n2, - feed_target_names[3]: ts_ctx_n1, - feed_target_names[4]: ts_ctx_0, - feed_target_names[5]: ts_ctx_p1, - feed_target_names[6]: ts_ctx_p2, - feed_target_names[7]: ts_mark - }, - fetch_list=fetch_targets, - return_numpy=False) - print(results[0].lod()) - np_data = np.array(results[0]) - print("Inference Shape: ", np_data.shape) - print("Inference results: ", np_data) + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + lod = [0, 4, 10] + word = create_random_lodtensor( + lod, place, low=0, high=word_dict_len - 1) + pred = create_random_lodtensor( + lod, place, low=0, high=pred_dict_len - 1) + ctx_n2 = create_random_lodtensor( + lod, place, low=0, high=word_dict_len - 1) + ctx_n1 = create_random_lodtensor( + lod, place, low=0, high=word_dict_len - 1) + ctx_0 = create_random_lodtensor( + lod, place, low=0, high=word_dict_len - 1) + ctx_p1 = create_random_lodtensor( + lod, place, low=0, high=word_dict_len - 1) + ctx_p2 = create_random_lodtensor( + lod, place, low=0, high=word_dict_len - 1) + mark = create_random_lodtensor( + lod, place, low=0, high=mark_dict_len - 1) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + assert feed_target_names[0] == 'word_data' + assert feed_target_names[1] == 'verb_data' + assert feed_target_names[2] == 'ctx_n2_data' + assert feed_target_names[3] == 'ctx_n1_data' + assert feed_target_names[4] == 'ctx_0_data' + assert feed_target_names[5] == 'ctx_p1_data' + assert feed_target_names[6] == 'ctx_p2_data' + assert feed_target_names[7] == 'mark_data' + + results = exe.run(inference_program, + feed={ + feed_target_names[0]: word, + feed_target_names[1]: pred, + feed_target_names[2]: ctx_n2, + feed_target_names[3]: ctx_n1, + feed_target_names[4]: ctx_0, + feed_target_names[5]: ctx_p1, + feed_target_names[6]: ctx_p2, + feed_target_names[7]: mark + }, + fetch_list=fetch_targets, + return_numpy=False) + print(results[0].lod()) + np_data = np.array(results[0]) + print("Inference Shape: ", np_data.shape) def main(use_cuda): diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index 12307111d5dda549bff7ea40ac7c341c69c3e4bd..285e91420375f63d8b37138f1565e7b77defb0c7 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -78,7 +78,12 @@ def conv_net(img, label): return loss_net(conv_pool_2, label) -def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): +def train(nn_type, + use_cuda, + parallel, + save_dirname=None, + model_filename=None, + params_filename=None): if use_cuda and not fluid.core.is_compiled_with_cuda(): return img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') @@ -146,7 +151,8 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): fluid.io.save_inference_model( save_dirname, ["img"], [prediction], exe, - save_file_name=save_param_filename) + model_filename=model_filename, + params_filename=params_filename) return else: print( @@ -158,54 +164,62 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): raise AssertionError("Loss of recognize digits is too large") -def infer(use_cuda, save_dirname=None, param_filename=None): +def infer(use_cuda, + save_dirname=None, + model_filename=None, + params_filename=None): if save_dirname is None: return place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, fetch_targets - ] = fluid.io.load_inference_model(save_dirname, exe, param_filename) - - # The input's dimension of conv should be 4-D or 5-D. - # Use normilized image pixels as input data, which should be in the range [-1.0, 1.0]. - batch_size = 1 - tensor_img = numpy.random.uniform(-1.0, 1.0, - [batch_size, 1, 28, 28]).astype("float32") - - # Construct feed as a dictionary of {feed_target_name: feed_target_data} - # and results will contain a list of data corresponding to fetch_targets. - results = exe.run(inference_program, - feed={feed_target_names[0]: tensor_img}, - fetch_list=fetch_targets) - print("infer results: ", results[0]) + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + save_dirname, exe, model_filename, params_filename) + + # The input's dimension of conv should be 4-D or 5-D. + # Use normilized image pixels as input data, which should be in the range [-1.0, 1.0]. + batch_size = 1 + tensor_img = numpy.random.uniform( + -1.0, 1.0, [batch_size, 1, 28, 28]).astype("float32") + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + print("infer results: ", results[0]) def main(use_cuda, parallel, nn_type, combine): + save_dirname = None + model_filename = None + params_filename = None if not use_cuda and not parallel: save_dirname = "recognize_digits_" + nn_type + ".inference.model" - save_filename = None if combine == True: - save_filename = "__params_combined__" - else: - save_dirname = None - save_filename = None + model_filename = "__model_combined__" + params_filename = "__params_combined__" train( nn_type=nn_type, use_cuda=use_cuda, parallel=parallel, save_dirname=save_dirname, - save_param_filename=save_filename) + model_filename=model_filename, + params_filename=params_filename) infer( use_cuda=use_cuda, save_dirname=save_dirname, - param_filename=save_filename) + model_filename=model_filename, + params_filename=params_filename) class TestRecognizeDigits(unittest.TestCase): diff --git a/python/paddle/fluid/tests/book/test_recommender_system.py b/python/paddle/fluid/tests/book/test_recommender_system.py index c190107e02e044635ff0b47c61de41c8bfed5acc..7c58c3e7823a82b5ccc7bb55a5e833969242ad96 100644 --- a/python/paddle/fluid/tests/book/test_recommender_system.py +++ b/python/paddle/fluid/tests/book/test_recommender_system.py @@ -251,13 +251,6 @@ def infer(use_cuda, save_dirname=None): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - def create_lod_tensor(data, lod=None): tensor = fluid.LoDTensor() if lod is None: @@ -275,44 +268,53 @@ def infer(use_cuda, save_dirname=None): tensor.set(flattened_data, place) return tensor - # Use the first data from paddle.dataset.movielens.test() as input - assert feed_target_names[0] == "user_id" - user_id = create_lod_tensor([[1]]) - - assert feed_target_names[1] == "gender_id" - gender_id = create_lod_tensor([[1]]) - - assert feed_target_names[2] == "age_id" - age_id = create_lod_tensor([[0]]) - - assert feed_target_names[3] == "job_id" - job_id = create_lod_tensor([[10]]) - - assert feed_target_names[4] == "movie_id" - movie_id = create_lod_tensor([[783]]) - - assert feed_target_names[5] == "category_id" - category_id = create_lod_tensor([[10], [8], [9]], [[0, 3]]) - - assert feed_target_names[6] == "movie_title" - movie_title = create_lod_tensor([[1069], [4140], [2923], [710], [988]], - [[0, 5]]) - - # Construct feed as a dictionary of {feed_target_name: feed_target_data} - # and results will contain a list of data corresponding to fetch_targets. - results = exe.run(inference_program, - feed={ - feed_target_names[0]: user_id, - feed_target_names[1]: gender_id, - feed_target_names[2]: age_id, - feed_target_names[3]: job_id, - feed_target_names[4]: movie_id, - feed_target_names[5]: category_id, - feed_target_names[6]: movie_title - }, - fetch_list=fetch_targets, - return_numpy=False) - print("inferred score: ", np.array(results[0])) + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + # Use the first data from paddle.dataset.movielens.test() as input + assert feed_target_names[0] == "user_id" + user_id = create_lod_tensor([[1]]) + + assert feed_target_names[1] == "gender_id" + gender_id = create_lod_tensor([[1]]) + + assert feed_target_names[2] == "age_id" + age_id = create_lod_tensor([[0]]) + + assert feed_target_names[3] == "job_id" + job_id = create_lod_tensor([[10]]) + + assert feed_target_names[4] == "movie_id" + movie_id = create_lod_tensor([[783]]) + + assert feed_target_names[5] == "category_id" + category_id = create_lod_tensor([[10], [8], [9]], [[0, 3]]) + + assert feed_target_names[6] == "movie_title" + movie_title = create_lod_tensor([[1069], [4140], [2923], [710], [988]], + [[0, 5]]) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + results = exe.run(inference_program, + feed={ + feed_target_names[0]: user_id, + feed_target_names[1]: gender_id, + feed_target_names[2]: age_id, + feed_target_names[3]: job_id, + feed_target_names[4]: movie_id, + feed_target_names[5]: category_id, + feed_target_names[6]: movie_title + }, + fetch_list=fetch_targets, + return_numpy=False) + print("inferred score: ", np.array(results[0])) def main(use_cuda): diff --git a/python/paddle/fluid/tests/book/test_understand_sentiment.py b/python/paddle/fluid/tests/book/test_understand_sentiment.py index ab8df93651c01f75eeda1eab1ac95db867678106..fae74c355710e472734b8b15176baf2cfdc5acc4 100644 --- a/python/paddle/fluid/tests/book/test_understand_sentiment.py +++ b/python/paddle/fluid/tests/book/test_understand_sentiment.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -193,36 +193,39 @@ def train(word_dict, net_method, use_cuda, parallel=False, save_dirname=None): net_method.__name__)) -def infer(use_cuda, save_dirname=None): +def infer(word_dict, use_cuda, save_dirname=None): if save_dirname is None: return place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - - lod = [0, 4, 10] - word_dict = paddle.dataset.imdb.word_dict() - tensor_words = create_random_lodtensor( - lod, place, low=0, high=len(word_dict) - 1) - - # Construct feed as a dictionary of {feed_target_name: feed_target_data} - # and results will contain a list of data corresponding to fetch_targets. - assert feed_target_names[0] == "words" - results = exe.run(inference_program, - feed={feed_target_names[0]: tensor_words}, - fetch_list=fetch_targets, - return_numpy=False) - print(results[0].lod()) - np_data = np.array(results[0]) - print("Inference Shape: ", np_data.shape) - print("Inference results: ", np_data) + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + word_dict_len = len(word_dict) + + lod = [0, 4, 10] + tensor_words = create_random_lodtensor( + lod, place, low=0, high=word_dict_len - 1) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + assert feed_target_names[0] == "words" + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_words}, + fetch_list=fetch_targets, + return_numpy=False) + print(results[0].lod()) + np_data = np.array(results[0]) + print("Inference Shape: ", np_data.shape) + print("Inference results: ", np_data) def main(word_dict, net_method, use_cuda, parallel=False, save_dirname=None): @@ -258,7 +261,7 @@ class TestUnderstandSentiment(unittest.TestCase): self.word_dict, net_method=convolution_net, use_cuda=False, - save_dirname="understand_sentiment.inference.model") + save_dirname="understand_sentiment_conv.inference.model") def test_conv_cpu_parallel(self): with self.new_program_scope(): @@ -271,7 +274,11 @@ class TestUnderstandSentiment(unittest.TestCase): @unittest.skip(reason="make CI faster") def test_stacked_lstm_cpu(self): with self.new_program_scope(): - main(self.word_dict, net_method=stacked_lstm_net, use_cuda=False) + main( + self.word_dict, + net_method=stacked_lstm_net, + use_cuda=False, + save_dirname="understand_sentiment_stacked_lstm.inference.model") def test_stacked_lstm_cpu_parallel(self): with self.new_program_scope(): @@ -287,7 +294,7 @@ class TestUnderstandSentiment(unittest.TestCase): self.word_dict, net_method=convolution_net, use_cuda=True, - save_dirname="understand_sentiment.inference.model") + save_dirname="understand_sentiment_conv.inference.model") def test_conv_gpu_parallel(self): with self.new_program_scope(): @@ -300,7 +307,11 @@ class TestUnderstandSentiment(unittest.TestCase): @unittest.skip(reason="make CI faster") def test_stacked_lstm_gpu(self): with self.new_program_scope(): - main(self.word_dict, net_method=stacked_lstm_net, use_cuda=True) + main( + self.word_dict, + net_method=stacked_lstm_net, + use_cuda=True, + save_dirname="understand_sentiment_stacked_lstm.inference.model") def test_stacked_lstm_gpu_parallel(self): with self.new_program_scope(): diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py index f33a759240f21f52817c482a2ebe008155dbd97b..696abd5499c826eda5c868ab1e7c9f4f839cdce3 100644 --- a/python/paddle/fluid/tests/book/test_word2vec.py +++ b/python/paddle/fluid/tests/book/test_word2vec.py @@ -1,5 +1,6 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# # Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -21,6 +22,7 @@ import sys def create_random_lodtensor(lod, place, low, high): + # The range of data elements is [low, high] data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64") res = fluid.LoDTensor() res.set(data, place) @@ -28,54 +30,7 @@ def create_random_lodtensor(lod, place, low, high): return res -def infer(use_cuda, save_dirname=None): - if save_dirname is None: - return - - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be feeded - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - - word_dict = paddle.dataset.imikolov.build_dict() - dict_size = len(word_dict) - 1 - - # Setup input, by creating 4 words, and setting up lod required for - # lookup_table_op - lod = [0, 1] - first_word = create_random_lodtensor(lod, place, low=0, high=dict_size) - second_word = create_random_lodtensor(lod, place, low=0, high=dict_size) - third_word = create_random_lodtensor(lod, place, low=0, high=dict_size) - fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size) - - assert feed_target_names[0] == 'firstw' - assert feed_target_names[1] == 'secondw' - assert feed_target_names[2] == 'thirdw' - assert feed_target_names[3] == 'forthw' - - # Construct feed as a dictionary of {feed_target_name: feed_target_data} - # and results will contain a list of data corresponding to fetch_targets. - results = exe.run(inference_program, - feed={ - feed_target_names[0]: first_word, - feed_target_names[1]: second_word, - feed_target_names[2]: third_word, - feed_target_names[3]: fourth_word - }, - fetch_list=fetch_targets, - return_numpy=False) - print(results[0].lod()) - np_data = np.array(results[0]) - print("Inference Shape: ", np_data.shape) - print("Inference results: ", np_data) - - -def train(use_cuda, is_sparse, parallel, save_dirname): +def train(use_cuda, is_sparse, is_parallel, save_dirname): PASS_NUM = 100 EMBED_SIZE = 32 HIDDEN_SIZE = 256 @@ -130,7 +85,7 @@ def train(use_cuda, is_sparse, parallel, save_dirname): forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64') next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64') - if not parallel: + if not is_parallel: avg_cost, predict_word = __network__( [first_word, second_word, third_word, forth_word, next_word]) else: @@ -176,11 +131,67 @@ def train(use_cuda, is_sparse, parallel, save_dirname): raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0])) -def main(use_cuda, is_sparse, parallel): +def infer(use_cuda, save_dirname=None): + if save_dirname is None: + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + word_dict = paddle.dataset.imikolov.build_dict() + dict_size = len(word_dict) + + # Setup inputs, by creating 4 words, the lod of which should be [0, 1] + lod = [0, 1] + first_word = create_random_lodtensor( + lod, place, low=0, high=dict_size - 1) + second_word = create_random_lodtensor( + lod, place, low=0, high=dict_size - 1) + third_word = create_random_lodtensor( + lod, place, low=0, high=dict_size - 1) + fourth_word = create_random_lodtensor( + lod, place, low=0, high=dict_size - 1) + + assert feed_target_names[0] == 'firstw' + assert feed_target_names[1] == 'secondw' + assert feed_target_names[2] == 'thirdw' + assert feed_target_names[3] == 'forthw' + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + results = exe.run(inference_program, + feed={ + feed_target_names[0]: first_word, + feed_target_names[1]: second_word, + feed_target_names[2]: third_word, + feed_target_names[3]: fourth_word + }, + fetch_list=fetch_targets, + return_numpy=False) + print(results[0].lod()) + np_data = np.array(results[0]) + print("Inference Shape: ", np_data.shape) + + +def main(use_cuda, is_sparse, is_parallel): if use_cuda and not fluid.core.is_compiled_with_cuda(): return - save_dirname = "word2vec.inference.model" - train(use_cuda, is_sparse, parallel, save_dirname) + + if not is_parallel: + save_dirname = "word2vec.inference.model" + else: + save_dirname = None + + train(use_cuda, is_sparse, is_parallel, save_dirname) infer(use_cuda, save_dirname) @@ -193,10 +204,10 @@ class W2VTest(unittest.TestCase): pass -def inject_test_method(use_cuda, is_sparse, parallel): +def inject_test_method(use_cuda, is_sparse, is_parallel): fn_name = "test_{0}_{1}_{2}".format("cuda" if use_cuda else "cpu", "sparse" if is_sparse else "dense", "parallel" - if parallel else "normal") + if is_parallel else "normal") def __impl__(*args, **kwargs): prog = fluid.Program() @@ -204,10 +215,12 @@ def inject_test_method(use_cuda, is_sparse, parallel): scope = fluid.core.Scope() with fluid.scope_guard(scope): with fluid.program_guard(prog, startup_prog): - main(use_cuda=use_cuda, is_sparse=is_sparse, parallel=parallel) + main( + use_cuda=use_cuda, + is_sparse=is_sparse, + is_parallel=is_parallel) - # run only 2 cases: use_cuda is either True or False - if is_sparse == False and parallel == False: + if use_cuda and is_sparse: fn = __impl__ else: # skip the other test when on CI server @@ -219,8 +232,8 @@ def inject_test_method(use_cuda, is_sparse, parallel): for use_cuda in (False, True): for is_sparse in (False, True): - for parallel in (False, True): - inject_test_method(use_cuda, is_sparse, parallel) + for is_parallel in (False, True): + inject_test_method(use_cuda, is_sparse, is_parallel) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py index 9f9af2f55e2e9a1c624fb95f1c113e24c2de4a89..f7461ee6dab699064153332116449c8e20a0bac0 100644 --- a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py +++ b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py @@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist): idx += 1 -def batch_bipartite_match(distance, lod): +def argmax_match(distance, match_indices, match_dist, threshold): + r, c = distance.shape + for j in xrange(c): + if match_indices[j] != -1: + continue + col_dist = distance[:, j] + indices = np.argwhere(col_dist >= threshold).flatten() + if len(indices) < 1: + continue + match_indices[j] = indices[np.argmax(col_dist[indices])] + match_dist[j] = col_dist[match_indices[j]] + + +def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None): """Bipartite Matching algorithm for batch input. Arg: distance (numpy.array) : The distance of two entries with shape [M, N]. @@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod): for i in range(len(lod) - 1): bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], match_dist[i, :]) + if match_type == 'per_prediction': + argmax_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], + match_dist[i, :], dist_threshold) return match_indices, match_dist @@ -71,8 +87,8 @@ class TestBipartiteMatchOpWithLoD(OpTest): self.inputs = {'DistMat': (dist, lod)} self.outputs = { - 'ColToRowMatchIndices': (match_indices), - 'ColToRowMatchDist': (match_dist), + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDist': match_dist, } def test_check_output(self): @@ -96,5 +112,27 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): self.check_output() +class TestBipartiteMatchOpWithPerPredictionType(OpTest): + def setUp(self): + self.op_type = 'bipartite_match' + lod = [[0, 5, 11, 23]] + dist = np.random.random((23, 237)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0], + 'per_prediction', 0.5) + + self.inputs = {'DistMat': (dist, lod)} + self.outputs = { + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDist': match_dist, + } + self.attrs = { + 'match_type': 'per_prediction', + 'dist_threshold': 0.5, + } + + def test_check_output(self): + self.check_output() + + if __name__ == '__main__': unittest.main()