diff --git a/CMakeLists.txt b/CMakeLists.txt index 26d94384a9150735aa8341fd8a18cb039895ff91..02752de762ca6df552af362dd4624cce21becda6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,33 +47,34 @@ find_package(Threads REQUIRED) include(simd) -################################ Configurations ####################################### +################################ Exposed Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) -option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF) +option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) +option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) +option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND}) +option(WITH_SYSTEM_BLAS "Use system blas library" OFF) +option(WITH_DISTRIBUTE "Compile with distributed support" OFF) +option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) +option(ON_INFER "Turn on inference optimization." OFF) +option(WITH_ANAKIN "Compile with Anakin library" OFF) +################################ Internal Configurations ####################################### +option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF) option(WITH_NGRAPH "Compile PaddlePaddle with nGraph support." OFF) -option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) -option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) -option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF) option(WITH_JEMALLOC "Compile PaddlePaddle with jemalloc" OFF) option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) -option(WITH_DISTRIBUTE "Compile with distributed support" OFF) option(WITH_PSLIB "Compile with pslib support" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) # TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter. -option(WITH_ANAKIN "Compile with Anakin library" OFF) option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF) option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) -option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) -option(ON_INFER "Turn on inference optimization." OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF) option(WITH_HIGH_LEVEL_API_TEST "Test fluid python high-level api interface" OFF) -option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bf39325cc9bfb258051ec1a7fc7f5eb139c60133..f8c97b05bc08a26c1fa0bdcc2cd1abd932af158a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -241,6 +241,7 @@ paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329')) paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', '731b21c62a4add60a33bd76d802ffc5c')) paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) +paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'a07a44c2bacdcd09c1f5f35a96a0514e')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e')) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index e9aad5d264d1745662848d1ba313b573d0974cb7..7f63c07b18f7c6147670656dfc567f8f2ae8429a 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -64,9 +64,12 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("epmap")); auto height_section = boost::get>( node->Op()->GetNullableAttr("sections")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(send_var_name, send_varnames, - epmap, height_section); + epmap, height_section, + trainer_id); VLOG(3) << "find and init an send op: " << send_varname_to_ctx[send_var_name]; } else if (node->Name() == "recv") { @@ -75,9 +78,11 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("recv_varnames")); auto epmap = boost::get>( node->Op()->GetNullableAttr("epmap")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(recv_var_name, recv_varnames, - epmap, {}); + epmap, {}, trainer_id); nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 0155609a029664da2c3d4c90a152ec556927c32d..fcab1ab186127e40701da9420426b4ef27c7f95d 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -832,6 +832,45 @@ std::string AnalysisPredictor::GetSerializedProgram() const { return inference_program_->Proto()->SerializeAsString(); } +// Add SaveOptimModel +void AnalysisPredictor::SaveOptimModel(const std::string &dir) { + // save model + std::string model_name = dir + "/model"; + std::ofstream outfile; + outfile.open(model_name, std::ios::out | std::ios::binary); + std::string inference_prog_desc = GetSerializedProgram(); + outfile << inference_prog_desc; + // save params + framework::ProgramDesc save_program; + auto *save_block = save_program.MutableBlock(0); + + const framework::ProgramDesc &main_program = program(); + const framework::BlockDesc &global_block = main_program.Block(0); + std::vector save_var_list; + for (framework::VarDesc *var : global_block.AllVars()) { + if (IsPersistable(var)) { + framework::VarDesc *new_var = save_block->Var(var->Name()); + new_var->SetShape(var->GetShape()); + new_var->SetDataType(var->GetDataType()); + new_var->SetType(var->GetType()); + new_var->SetLoDLevel(var->GetLoDLevel()); + new_var->SetPersistable(true); + + save_var_list.push_back(new_var->Name()); + } + } + std::sort(save_var_list.begin(), save_var_list.end()); + auto *op = save_block->AppendOp(); + op->SetType("save_combine"); + op->SetInput("X", save_var_list); + op->SetAttr("file_path", dir + "/params"); + op->CheckAttrs(); + + platform::CPUPlace place; + framework::Executor exe(place); + exe.Run(save_program, scope(), 0, true, true); +} + template <> std::unique_ptr CreatePaddlePredictor( const AnalysisConfig &config) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index e4c537f426650f16ced32d3cb61b944a78c35b43..b5e134ced70f8bf9ef0267bee08ec9836aeb5338 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -86,6 +86,10 @@ class AnalysisPredictor : public PaddlePredictor { bool MkldnnQuantize(); + // save program to model + // save parameters to params + void SaveOptimModel(const std::string &dir); + protected: // For memory optimization. bool need_collect_var_shapes_for_memory_optim(); diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 0429a287c74f9db5257181151d90b77da86c694c..6bc892638c28ca0b5bab82936bf9700289bed6b2 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -196,6 +196,9 @@ TEST(AnalysisPredictor, Clone) { } } +// This function is not released yet, will fail on some machine. +// TODO(Superjomn) Turn on it latter. +/* TEST(AnalysisPredictor, memory_optim) { AnalysisConfig config(FLAGS_dirname); config.DisableGpu(); @@ -246,6 +249,7 @@ TEST(AnalysisPredictor, memory_optim) { inference::CompareResult(output, output1); } +*/ #ifdef PADDLE_WITH_MKLDNN class MkldnnQuantizerTest : public testing::Test { diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index e10d239a5d1b30e089a110c6155520e3b035860a..c9da5b3ea5581e415f11c8f85e1d6aea757531ab 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -170,6 +170,15 @@ void SetConfig(AnalysisConfig *cfg) { cfg->SwitchIrOptim(true); } +void SetOptimConfig(AnalysisConfig *cfg) { + std::string optimModelPath = + FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) + + "/saved_optim_model"; + cfg->SetModel(optimModelPath + "/model", optimModelPath + "/params"); + cfg->SwitchIrOptim(true); + cfg->SwitchSpecifyInputNames(); +} + void SetInput(std::vector> *inputs) { DataRecord data(FLAGS_infer_data, FLAGS_batch_size); std::vector input_slots; @@ -315,5 +324,44 @@ TEST(Analyzer_dam, compare_determine) { input_slots_all); } +// Save optim model +TEST(Analyzer_dam, save_optim_model) { + AnalysisConfig cfg; + SetConfig(&cfg); + std::string optimModelPath = + FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) + + "/saved_optim_model"; + mkdir(optimModelPath.c_str(), 0777); + auto predictor = CreateTestPredictor( + reinterpret_cast(&cfg), + FLAGS_use_analysis); + (static_cast(predictor.get())) + ->SaveOptimModel(optimModelPath); +} + +void CompareOptimAndOrig(const PaddlePredictor::Config *orig_config, + const PaddlePredictor::Config *optim_config, + const std::vector> &inputs) { + PrintConfig(orig_config, true); + PrintConfig(optim_config, true); + std::vector> orig_outputs, optim_outputs; + TestOneThreadPrediction(orig_config, inputs, &orig_outputs, false); + TestOneThreadPrediction(optim_config, inputs, &optim_outputs, false); + CompareResult(orig_outputs.back(), optim_outputs.back()); +} + +TEST(Analyzer_dam, compare_optim_orig) { + AnalysisConfig orig_cfg; + AnalysisConfig optim_cfg; + SetConfig(&orig_cfg); + SetOptimConfig(&optim_cfg); + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareOptimAndOrig( + reinterpret_cast(&orig_cfg), + reinterpret_cast(&optim_cfg), + input_slots_all); +} + } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index d4330e6cddf8818ace01be2f13a4c18a192c46e1..588c80aa607c8d79365bbdfbb42a3d3c7667dbb2 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -32,6 +32,17 @@ void SetInput(std::vector> *inputs) { SetFakeImageInput(inputs, FLAGS_infer_model); } +void SetOptimConfig(AnalysisConfig *cfg) { + std::string optimModelPath = + FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) + + "/saved_optim_model"; + cfg->SetModel(optimModelPath + "/model", optimModelPath + "/params"); + cfg->DisableGpu(); + cfg->SwitchIrOptim(); + cfg->SwitchSpecifyInputNames(); + cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); +} + // Easy for profiling independently. void profile(bool use_mkldnn = false) { AnalysisConfig cfg; @@ -87,13 +98,51 @@ TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_resnet50, compare_determine) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector> input_slots_all; SetInput(&input_slots_all); CompareDeterministic(reinterpret_cast(&cfg), input_slots_all); } +// Save optim model +TEST(Analyzer_resnet50, save_optim_model) { + AnalysisConfig cfg; + SetConfig(&cfg); + std::string optimModelPath = + FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) + + "/saved_optim_model"; + mkdir(optimModelPath.c_str(), 0777); + auto predictor = CreateTestPredictor( + reinterpret_cast(&cfg), + FLAGS_use_analysis); + (static_cast(predictor.get())) + ->SaveOptimModel(optimModelPath); +} + +void CompareOptimAndOrig(const PaddlePredictor::Config *orig_config, + const PaddlePredictor::Config *optim_config, + const std::vector> &inputs) { + PrintConfig(orig_config, true); + PrintConfig(optim_config, true); + std::vector> orig_outputs, optim_outputs; + TestOneThreadPrediction(orig_config, inputs, &orig_outputs, false); + TestOneThreadPrediction(optim_config, inputs, &optim_outputs, false); + CompareResult(orig_outputs.back(), optim_outputs.back()); +} + +TEST(Analyzer_resnet50, compare_optim_orig) { + AnalysisConfig orig_cfg; + AnalysisConfig optim_cfg; + SetConfig(&orig_cfg); + SetOptimConfig(&optim_cfg); + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareOptimAndOrig( + reinterpret_cast(&orig_cfg), + reinterpret_cast(&optim_cfg), + input_slots_all); +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index 63eaa676a43fc784dce2437ca15bc85e2295dbb7..21a25ce7d5e2bad172cf50cee6138ef4b44b07c1 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -29,8 +29,6 @@ pool3d prelu quantize rank_loss -reduce_all -reduce_any reduce_max reduce_mean reduce_min diff --git a/paddle/fluid/operators/conv_shift_op.cc b/paddle/fluid/operators/conv_shift_op.cc index 08506ddd18ed35831702814e70962cb36ec958b1..fa4edb70b48e529102f11a1b0b9cac2110a33966 100644 --- a/paddle/fluid/operators/conv_shift_op.cc +++ b/paddle/fluid/operators/conv_shift_op.cc @@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], - "The 1st dimension of Input(X) and Input(Y) should " - "be equal."); - PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, - "The 2nd dimension of Input(Y) should be odd."); - PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], - "The 2nd dimension of Input(Y) should be less than or " - "equal to the 2nd dimension of Input(X)."); + if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0)) + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The 1st dimension of Input(X) and Input(Y) should " + "be equal."); + if (ctx->IsRuntime() || y_dims[1] > 0) + PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, + "The 2nd dimension of Input(Y) should be odd."); + if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0)) + PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], + "The 2nd dimension of Input(Y) should be less than or " + "equal to the 2nd dimension of Input(X)."); ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..53ed86ade48ce52d49285495388f93f1bc4f5d9e --- /dev/null +++ b/paddle/fluid/operators/cvm_op.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cvm_op.h" +#include +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class CVMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto cvm_dims = ctx->GetInputDim("CVM"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(cvm_dims.size(), 2UL, "Input(CVM)'s rank should be 2."); + PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, + "The 2nd dimension of " + "Input(CVM) should be 2."); + + if (ctx->Attrs().Get("use_cvm")) { + ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]}); + } else { + ctx->SetOutputDim("Y", {x_dims[0], x_dims[1] - 2}); + } + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of + // cvm + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); + } +}; + +class CVMGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto cvm_dims = ctx->GetInputDim("CVM"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); + PADDLE_ENFORCE_EQ(cvm_dims.size(), 2, "Input(CVM)'s rank should be 2."); + + PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal."); + + PADDLE_ENFORCE_EQ(cvm_dims[1], 2, + "When Attr(soft_label) == false, the 2nd dimension of " + "Input(CVM) should be 2."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of + // cvm + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); + } +}; + +class CVMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[N x D]," + " where N is the batch size and D is the emebdding dim. "); + AddInput("CVM", + "(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch " + "size, 2 is show and click."); + AddOutput("Y", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[N x K]."); + AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddComment(R"DOC( +CVM Operator. + + We assume that input X is a embedding vector with cvm_feature(show and click), which shape is [N * D] (D is 2(cvm_feature) + embedding dim, N is batch_size) + if use_cvm is True, we will log(cvm_feature), and output shape is [N * D]. + if use_cvm is False, we will remove cvm_feature from input, and output shape is [N * (D - 2)]. + +)DOC"); + } +}; + +class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("cvm_grad"); + op->SetInput("X", Input("X")); + op->SetInput("CVM", Input("CVM")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker, ops::CVMGradOpDescMaker); + +REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp); + +REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel, ops::CVMOpKernel); + +REGISTER_OP_CPU_KERNEL(cvm_grad, ops::CVMGradOpKernel, + ops::CVMGradOpKernel); diff --git a/paddle/fluid/operators/cvm_op.h b/paddle/fluid/operators/cvm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..38e5a2afa11feace17b8d870cdc3ef0ed38745d7 --- /dev/null +++ b/paddle/fluid/operators/cvm_op.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class CVMOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const LoDTensor* x = context.Input("X"); + const T* x_data = x->data(); + auto lod = x->lod()[0]; + int64_t item_size = x->numel() / x->dims()[0]; + int offset = 2; + if (!context.Attr("use_cvm")) { + item_size -= offset; + } + LoDTensor* y = context.Output("Y"); + T* y_data = y->mutable_data(context.GetPlace()); + + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + + for (int j = 0; j < seq_len; ++j) { + if (context.Attr("use_cvm")) { + std::memcpy(y_data, x_data, item_size * sizeof(T)); + y_data[0] = log(y_data[0] + 1); + y_data[1] = log(y_data[1] + 1) - y_data[0]; + x_data += item_size; + y_data += item_size; + } else { + std::memcpy(y_data, x_data + offset, item_size * sizeof(T)); + x_data += item_size + offset; + y_data += item_size; + } + } + } + } +}; + +template +class CVMGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + LoDTensor* dx = context.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(context.GetPlace()); + + const Tensor* cvm = context.Input("CVM"); + const T* cvm_data = cvm->data(); + int offset = 2; + const framework::LoDTensor* dOut = + context.Input(framework::GradVarName("Y")); + const T* dout_data = dOut->data(); + + auto lod = dx->lod()[0]; + int64_t item_size = dx->numel() / dx->dims()[0]; + if (!context.Attr("use_cvm")) { + item_size -= offset; + } + + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + + for (int j = 0; j < seq_len; ++j) { + if (context.Attr("use_cvm")) { + std::memcpy(dx_data, dout_data, item_size * sizeof(T)); + dx_data[0] = cvm_data[0]; + dx_data[1] = cvm_data[1]; + dx_data += item_size; + dout_data += item_size; + } else { + std::memcpy(dx_data + offset, dout_data, item_size * sizeof(T)); + dx_data[0] = cvm_data[0]; + dx_data[1] = cvm_data[1]; + dx_data += item_size + offset; + dout_data += item_size; + } + } + cvm_data += offset; + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 972b4f67a8388ce68952fa90aaa224cd45c6d226..f6531ec9edca7b425d28853f542d5e46783ba699 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -9,6 +9,9 @@ else() endif() configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY) +cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool) +cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder) + # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") if(WITH_GRPC) @@ -20,7 +23,7 @@ if(WITH_GRPC) collective_client.cc collective_server.cc ${GRPC_SRCS} PROTO send_recv.proto - DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS}) + DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f3b6b959e30194c10b1a58d6fc3e7a61ad01313 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" + +namespace paddle { +namespace operators { +namespace distributed { + +std::once_flag AsyncSparseParamUpdateRecorder::init_flag_; +std::unique_ptr + AsyncSparseParamUpdateRecorder::recorder_(nullptr); + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h new file mode 100644 index 0000000000000000000000000000000000000000..eadd842c7f6ead56006fd0c34814b1b7bd9b62f4 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h @@ -0,0 +1,183 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace distributed { + +class ConcurrentSet { + public: + ConcurrentSet() : pool_(new ::ThreadPool(1)) {} + ~ConcurrentSet() {} + + std::future Update(const std::vector& rows) { + auto task = [this, rows] { + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& id : rows) { + sstream << id << ", "; + } + sstream << "]"; + VLOG(3) << "update ids -> " << sstream.str(); + } + for (auto row : rows) { + set_.insert(row); + } + }; + return pool_->enqueue(std::move(task)); + } + + std::future GetAndClear(std::vector* result) { + auto task = [this, &result] { + result->clear(); + for (auto& id : set_) { + result->push_back(id); + } + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& id : *result) { + sstream << id << ", "; + } + sstream << "]"; + VLOG(3) << "result ids size: " << result->size() << " " + << sstream.str(); + } + set_.clear(); + }; + return pool_->enqueue(std::move(task)); + } + + private: + std::unordered_set set_; + std::unique_ptr<::ThreadPool> pool_{nullptr}; +}; + +class AsyncSparseParamUpdateRecorder { + using TrainerToRows = std::vector>; + + public: + AsyncSparseParamUpdateRecorder( + int trainer_num, + const std::unordered_map& grad_to_param) + : trainer_num_(trainer_num), grad_to_param_(grad_to_param) { + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& item : grad_to_param) { + sstream << item.first << ":" << item.second << ", "; + } + sstream << "]"; + VLOG(3) << "trainer_num: " << trainer_num + << " grad_to_param_: " << sstream.str(); + } + for (auto& iter : grad_to_param) { + param_to_grad_[iter.second] = iter.first; + auto& param_name = iter.second; + param_to_updated_rows_[param_name] = TrainerToRows(); + auto& trainer_to_rows = param_to_updated_rows_[param_name]; + for (auto i = 0; i < trainer_num; ++i) { + trainer_to_rows.emplace_back(new ConcurrentSet()); + } + } + } + + ~AsyncSparseParamUpdateRecorder() = default; + + void Update(const std::string& grad_name, + const std::vector& update_rows) { + VLOG(3) << "update grad: " << grad_name + << " row size: " << update_rows.size(); + auto& param_name = grad_to_param_.at(grad_name); + auto& trainer_to_rows = param_to_updated_rows_.at(param_name); + + std::vector> fs; + for (auto& set : trainer_to_rows) { + fs.push_back(set->Update(update_rows)); + } + for (auto& f : fs) { + f.wait(); + } + } + + void GetAndClear(const std::string& param_name, int trainer_id, + std::vector* result) { + VLOG(3) << "GetAndClear param: " << param_name + << " for trainer: " << trainer_id; + PADDLE_ENFORCE_LT(trainer_id, trainer_num_); + param_to_updated_rows_.at(param_name)[trainer_id] + ->GetAndClear(result) + .wait(); + } + + bool HasParam(const std::string& param_name) { + return param_to_grad_.find(param_name) != param_to_grad_.end(); + } + + bool HasGrad(const std::string& grad_name) { + return grad_to_param_.find(grad_name) != grad_to_param_.end(); + } + + private: + const int trainer_num_; + std::unordered_map grad_to_param_; + std::unordered_map param_to_grad_; + std::unordered_map param_to_updated_rows_; + + // init recorder + public: + static void Init( + int trainer_num, + const std::unordered_map& grad_to_param) { + InitImpl(trainer_num, grad_to_param); + } + + static AsyncSparseParamUpdateRecorder* GetInstance() { + return recorder_.get(); + } + + private: + // Init is called by GetInstance. + static void InitImpl( + int trainer_num, + const std::unordered_map& grad_to_param) { + if (recorder_ == nullptr) { + recorder_.reset( + new AsyncSparseParamUpdateRecorder(trainer_num, grad_to_param)); + } + } + + static std::once_flag init_flag_; + static std::unique_ptr recorder_; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..67e8fd8a0edc4510d0abe885c821e75b528254f8 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" + +#include + +#include "gtest/gtest.h" + +namespace paddle { +namespace operators { +namespace distributed { + +TEST(ConcurrentSet, All) { + ConcurrentSet concurrent_set; + std::vector in1 = {1, 2, 3, 4}; + std::vector in2 = {2, 3, 5, 6}; + + std::vector> futures; + futures.push_back(concurrent_set.Update(in1)); + futures.push_back(concurrent_set.Update(in2)); + + for (auto &f : futures) { + f.wait(); + } + + std::unordered_set in; + std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); + std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); + + std::vector ret; + concurrent_set.GetAndClear(&ret).wait(); + + std::unordered_set out; + std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); + + EXPECT_EQ(in, out); + + concurrent_set.GetAndClear(&ret).wait(); + EXPECT_EQ(ret.size(), 0); +} + +TEST(AsyncSparseParamUpdateRecorder, All) { + std::unordered_map grad_to_param; + grad_to_param["grad1"] = "param1"; + grad_to_param["grad2"] = "param2"; + + int trainer_num = 10; + + AsyncSparseParamUpdateRecorder recorder(trainer_num, grad_to_param); + std::vector in1 = {1, 2, 3, 4}; + std::vector in2 = {2, 3, 5, 6}; + + std::unordered_set in; + std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); + std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); + + recorder.Update("grad1", in1); + recorder.Update("grad1", in2); + + EXPECT_TRUE(recorder.HasParam("param1")); + EXPECT_TRUE(recorder.HasParam("param2")); + EXPECT_FALSE(recorder.HasParam("param3")); + + EXPECT_TRUE(recorder.HasGrad("grad1")); + EXPECT_TRUE(recorder.HasGrad("grad2")); + EXPECT_FALSE(recorder.HasGrad("grad3")); + + std::vector ret; + EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret)); + + for (int i = 0; i < trainer_num; ++i) { + std::vector ret; + std::unordered_set out; + + recorder.GetAndClear("param1", i, &ret); + std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); + + EXPECT_EQ(in, out); + + recorder.GetAndClear("param1", i, &ret); + EXPECT_EQ(ret.size(), 0); + } +} + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.cc b/paddle/fluid/operators/distributed/brpc/brpc_client.cc index a1a3443348129b5cdf057592fced8fdff238ac09..4c22ad8eb4d4b2e23d8a6720e726eb9e2998314e 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.cc @@ -234,6 +234,7 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep, const framework::Scope& scope, const std::string& var_name, const std::string& out_var_name, + const std::string& table_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC, time_out); diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.h b/paddle/fluid/operators/distributed/brpc/brpc_client.h index 501a593b11d35c160348e42ee47216a85647aac4..51864dfdca53eb4b1d9045188a6347781130e785 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.h +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.h @@ -21,8 +21,10 @@ limitations under the License. */ #include #include #include +#include #include // NOLINT #include +#include #include #include "brpc/channel.h" @@ -66,6 +68,7 @@ class BRPCClient : public RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_var_name, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncGetMonomerBarrier( @@ -107,13 +110,11 @@ class BRPCClient : public RPCClient { void SendComplete() override; private: - VarHandlePtr _AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_var_name, - const std::string& method_name, - int64_t time_out = FLAGS_rpc_deadline); + VarHandlePtr _AsyncGetVar( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + const std::string& out_var_name, const std::string& method_name, + const std::string& table_name, int64_t time_out = FLAGS_rpc_deadline); void Proceed(); ChannelQueuePtr GetChannel(const std::string& ep); diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index eba18c67771fa26eed855b0f19591e06101f424d..b528bcdd32b11d686f44596d9a1bb663b21691f4 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -32,6 +32,9 @@ DEFINE_int32(communicator_send_queue_size, 20, DEFINE_int32(communicator_max_send_grad_num_before_recv, 20, "max grad num to send before recv parameters"); DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv"); +DEFINE_int32(communicator_send_wait_times, 5, + "times that send thread will wait if merge num does not reach " + "max_merge_var_num"); DEFINE_int32(communicator_max_merge_var_num, 20, "max var num to merge and send"); DEFINE_bool(communicator_fake_rpc, false, @@ -65,6 +68,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, << FLAGS_communicator_max_send_grad_num_before_recv; VLOG(0) << "communicator_thread_pool_size: " << FLAGS_communicator_thread_pool_size; + VLOG(0) << "communicator_send_wait_times: " + << FLAGS_communicator_send_wait_times; VLOG(0) << "communicator_max_merge_var_num: " << FLAGS_communicator_max_merge_var_num; VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; @@ -101,20 +106,32 @@ void Communicator::SendThread() { VLOG(3) << var_name << " merge and send"; std::vector> vars; size_t merged_var_num = 0; - while (var_queue->Size() > 0 && - merged_var_num < FLAGS_communicator_max_merge_var_num) { - vars.push_back(var_queue->Pop()); - // only count the send number of the first var - if (var_name == send_varname_to_queue_.begin()->first) { - grad_num_.fetch_add(1, std::memory_order_relaxed); + size_t wait_times = 0; + while (merged_var_num < FLAGS_communicator_max_merge_var_num) { + if (var_queue->Size() == 0) { + VLOG(3) << "wait_times -> " << wait_times; + if (wait_times >= FLAGS_communicator_send_wait_times) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } else { + wait_times = 0; + + vars.push_back(var_queue->Pop()); + // only count the send number of the first var + if (var_name == send_varname_to_queue_.begin()->first) { + grad_num_.fetch_add(1, std::memory_order_relaxed); + } + merged_var_num++; } - merged_var_num++; } auto before_merge = GetCurrentUS(); MergeVars(var_name, vars, send_scope_.get()); auto after_merge = GetCurrentUS(); - VLOG(3) << "merge " << var_name << " use time " - << after_merge - before_merge; + VLOG(3) << "merge " << merged_var_num << " " << var_name + << " use time " << after_merge - before_merge; auto send_functor = distributed::ParameterSend(); auto &ctx = send_varname_to_ctx_.at(var_name); if (!FLAGS_communicator_fake_rpc) { diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 41155bfc31bb31520fdcf5bd50b203f2e1f2c516..37c39eb15112f745f6a25e95ce65d431d825182e 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -109,7 +109,7 @@ inline void MergeVars(const std::string& var_name, auto* out_var = scope->Var(var_name); if (var0->IsType()) { auto dims = var0->Get().dims(); - VLOG(3) << "merge " << var_name << " LoDTensor " << dims; + VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims; // init output tensor auto* out_t = out_var->GetMutable(); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 61e94dae3c7a107e10fa5e5518651014cec078bc..8504110c6e9dbfe22b78063999ed4a9e36850e2c 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname, - "/sendrecv.SendRecvService/GetVariable", time_out); + "/sendrecv.SendRecvService/GetVariable", table_name, + time_out); } VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( @@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( return _AsyncGetVar( ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname, - "/sendrecv.SendRecvService/GetVariableNoBarrier", time_out); + "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out); } VarHandlePtr GRPCClient::AsyncGetMonomerVariable( @@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable( const framework::Scope& scope, const std::string& var_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name, - "/sendrecv.SendRecvService/GetMonomerVariable", time_out); + "/sendrecv.SendRecvService/GetMonomerVariable", "", + time_out); } VarHandlePtr GRPCClient::_AsyncGetVar( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& method, const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, int64_t time_out) { + const std::string& rpc_path, const std::string& table_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const std::string out_varname_val = out_varname; + const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); GetProcessor* s = new GetProcessor(ch); @@ -169,32 +174,33 @@ VarHandlePtr GRPCClient::_AsyncGetVar( VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO( - [var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] { - // prepare input - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_out_varname(out_varname_val); - req.set_trainer_id(trainer_id_); - ::grpc::ByteBuffer buf; - RequestToByteBuffer(req, &buf); + framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method, + p_ctx, h, rpc_path, this] { + // prepare input + sendrecv::VariableMessage req; + req.set_varname(var_name_val); + req.set_out_varname(out_varname_val); + req.set_trainer_id(trainer_id_); + req.set_table_name(table_name_val); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - // stub context - s->response_call_back_ = ProcGetResponse; + // stub context + s->response_call_back_ = ProcGetResponse; - platform::RecordRPCEvent record_event(method); + platform::RecordRPCEvent record_event(method); - auto call = - s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + auto call = + s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); req_count_++; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index ce0d2152aa27c62b6e12881aaf2ae458597e67e6..ad2f04a6d1dda34e35b67b21dce8ac612ff697a0 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -23,9 +23,11 @@ limitations under the License. */ #include #include #include +#include #include // NOLINT #include #include // NOLINT +#include #include #include "grpc++/channel.h" @@ -187,6 +189,7 @@ class GRPCClient : public RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncGetVarNoBarrier( @@ -239,7 +242,8 @@ class GRPCClient : public RPCClient { const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& method, const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline); + const std::string& rpc_path, const std::string& table_name = "", + int64_t time_out = FLAGS_rpc_deadline); private: grpc::CompletionQueue cq_; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index 0eb313f75dfa64f8722faa365128f3111f72bd0b..75526bed0f0eadada65279ec05757da7a469f984 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -137,6 +137,7 @@ class RequestGet final : public RequestBase { // proc request. std::string varname = request_.varname(); std::string out_varname = request_.out_varname(); + std::string table_name = request_.table_name(); int trainer_id = request_.trainer_id(); VLOG(4) << "RequestGet " << out_varname << " from " << varname; @@ -145,19 +146,23 @@ class RequestGet final : public RequestBase { framework::Variable* invar = nullptr; framework::Variable* outvar = nullptr; - request_handler_->Handle(varname, scope, invar, &outvar, trainer_id, - out_varname); + tmp_scope_ = std::move(scope->NewTmpScope()); + request_handler_->Handle(varname, tmp_scope_.get(), invar, &outvar, + trainer_id, out_varname, table_name); + VLOG(1) << "before SerializeToByteBuffer"; if (outvar) { SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), &reply_); } + VLOG(1) << "after SerializeToByteBuffer"; Finish(reply_, &responder_); } protected: sendrecv::VariableMessage request_; ::grpc::ByteBuffer reply_; + std::unique_ptr tmp_scope_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; }; diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index e7d4c262aa9fad10a23adc61b94ba0c38577c0e8..da73167ae603fb8c8ba9deabe118269891d1f52a 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -42,27 +42,23 @@ using DDim = framework::DDim; template void ParameterRecv::operator()(const RpcContext &rpc_ctx, const framework::Scope &scope) { - VLOG(3) << "ParameterRecv in"; + VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name; std::unique_ptr local_scope = scope.NewTmpScope(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *recv_var = scope.FindVar(rpc_ctx.var_name); - std::vector recved_tensors; - // recv all vars to local scope if (recv_var->IsType()) { std::vector rets; for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { auto &recv_var_name = rpc_ctx.splited_var_names[i]; - framework::Tensor *t = - local_scope->Var(recv_var_name)->GetMutable(); - recved_tensors.push_back(t); + local_scope->Var(recv_var_name); VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, @@ -78,23 +74,61 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, // concat recved tensor into one var { size_t output_offset = 0; + size_t row_offset = 0; framework::Tensor *recv_tensor = recv_var->GetMutable(); auto dev_ctx = paddle::platform::CPUDeviceContext(); int64_t recv_numel = 0; - for (auto *in : recved_tensors) { - recv_numel += in->numel(); - auto in_stride = framework::stride_numel(in->dims()); - auto out_stride = framework::stride_numel(recv_tensor->dims()); - StridedNumelCopyWithAxis( - dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, - in->data(), in_stride, in_stride[0]); - output_offset += in_stride[0]; + for (auto &recv_var_name : rpc_ctx.splited_var_names) { + auto *recv_var = local_scope->FindVar(recv_var_name); + if (recv_var->IsType()) { + auto &in = recv_var->Get(); + recv_numel += in.numel(); + auto in_stride = framework::stride_numel(in.dims()); + auto out_stride = framework::stride_numel(recv_tensor->dims()); + StridedNumelCopyWithAxis( + dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, + in.data(), in_stride, in_stride[0]); + output_offset += in_stride[0]; + } else if (recv_var->IsType()) { + auto &recv_slr = recv_var->Get(); + auto &recv_dims = recv_tensor->dims(); + int64_t width = recv_dims[1]; + recv_numel += recv_slr.height() * width; + PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width); + PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size()); + VLOG(3) << "recv slr " << recv_var_name << " dims " + << recv_slr.value().dims(); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto &row_id : recv_slr.rows()) { + sstream << row_id << ", "; + } + sstream << "]"; + VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " " + << sstream.str(); + } + + for (auto i = 0; i < recv_slr.rows().size(); ++i) { + auto row_id = recv_slr.rows()[i] + row_offset; + PADDLE_ENFORCE_LT(row_id, recv_dims[0]); + memcpy(recv_tensor->data() + row_id * width, + recv_slr.value().data() + i * width, sizeof(T) * width); + } + row_offset += recv_slr.height(); + } else { + PADDLE_THROW("unsupported recieved var type"); + } + } + auto numel = recv_tensor->numel(); + if (recv_numel != numel) { + LOG(FATAL) << "recv_numel: " << recv_numel << " acture numel: " << numel; } - PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel()); + PADDLE_ENFORCE_EQ(recv_numel, numel); } - VLOG(3) << "ParameterRecv out"; + VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; } template struct ParameterRecv; diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 9ce424445229cde0a7e775c95f4af8839f4d4d68..dfabad567af590b65b9e777824d476fce2b17238 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -47,7 +47,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *send_var = scope.FindVar(rpc_ctx.var_name); size_t out_num = rpc_ctx.splited_var_names.size(); diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 991158ac72007efc1233f852caed4f90f35fe1cd..de8f30184611aeb961e2ab69b05779c56371b976 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -18,7 +18,9 @@ #include // NOLINT #include +#include #include +#include #include #include @@ -180,6 +182,10 @@ class RequestHandler { grad_to_prepared_ctx_ = g; } + void SetSparseGradToParam(std::unordered_map* g) { + sparse_grad_to_param_ = g; + } + void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } // Get attributes. @@ -228,6 +234,7 @@ class RequestHandler { std::unordered_map>* grad_to_prepared_ctx_; + std::unordered_map* sparse_grad_to_param_; RPCServer* rpc_server_; }; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index e289ec929dbd6643a2518b92c1a25b7d63e790a9..a41536368abc925531d1a54615546a100482a7eb 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/printf.h" @@ -59,6 +60,12 @@ bool RequestSendHandler::Handle(const std::string& varname, "async mode should not recv BATCH_BARRIER_MESSAGE or " "COMPLETE_MESSAGE"); } + if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(varname)) { + auto& grad_slr = + scope->FindVar(varname)->Get(); + AsyncSparseParamUpdateRecorder::GetInstance()->Update(varname, + grad_slr.rows()); + } executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope); return true; @@ -82,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname, const int trainer_id, const std::string& out_var_name, const std::string& table_name) { - VLOG(4) << "RequestGetHandler:" << varname - << " out_var_name: " << out_var_name; + VLOG(3) << "RequestGetHandler:" << varname + << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id + << " table_name: " << table_name; if (sync_mode_) { if (varname == FETCH_BARRIER_MESSAGE) { @@ -108,7 +116,42 @@ bool RequestGetHandler::Handle(const std::string& varname, VLOG(3) << "copying " << varname << " to " << param_bak_name; framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); } - *outvar = scope_->FindVar(varname); + if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && + !table_name.empty()) { + std::vector updated_rows; + AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( + varname, trainer_id, &updated_rows); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& row_id : updated_rows) { + sstream << row_id << ", "; + } + sstream << "]"; + VLOG(3) << "updated_rows size: " << updated_rows.size() << " " + << sstream.str(); + } + auto& origin_tensor = + scope_->FindVar(varname)->Get(); + auto* origin_tensor_data = origin_tensor.data(); + auto& dims = origin_tensor.dims(); + *outvar = scope->Var(); + auto* out_slr = (*outvar)->GetMutable(); + out_slr->set_rows(updated_rows); + out_slr->set_height(dims[0]); + auto out_dims = framework::make_ddim( + {static_cast(updated_rows.size()), dims[1]}); + auto* data = out_slr->mutable_value()->mutable_data( + out_dims, origin_tensor.place()); + auto width = dims[1]; + for (auto i = 0; i < updated_rows.size(); ++i) { + PADDLE_ENFORCE_LT(updated_rows[i], dims[0]); + memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, + sizeof(float) * width); + } + } else { + *outvar = scope_->FindVar(varname); + } } } return true; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index ea54e0c2951253fc009672f4cd2e5233ed56944e..d4be2c28fdbaa4beef62402155de5b677ed67e9b 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -15,6 +15,7 @@ #pragma once #include // NOLINT +#include #include #include "gflags/gflags.h" @@ -44,6 +45,7 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncGetVarNoBarrier( @@ -96,6 +98,7 @@ class RPCClient { // Init is called by GetInstance. template static void Init(int trainer_id) { + VLOG(0) << "init rpc client with trainer_id " << trainer_id; trainer_id_ = trainer_id; if (rpc_client_.get() == nullptr) { rpc_client_.reset(new T()); diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index 3de89c2ae89d29edc317ca123882d1c55038b6ca..eb127bf4ad5a5c9a28210e2fbcdb69b07543f4b9 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -27,23 +27,26 @@ struct RpcContext { RpcContext(const std::string &name, const std::vector &names, const std::vector &emap, - const std::vector §ions) + const std::vector §ions, int id) : var_name(name), splited_var_names(names), epmap(emap), - height_sections(sections) {} + height_sections(sections), + trainer_id(id) {} RpcContext(const RpcContext &ctx) { var_name = ctx.var_name; splited_var_names = ctx.splited_var_names; epmap = ctx.epmap; height_sections = ctx.height_sections; + trainer_id = ctx.trainer_id; } std::string var_name; std::vector splited_var_names; std::vector epmap; std::vector height_sections; + int trainer_id; }; inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index a1ef1af39ff2ab1456706ebafbd3d7ce1acc0c07..1096f3773c6d44560d370502b1c550d67d40ca64 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb snappystream snappy protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index 5b30ed472d51a37a0705d1717395da9e4ff7d743..a672fb2a9141a81383d947dcc961a112aee3f7ac 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -24,8 +24,10 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" + #include "paddle/fluid/platform/profiler.h" DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send"); @@ -292,6 +294,8 @@ static void FillRequestCtx( std::unordered_map> *prefetch_ctx, + std::unordered_map + *sparse_grad_name_to_param_name, std::shared_ptr checkpoint_ctx, distributed::RPCServer *rpc_server) { h->SetScope(scope); @@ -299,6 +303,7 @@ static void FillRequestCtx( h->SetExecutor(executor); h->SetProgram(program); h->SetPrefetchPreparedCtx(prefetch_ctx); + h->SetSparseGradToParam(sparse_grad_name_to_param_name); h->SetRPCServer(rpc_server); h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); } @@ -414,10 +419,24 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; } - auto f = - std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, - &executor, program, &prefetch_var_name_to_prepared_ctx, - ckpt_pre_context, rpc_service_.get()); + // parse attr of kSparseGradToParam sparse_grad_name -> param_name + std::unordered_map sparse_grad_name_to_param_name; + auto sparse_grad_name_to_param_name_str = + Attr>(kSparseGradToParam); + for (const auto &sparse_grad_name_and_param_name : + sparse_grad_name_to_param_name_str) { + std::vector pieces; + split(sparse_grad_name_and_param_name, ':', &pieces); + PADDLE_ENFORCE_EQ(pieces.size(), 2); + VLOG(3) << "after split, sparse_grad_name = " << pieces[0] + << ", param_name = " << pieces[1]; + sparse_grad_name_to_param_name[pieces[0]] = pieces[1]; + } + + auto f = std::bind( + FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor, + program, &prefetch_var_name_to_prepared_ctx, + &sparse_grad_name_to_param_name, ckpt_pre_context, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get()); @@ -445,6 +464,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, prefetch_block_id_list, checkpoint_block_id); } else { + distributed::AsyncSparseParamUpdateRecorder::Init( + fan_in, sparse_grad_name_to_param_name); RunAsyncLoop(&executor, program, &recv_scope); } } @@ -475,6 +496,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>(kPrefetchVarNameToBlockId, "prefetch blocks to run on server side.") .SetDefault({}); + AddAttr>( + kSparseGradToParam, + "sparse grad name to param name. like: 'emb@Grad:emb'") + .SetDefault({}); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); AddAttr(kCheckpointBlockId, diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h index f20442bad7c5bd96173b9d6efc4dceb13feacf5b..1cf2130d7a593077d1145b4f3be379c32557dd53 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h @@ -16,8 +16,10 @@ limitations under the License. */ #include #include +#include #include #include +#include #include #include @@ -35,6 +37,7 @@ namespace operators { constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kCheckpointBlockId[] = "checkpint_block_id"; +constexpr char kSparseGradToParam[] = "sparse_grad_to_param"; void RunServer(std::shared_ptr service); diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 3fd0700a077321d931e87b1d94c3637d167c9eff..8e9846b1fc89953526149be3838103526d5c441b 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &ctx = *pool.Get(place); + auto trainer_id = Attr("trainer_id"); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector recv_varnames = Attr>("recv_varnames"); if (recv_varnames.size() > 0) { auto recv_functor = distributed::ParameterRecv(); - auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); + auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}, + trainer_id); recv_functor(rpc_ctx, scope); } else { if (with_barrier) { diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index b08cd0942f8c89b60d722c931d0cec2063b96578..5731bcc15a07074b3d77873c5cdcbb70dc41aba8 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase { auto epmap = Attr>("epmap"); int sync_send = Attr("sync_mode"); + auto trainer_id = Attr("trainer_id"); auto send_varnames = Attr>("send_varnames"); auto height_sections = Attr>("sections"); @@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase { if (distributed::Communicator::GetInstance() == nullptr) { auto send_functor = distributed::ParameterSend(); auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, - height_sections); + height_sections, trainer_id); send_functor(rpc_ctx, scope, true); } else { distributed::Communicator::GetInstance()->Send(ins[0], scope); @@ -62,8 +63,7 @@ class SendOp : public framework::OperatorBase { auto& ctx = *pool.Get(place); distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector rets; for (size_t i = 0; i < ins.size(); i++) { diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 33bd275e5cc507ec700b3694cd8b1df9672ec512..7d551106756070a14f94f39f19b775d022d90777 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -235,11 +235,13 @@ struct FindRangeAbsMaxFunctor { int g_find_max; memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, - sizeof(int), 0); + sizeof(int), ctx.stream()); + ctx.Wait(); if (g_find_max) { int len; memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, - sizeof(int), 0); + sizeof(int), ctx.stream()); + ctx.Wait(); FindAbsMaxFunctor()(ctx, scale_arr, len, out_scale_data); } @@ -258,25 +260,26 @@ struct FindMovingAverageAbsMaxFunctor { const auto gpu_place = boost::get(ctx.GetPlace()); T accum; - memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data(), - sizeof(T), 0); T state; - memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data(), - sizeof(T), 0); T scale; + memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data(), + sizeof(T), ctx.stream()); + memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data(), + sizeof(T), ctx.stream()); memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T), - 0); - + ctx.stream()); + ctx.Wait(); state = rate * state + 1; accum = rate * accum + scale; scale = accum / state; memory::Copy(gpu_place, out_accum->mutable_data(gpu_place), - platform::CPUPlace(), &accum, sizeof(T), 0); + platform::CPUPlace(), &accum, sizeof(T), ctx.stream()); memory::Copy(gpu_place, out_state->mutable_data(gpu_place), - platform::CPUPlace(), &state, sizeof(T), 0); + platform::CPUPlace(), &state, sizeof(T), ctx.stream()); memory::Copy(gpu_place, out_scale->mutable_data(gpu_place), - platform::CPUPlace(), &scale, sizeof(T), 0); + platform::CPUPlace(), &scale, sizeof(T), ctx.stream()); + ctx.Wait(); } }; diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 241184c6f4a19a1da0d6d75c5d4e2b372c14e9da..57a1fcd42da04a766ebd8713e3863f259b3784ac 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/grid_sampler_op.h" +#include #include "paddle/fluid/framework/op_registry.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" @@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel { "Input(X) of GridSampleOp should be 4-D Tensor."); PADDLE_ENFORCE(grid_dims.size() == 4, "Input(Grid) of GridSampleOp should be 4-D Tensor."); - PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); - PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], - "Input(X) and Input(Grid) dims[0] should be equal."); + if (ctx->IsRuntime() || grid_dims[3] > 0) { + PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); + } if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], + "Input(X) and Input(Grid) dims[0] should be equal."); PADDLE_ENFORCE_EQ( grid_dims[1], x_dims[2], "Input(X) dims[2] and Input(Grid) dims[1] should be equal."); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 82c8171ca52ffb128df103f27bafbdba1e72e52f..7cfe0aabcb7f3ce86ccc3a9a1c54b3b60d384aa1 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -238,6 +238,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { zero(dev_ctx, w_grad, static_cast(0.0)); bit_code->MulGradWeight(pre_out_grad, w_grad, in); } else { + PADDLE_ENFORCE(path != nullptr, + "Sparse mode should not be used without custom tree!"); framework::Vector real_rows = PathToRows(*path); auto* w_grad = ctx.Output(framework::GradVarName("W")); diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 9f2e3ad4a5ac1786096c67154d5a9ef5ea62855c..900b0c636ddafc8c033560adf58d596eb696621f 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -45,9 +45,14 @@ class InterpolateOp : public framework::OperatorWithKernel { // round down out_h = static_cast(dim_x[2] * scale); out_w = static_cast(dim_x[3] * scale); + // protect when input shape is -1 + out_h = out_h > 0 ? out_h : -1; + out_w = out_w > 0 ? out_w : -1; } else { out_h = ctx->Attrs().Get("out_h"); out_w = ctx->Attrs().Get("out_w"); + PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0."); + PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0."); } if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { @@ -58,6 +63,7 @@ class InterpolateOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", "Out"); return; } + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); } diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index a43f22c0496f89943d2fd5110446f1aae6a99315..a7c5d6305b09afb93be0b3b8524a91bd53e719fe 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -35,8 +35,10 @@ class KLDivLossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), "Input(X) rank and Input(Target) rank should be same."); for (int i = 0; i < dim_x.size(); i++) { - PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], - "Input(X) and Input(Target) should in same shape."); + if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) { + PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], + "Input(X) and Input(Target) should in same shape."); + } } auto reduction = ctx->Attrs().Get("reduction"); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index b99115e44b31536f0fd0a9078b40d07949be86f0..647d4f14842ee38bbd8a5d07563ea29ff0432e1a 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -296,6 +296,7 @@ struct MergeAdd { auto input_height = has_value_input->height(); framework::SelectedRows& out = *output; std::set merged_row_set; + size_t row_num = 0; for (auto* input : inputs) { if (input->rows().size() == 0) { continue; @@ -305,42 +306,71 @@ struct MergeAdd { "dimension except for the first one"); PADDLE_ENFORCE_EQ(input_height, input->height(), "all input should have same height"); + row_num += input->rows().size(); merged_row_set.insert(input->rows().begin(), input->rows().end()); } - std::vector merge_rows(merged_row_set.begin(), - merged_row_set.end()); - if (sorted_result) { - std::sort(merge_rows.begin(), merge_rows.end()); - } - std::unordered_map rows_to_id; - for (size_t i = 0; i < merge_rows.size(); ++i) { - rows_to_id[merge_rows[i]] = i; - } - out.set_rows(merge_rows); + out.set_height(input_height); out.mutable_value()->mutable_data( framework::make_ddim( - {static_cast(merge_rows.size()), input_width}), + {static_cast(merged_row_set.size()), input_width}), context.GetPlace()); + auto* out_data = out.mutable_value()->data(); - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); + if (merged_row_set.size() == row_num && !sorted_result) { + // no duplicated ids, just concat the result together + std::vector merge_rows; + merge_rows.reserve(row_num); + // concat rows + for (auto* in : inputs) { + merge_rows.insert(merge_rows.end(), in->rows().begin(), + in->rows().end()); + } + out.set_rows(merge_rows); + auto in_place = inputs[0]->place(); + auto out_place = out.place(); + int64_t copied_numel = 0; + for (auto* in : inputs) { + auto* in_data = in->value().data(); + auto in_numel = in->value().numel(); + memory::Copy(boost::get(out_place), + out_data + copied_numel, + boost::get(in_place), in_data, + in_numel * sizeof(T)); + copied_numel += in_numel; + } + } else { + std::vector merge_rows(merged_row_set.begin(), + merged_row_set.end()); - auto* out_data = out.mutable_value()->data(); + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); + } - auto blas = math::GetBlas(context); - for (auto* input : inputs) { - if (input->rows().size() == 0) { - continue; + out.set_rows(merge_rows); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + std::unordered_map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; } - auto* input_data = input->value().data(); - auto& input_rows = input->rows(); - - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_to_id[input_rows[i]]; - elementwise_add_to( - context, &blas, static_cast(input_width), - &input_data[i * input_width], &out_data[out_i * input_width]); + + auto blas = math::GetBlas(context); + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_to_id[input_rows[i]]; + elementwise_add_to( + context, &blas, static_cast(input_width), + &input_data[i * input_width], &out_data[out_i * input_width]); + } } } } diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index aedb82da2f0fb2f15e1586d351af7c9d4364852b..5581b9e040272e224669d612409f88d61f794443 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/selected_rows_functor.h" + +#include #include #include "gtest/gtest.h" + #include "paddle/fluid/operators/math/math_function.h" TEST(selected_rows_functor, cpu_add) { @@ -360,6 +363,69 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { } } +TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + set_const; + + int64_t height = 10; + int64_t row_numel = 8; + + std::vector rows1{1, 3, 5, 7, 9}; + std::unique_ptr selected_rows1{ + new paddle::framework::SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows1.size()), row_numel}), + cpu_place); + set_const(ctx, in1_value, 1.0); + + std::vector rows2{0, 2, 4, 6, 8}; + std::unique_ptr selected_rows2{ + new paddle::framework::SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows2.size()), row_numel}), + cpu_place); + set_const(ctx, in2_value, 2.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + output->set_height(height); + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + + std::vector inputs; + inputs.push_back(selected_rows1.get()); + inputs.push_back(selected_rows2.get()); + merge_add_functor(ctx, inputs, output.get()); + + EXPECT_EQ(output->height(), height); + EXPECT_EQ(output->value().dims(), + paddle::framework::make_ddim({10, row_numel})); + + std::vector ret_rows{1, 3, 5, 7, 9, 0, 2, 4, 6, 8}; + EXPECT_EQ(output->rows(), ret_rows); + + auto* out_data = output->value().data(); + for (size_t i = 0; i < ret_rows.size(); ++i) { + float data_value = 0; + if (i < 5) { + data_value = 1.0; + } else { + data_value = 2.0; + } + for (size_t j = 0; j < static_cast(row_numel); ++j) { + EXPECT_EQ(out_data[i * row_numel + j], data_value); + } + } +} + TEST(selected_rows_functor, cpu_sum_to) { paddle::platform::CPUPlace cpu_place; paddle::platform::CPUDeviceContext ctx(cpu_place); diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc index b087fbbb94c7ba2f7449f6bda56010dee1c38ea6..a3ca9ae0675472cb4f0bcd6f404f39004e7cc62f 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" -REGISTER_REDUCE_OP(reduce_all); +REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all); REGISTER_OP_CPU_KERNEL(reduce_all, ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc index d865dcb3c935b76b8da25d723a5f780fb4de255b..34f0fffc9adef240c6fa222540710537587010c5 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" -REGISTER_REDUCE_OP(reduce_any); +REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any); REGISTER_OP_CPU_KERNEL(reduce_any, ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 540742c4cd8b0efc4c6cf095d7a8b3516f551d4c..c86591fdafa3d33bb3c7d75bf9f4f3b041a7a9cb 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -270,3 +270,12 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(op_name, ops::ReduceOp, __##op_name##Maker__, \ paddle::framework::DefaultGradOpDescMaker); \ REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp) + +#define REGISTER_REDUCE_OP_WITHOUT_GRAD(op_name) \ + class __##op_name##Maker__ : public ops::ReduceOpMaker { \ + protected: \ + virtual std::string GetName() const { return #op_name; } \ + virtual std::string GetOpType() const { return "Reduce " #op_name; } \ + }; \ + REGISTER_OPERATOR(op_name, ops::ReduceOp, __##op_name##Maker__, \ + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc index 5c92588cc1d073612d2f6a7b315edf16cc14bedd..1c2726454f3d1fb8545e5d3260e59fcafbcb2aee 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -34,15 +34,22 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto labels_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(labels_dims.size(), 2, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], - "The 2nd dimension of Input(X) and Input(Label) should " - "be equal."); + + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, labels_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank), + framework::slice_ddim(labels_dims, 0, rank), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); @@ -65,23 +72,24 @@ class SigmoidCrossEntropyWithLogitsGradOp auto x_dims = ctx->GetInputDim("X"); auto labels_dims = ctx->GetInputDim("Label"); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(labels_dims.size(), 2, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(dout_dims.size(), 2, - "Input(Out@Grad)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], - "The 2nd dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0], - "The 1st dimension of Input(X) and Input(Out@Grad) " - "should be equal."); - PADDLE_ENFORCE_EQ(x_dims[1], dout_dims[1], - "The 2nd dimension of Input(X) and Input(Out@Grad) " - "should be equal."); + + int rank = x_dims.size(); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank), + framework::slice_ddim(labels_dims, 0, rank), + "Input(X) and Input(Label) shall have the same shape."); + + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, 0, rank), + framework::slice_ddim(dout_dims, 0, rank), + "Input(X) and Input(Out@Grad) shall have the same shape."); + } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 04f659a465a345653d251cbe6703309c804fe614..ec5ee487729d0650983d553dbffe14b63c16b26a 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel { } auto dim_u = ctx->GetInputDim("U"); auto dim_v = ctx->GetInputDim("V"); - PADDLE_ENFORCE_EQ(dim_u[0], h, - "Input(U) dims[0] should be equal to " - "Input(Weight) dims[Attr(dim)]"); - PADDLE_ENFORCE_EQ( - dim_v[0], w, - "Input(V) dims[0] should be equal to " - "the product of Input(Weight) dims except dims[Attr(dim)]"); + + if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) { + PADDLE_ENFORCE_EQ(dim_u[0], h, + "Input(U) dims[0] should be equal to " + "Input(Weight) dims[Attr(dim)]"); + } + + if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) { + PADDLE_ENFORCE_EQ( + dim_v[0], w, + "Input(V) dims[0] should be equal to " + "the product of Input(Weight) dims except dims[Attr(dim)]"); + } ctx->SetOutputDim("Out", dim_weight); ctx->ShareLoD("Weight", /*->*/ "Out"); diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index a05582ae09e16ee17194d299d713d321f28ccace..a43bad878179d02c41d8c8bcd6b43eaffaa6e9a2 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -39,14 +39,22 @@ class SplitOp : public framework::OperatorWithKernel { if (num > 0) { int64_t in_axis_dim = in_dims[axis]; - PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, - "tensor split does not result" - " in an equal division"); - size_t out_axis_dim = in_axis_dim / num; - for (size_t i = 0; i < outs_number; ++i) { - auto dim = in_dims; - dim[axis] = out_axis_dim; - outs_dims.push_back(dim); + if (ctx->IsRuntime() || in_axis_dim > 0) { + PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, + "tensor split does not result" + " in an equal division"); + size_t out_axis_dim = in_axis_dim / num; + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else { + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = -1; + outs_dims.push_back(dim); + } } } else if (sections.size() > 0) { PADDLE_ENFORCE_EQ(sections.size(), outs_number, diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 983d8243b1d8aa6c8d01855d6dbeab76c335f70c..3dc2b0c895116155f41df3ca66125fff3ede5ead 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -175,6 +175,7 @@ def __bootstrap__(): read_env_flags.append('communicator_thread_pool_size') read_env_flags.append('communicator_max_merge_var_num') read_env_flags.append('communicator_fake_rpc') + read_env_flags.append('communicator_send_wait_times') if core.is_compiled_with_brpc(): read_env_flags.append('max_body_size') #set brpc max body size diff --git a/python/paddle/fluid/contrib/tests/test_calibration.py b/python/paddle/fluid/contrib/tests/test_calibration.py index 00885eb5d6057b4a7738705007a9334da6aea9d0..16de60440245dbe6d73dd499e851dd18a280825f 100644 --- a/python/paddle/fluid/contrib/tests/test_calibration.py +++ b/python/paddle/fluid/contrib/tests/test_calibration.py @@ -147,10 +147,11 @@ class TestCalibrationForResnet50(unittest.TestCase): self.data_cache_folder) os.system(cmd) - self.batch_size = 1 - self.sample_iterations = 50 + self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50 + self.sample_iterations = 50 if os.environ.get( + 'DATASET') == 'full' else 1 self.infer_iterations = 50000 if os.environ.get( - 'DATASET') == 'full' else 50 + 'DATASET') == 'full' else 1 def cache_unzipping(self, target_folder, zip_path): if not os.path.exists(target_folder): @@ -279,15 +280,15 @@ class TestCalibrationForResnet50(unittest.TestCase): def test_calibration(self): self.download_model() print("Start FP32 inference for {0} on {1} images ...").format( - self.model, self.infer_iterations) + self.model, self.infer_iterations * self.batch_size) (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(self.model_cache_folder + "/model") print("Start INT8 calibration for {0} on {1} images ...").format( - self.model, self.sample_iterations) + self.model, self.sample_iterations * self.batch_size) self.run_program( self.model_cache_folder + "/model", True, algo=self.algo) print("Start INT8 inference for {0} on {1} images ...").format( - self.model, self.infer_iterations) + self.model, self.infer_iterations * self.batch_size) (int8_throughput, int8_latency, int8_acc1) = self.run_program("calibration_out") delta_value = fp32_acc1 - int8_acc1 diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 93e46eef16fb177169db679a8437d9a33ed38e99..81b7eabbbe1d1da5fa4acb851f4d2ed506efd319 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -196,6 +196,7 @@ __all__ = [ 'npair_loss', 'pixel_shuffle', 'fsp_matrix', + 'continuous_value_model', ] kIgnoreIndex = -100 @@ -5720,12 +5721,21 @@ def hsigmoid(input, raise ValueError( "num_classes must not be less than 2 with default tree") + if (not is_custom) and (is_sparse): + print("Sparse mode should not be used without custom tree") + is_sparse = False + + if (not is_custom) and ((path_table is not None) or + (path_code is not None)): + raise ValueError( + "only num_classes should be passed without custom tree") + if (is_custom) and (path_code is None): - raise ValueError("path_code should not be None with costum tree") + raise ValueError("path_code should not be None with custom tree") elif (is_custom) and (path_table is None): - raise ValueError("path_table should not be None with costum tree") + raise ValueError("path_table should not be None with custom tree") elif (is_custom) and (num_classes is None): - raise ValueError("num_classes should not be None with costum tree") + raise ValueError("num_classes should not be None with custom tree") else: pass @@ -11202,3 +11212,54 @@ def fsp_matrix(x, y): input_param_name='x')) helper.append_op(type='fsp', inputs={'X': x, 'Y': y}, outputs={'Out': out}) return out + + +def continuous_value_model(input, cvm, use_cvm=True): + """ + + **continuous_value_model layers** + + continuous value model(cvm). Now, it only considers show and click value in CTR project. + We assume that input is an embedding vector with cvm_feature, whose shape is [N * D] (D is 2 + embedding dim). + If use_cvm is True, it will log(cvm_feature), and output shape is [N * D]. + If use_cvm is False, it will remove cvm_feature from input, and output shape is [N * (D - 2)]. + + This layer accepts a tensor named input which is ID after embedded(lod level is 1), cvm is a show_click info. + + Args: + + input (Variable): a 2-D LodTensor with shape [N x D], where N is the batch size, D is 2 + the embedding dim. lod level = 1. + cvm (Variable): a 2-D Tensor with shape [N x 2], where N is the batch size, 2 is show and click. + use_cvm (bool): use cvm or not. if use cvm, the output dim is the same as input + if don't use cvm, the output dim is input dim - 2(remove show and click) + (cvm op is a customized op, which input is a sequence has embedd_with_cvm default, so we need an op named cvm to decided whever use it or not.) + + Returns: + + Variable: A 2-D LodTensor with shape [N x D], if use cvm, D is equal to input dim, if don't use cvm, D is equal to input dim - 2. + + Examples: + + .. code-block:: python + + input = fluid.layers.data(name="input", shape=[-1, 1], lod_level=1, append_batch_size=False, dtype="int64")#, stop_gradient=False) + label = fluid.layers.data(name="label", shape=[-1, 1], append_batch_size=False, dtype="int64") + embed = fluid.layers.embedding( + input=input, + size=[100, 11], + dtype='float32') + ones = fluid.layers.fill_constant_batch_size_like(input=label, shape=[-1, 1], dtype="int64", value=1) + show_clk = fluid.layers.cast(fluid.layers.concat([ones, label], axis=1), dtype='float32') + show_clk.stop_gradient = True + input_with_cvm = fluid.layers.continuous_value_model(embed, show_clk, True) + + """ + helper = LayerHelper('cvm', **locals()) + out = helper.create_variable(dtype=input.dtype) + helper.append_op( + type='cvm', + inputs={'X': [input], + 'CVM': [cvm]}, + outputs={'Y': [out]}, + attrs={"use_cvm": use_cvm}) + return out diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a375ba657a6152c6e9fb67b8990ea85925e6670a..c3b7aee2b4d2421927adeb9fd44a516a7999cf83 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -275,15 +275,26 @@ class Optimizer(object): self._create_global_learning_rate() optimize_ops = [] - for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - with param_and_grad[0].block.program._optimized_guard( - param_and_grad), name_scope("optimizer"): - if param_and_grad[0].trainable is True: - optimize_op = self._append_optimize_op(global_block, - param_and_grad) - optimize_ops.append(optimize_op) + if framework.in_dygraph_mode(): + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + with param_and_grad[0].block.program._optimized_guard( + param_and_grad): + if param_and_grad[0].trainable is True: + optimize_op = self._append_optimize_op(global_block, + param_and_grad) + optimize_ops.append(optimize_op) + else: + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + with param_and_grad[0].block.program._optimized_guard( + param_and_grad), name_scope("optimizer"): + if param_and_grad[0].trainable is True: + optimize_op = self._append_optimize_op(global_block, + param_and_grad) + optimize_ops.append(optimize_op) # Get custom finish ops for subclasses # FIXME: Need to fix this once we figure out how to handle dependencies diff --git a/python/paddle/fluid/tests/unittests/test_cvm_op.py b/python/paddle/fluid/tests/unittests/test_cvm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..67c310bd2f1155e4c5492e90a96cbdac9e8a3481 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cvm_op.py @@ -0,0 +1,47 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from math import log +from math import exp +from op_test import OpTest +import unittest + + +class TestCVMOp(OpTest): + """ + Test cvm op with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "cvm" + batch_size = 4 + dims = 11 + lod = [[1]] + self.inputs = { + 'X': (np.random.uniform(0, 1, [1, dims]).astype("float32"), lod), + 'CVM': np.array([[0.6, 0.4]]).astype("float32"), + } + self.attrs = {'use_cvm': False} + out = [] + for index, emb in enumerate(self.inputs["X"][0]): + out.append(emb[2:]) + self.outputs = {'Y': (np.array(out), lod)} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py index ae1883f1f7e44e06e378ff6d16dbc3c5060027e4..ec10b634091fc521062457b780b0c4cafcbacec0 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py @@ -149,5 +149,98 @@ class TestSigmoidCrossEntropyWithNorm(OpTest): self.check_grad(['X'], 'Out') +class TestSigmoidCrossEntropyWithLogitsOp5(OpTest): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype("float32")), + 'Label': np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype("float32") + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestSigmoidCrossEntropyWithNorm2(OpTest): + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + batch_size = [10, 10] + num_classes = 20 + ignore_index = -1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype("float32")), + 'Label': np.random.randint(-1, 2, tuple(batch_size + [num_classes])) + .astype("float32") + } + self.attrs = {'ignore_index': ignore_index, 'normalize': True} + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == ignore_index)] = 0 + if self.attrs['normalize']: + out = out / float( + np.where(self.inputs['Label'] != ignore_index)[0].size) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestSigmoidCrossEntropyWithLogitsOp6(OpTest): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype("float32")), + 'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes])) + .astype("float32") + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 41e5f47976c566306ad141f655a0f6516831d690..19a1f8bf74060905ecb4b81b44f7080db79c45e4 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -658,6 +658,7 @@ class DistributeTranspiler(object): outputs={"Out": splited_var}, attrs={ "epmap": eps, + "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) @@ -669,6 +670,7 @@ class DistributeTranspiler(object): outputs={"Out": fetch_barrier_out}, attrs={ "endpoints": self.pserver_endpoints, + "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) @@ -791,11 +793,15 @@ class DistributeTranspiler(object): global_ops = [] + # sparse grad name to param name + sparse_grad_to_param = [] + def __append_optimize_op__(op, block, grad_to_block_id, merged_var, lr_ops): if self._is_optimizer_op(op): self._append_pserver_ops(block, op, endpoint, grad_to_block_id, - self.origin_program, merged_var) + self.origin_program, merged_var, + sparse_grad_to_param) elif op not in lr_ops: self._append_pserver_non_opt_ops(block, op) @@ -911,6 +917,7 @@ class DistributeTranspiler(object): "Fanin": self.trainer_num, "sync_mode": self.sync_mode, "grad_to_block_id": grad_to_block_id, + "sparse_grad_to_param": sparse_grad_to_param, } if self.has_distributed_lookup_table: @@ -1779,7 +1786,8 @@ class DistributeTranspiler(object): return o4 def _append_pserver_ops(self, optimize_block, opt_op, endpoint, - grad_to_block_id, origin_program, merged_var): + grad_to_block_id, origin_program, merged_var, + sparse_grad_to_param): program = optimize_block.program pserver_block = program.global_block() new_inputs = collections.OrderedDict() @@ -1863,6 +1871,12 @@ class DistributeTranspiler(object): outputs=outputs, attrs=opt_op.all_attrs()) + # record sparse grad to param name + if new_inputs["Grad"].type == core.VarDesc.VarType.SELECTED_ROWS: + sparse_grad_to_param.append( + str(new_inputs["Grad"].name) + ":" + str(new_inputs["Param"] + .name)) + def _get_pserver_grad_param_var(self, var, var_dict): """ Return pserver side grad/param variable, return None