diff --git a/CMakeLists.txt b/CMakeLists.txt index 23bbe829ac16180088bfa37df66e23f19b021ea3..030bd19b3fd2f561a847bbc4613e5d2030812a92 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,6 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: " message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: " "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") -find_package(Sphinx) if(NOT CMAKE_CROSSCOMPILING) find_package(CUDA QUIET) endif(NOT CMAKE_CROSSCOMPILING) @@ -226,5 +225,7 @@ if(WITH_PYTHON) endif() if(WITH_DOC) + find_package(Sphinx REQUIRED) + find_python_module(recommonmark REQUIRED) add_subdirectory(doc) endif() diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d373c48b1a75c5f75c7520b56f230bc2c146b174..a4eb6f706edab9479cbce436311eb96da8845646 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -192,6 +192,10 @@ class ExecutionContext { return op_.Attr(name); } + bool HasInput(const std::string& name) const { return op_.HasInputs(name); } + + bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); } + size_t InputSize(const std::string& name) const { return op_.Inputs(name).size(); } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 20ef7e09f630140c44774147aa727780df6333fa..95e807c0afa45bc4f4feb84d450b2d0584bc3b28 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor( const std::unordered_set &bcast_vars, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope, const std::vector &local_scopes, bool allow_op_delay, - bool use_default_grad_scale, bool balance_parameter_opt_between_cards) + bool use_default_grad_scale, bool balance_parameter_opt_between_cards, + size_t num_trainers, size_t trainer_id) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; @@ -80,7 +81,13 @@ ParallelExecutor::ParallelExecutor( // Bcast Parameters to all GPUs #ifdef PADDLE_WITH_CUDA - member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); + auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); + ncclUniqueId *nccl_id = nullptr; + if (nccl_id_var != nullptr) { + nccl_id = nccl_id_var->GetMutable(); + } + member_->nccl_ctxs_.reset(new platform::NCCLContextMap( + member_->places_, nccl_id, num_trainers, trainer_id)); #endif if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 && local_scopes.empty()) { // Is CUDA diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index b251fc91417a1c00e61e9c3c952460e6268d2819..9e279876cfeef20a1921f8bd1c27046a477b9f56 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -41,7 +41,8 @@ class ParallelExecutor { const std::string& loss_var_name, Scope* scope, const std::vector& local_scopes, bool allow_op_delay, bool use_default_grad_scale, - bool balance_parameter_opt_between_cards); + bool balance_parameter_opt_between_cards, + size_t num_trainers = 1, size_t trainer_id = 0); ~ParallelExecutor(); diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 543784289cfc51048057e467d36fdd1f334eb903..6297051e5a30f1daa512d25d5aa3ab3b2f79f1d1 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -21,15 +21,18 @@ namespace tensorrt { class ReluOpConverter : public OpConverter { public: ReluOpConverter() {} - void operator()(const framework::OpDesc& op) override { + void operator()(const framework::proto::OpDesc& op) override { + // Here the two nullptr looks strange, that's because the + // framework::OpDesc's constructor is strange. + framework::OpDesc op_desc(op, nullptr, nullptr); LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " "type is Relu"; const nvinfer1::ITensor* input_tensor = - engine_->GetITensor(op.Input("X")[0]); + engine_->GetITensor(op_desc.Input("X")[0]); nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, Activation, *const_cast(input_tensor), nvinfer1::ActivationType::kRELU); - engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0)); + engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0)); } }; diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 431500b90e144e2a30fe705b72e93452f806ca65..209936c3bafb0d31546856dc36c1b48053a0634b 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -21,7 +21,7 @@ namespace tensorrt { class Conv2dOpConverter : public OpConverter { public: Conv2dOpConverter() {} - void operator()(const framework::OpDesc& op) override { + void operator()(const framework::proto::OpDesc& op) override { LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias"; } diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.cc b/paddle/fluid/inference/tensorrt/convert/io_converter.cc index 2d583c0012367d1a08f4a90b39d0bca211f82102..854f434d93e81237dc85c5df62debcf3b3824b78 100644 --- a/paddle/fluid/inference/tensorrt/convert/io_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/io_converter.cc @@ -39,7 +39,7 @@ class DefaultIOConverter : public EngineIOConverter { cudaMemcpyHostToDevice, *stream_)); } else if (is_gpu_place(place)) { PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size, - cudaMemcpyHostToHost, *stream_)); + cudaMemcpyDeviceToDevice, *stream_)); } else { PADDLE_THROW("Unknown device for converter"); } diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index f9834ab156c9dcc11f4e89075b7bf5457cf00268..3ca58b139bd3af1947ae7f063060e11d2ea7d577 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -21,7 +21,7 @@ namespace tensorrt { class MulOpConverter : public OpConverter { public: MulOpConverter() {} - void operator()(const framework::OpDesc& op) override { + void operator()(const framework::proto::OpDesc& op) override { LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; } }; diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 77c788550b2c7df1f483b926661789b2a54d8fff..abc9ebf472498f6653d5bb1113ae2f3ce7e5a923 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -31,10 +31,10 @@ namespace tensorrt { class OpConverter { public: OpConverter() {} - virtual void operator()(const framework::OpDesc& op) {} + virtual void operator()(const framework::proto::OpDesc& op) {} - void Run(const framework::OpDesc& op, TensorRTEngine* engine) { - std::string type = op.Type(); + void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) { + std::string type = op.type(); auto* it = Registry::Lookup(type); PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type); it->SetEngine(engine); @@ -42,14 +42,16 @@ class OpConverter { } // convert fluid op to tensorrt layer - void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) { + void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) { OpConverter::Run(op, engine); } // convert fluid block to tensorrt network - void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) { - for (auto op : block.AllOps()) { - OpConverter::Run(*op, engine); + void ConvertBlock(const framework::proto::BlockDesc& block, + TensorRTEngine* engine) { + for (size_t i = 0; i < block.ops_size(); i++) { + const auto& op = block.ops(i); + OpConverter::Run(op, engine); } } diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index 3275b65bc6ddcda39414204a6e639070fdc3fe15..ec33f97c8240dfc09a203d68599bffe78a4abb12 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -51,7 +51,7 @@ void Compare(const std::string op_type, float input, float expect) { op_desc.SetInput("X", {"X"}); op_desc.SetOutput("Out", {"Out"}); - auto op = framework::OpRegistry::CreateOp(op_desc); + auto op = framework::OpRegistry::CreateOp(*op_desc.Proto()); // run fluid op op->Run(scope, place); @@ -68,7 +68,8 @@ void Compare(const std::string op_type, float input, float expect) { nvinfer1::DimsCHW{1, 1, 1}); // convert op OpConverter op_converter; - op_converter.ConvertOp(op_desc, engine); + op_converter.ConvertOp(*op_desc.Proto(), engine); + engine->DeclareOutput("Out"); engine->FreezeNetwork(); diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index aa5fb726f1129eda65a6f39791330b795aad660d..8d66543eb7637c5a8ae670b89ef5996954ba2e7b 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) { conv2d_op->SetType("conv2d"); OpConverter converter; - converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/); + converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index c4fd1e298b0daea85db2a407d04ad2d7bcdee0f0..60c761c5281e2f535aab0200c93fb738addcdb87 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" -DEFINE_string(data_set, "cifar10", "Data set to test"); DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model."); DEFINE_int32(batch_size, 1, "Batch size of input data"); @@ -35,19 +34,19 @@ TEST(inference, image_classification) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + const bool is_combined = false; + std::vector> feed_target_shapes = + GetFeedTargetShapes(dirname, is_combined); + paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [0.0, 1.0]. - if (FLAGS_data_set == "cifar10") { - SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32}, - static_cast(0), static_cast(1)); - } else if (FLAGS_data_set == "imagenet") { - SetupTensor(&input, {FLAGS_batch_size, 3, 224, 224}, - static_cast(0), static_cast(1)); - } else { - LOG(FATAL) << "Only cifar10 or imagenet is supported."; - } - + feed_target_shapes[0][0] = FLAGS_batch_size; + paddle::framework::DDim input_dims = + paddle::framework::make_ddim(feed_target_shapes[0]); + LOG(INFO) << input_dims; + SetupTensor(&input, input_dims, static_cast(0), + static_cast(1)); std::vector cpu_feeds; cpu_feeds.push_back(&input); @@ -60,7 +59,7 @@ TEST(inference, image_classification) { LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "Batch size is " << FLAGS_batch_size; TestInference( - dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined); LOG(INFO) << output1.dims(); } @@ -73,7 +72,7 @@ TEST(inference, image_classification) { LOG(INFO) << "--- GPU Runs: ---"; LOG(INFO) << "Batch size is " << FLAGS_batch_size; TestInference( - dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined); LOG(INFO) << output2.dims(); if (!FLAGS_skip_cpu) { diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index af2a7a5620487a10c1df6152fc4e4bf67b150752..b02e5c99f00eaf03c3753e43575cbc67e834774e 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -89,6 +89,50 @@ void CheckError(const paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } +std::unique_ptr InitProgram( + paddle::framework::Executor* executor, paddle::framework::Scope* scope, + const std::string& dirname, const bool is_combined = false) { + std::unique_ptr inference_program; + if (is_combined) { + // All parameters are saved in a single file. + // Hard-coding the file names of program and parameters in unittest. + // The file names should be consistent with that used in Python API + // `fluid.io.save_inference_model`. + std::string prog_filename = "__model_combined__"; + std::string param_filename = "__params_combined__"; + inference_program = + paddle::inference::Load(executor, scope, dirname + "/" + prog_filename, + dirname + "/" + param_filename); + } else { + // Parameters are saved in separate files sited in the specified + // `dirname`. + inference_program = paddle::inference::Load(executor, scope, dirname); + } + return inference_program; +} + +std::vector> GetFeedTargetShapes( + const std::string& dirname, const bool is_combined = false) { + auto place = paddle::platform::CPUPlace(); + auto executor = paddle::framework::Executor(place); + auto* scope = new paddle::framework::Scope(); + + auto inference_program = InitProgram(&executor, scope, dirname, is_combined); + auto& global_block = inference_program->Block(0); + + const std::vector& feed_target_names = + inference_program->GetFeedTargetNames(); + std::vector> feed_target_shapes; + for (size_t i = 0; i < feed_target_names.size(); ++i) { + auto* var = global_block.FindVar(feed_target_names[i]); + std::vector var_shape = var->GetShape(); + feed_target_shapes.push_back(var_shape); + } + + delete scope; + return feed_target_shapes; +} + template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, @@ -124,22 +168,7 @@ void TestInference(const std::string& dirname, paddle::platform::RecordEvent record_event( "init_program", paddle::platform::DeviceContextPool::Instance().Get(place)); - - if (is_combined) { - // All parameters are saved in a single file. - // Hard-coding the file names of program and parameters in unittest. - // The file names should be consistent with that used in Python API - // `fluid.io.save_inference_model`. - std::string prog_filename = "__model_combined__"; - std::string param_filename = "__params_combined__"; - inference_program = paddle::inference::Load( - &executor, scope, dirname + "/" + prog_filename, - dirname + "/" + param_filename); - } else { - // Parameters are saved in separate files sited in the specified - // `dirname`. - inference_program = paddle::inference::Load(&executor, scope, dirname); - } + inference_program = InitProgram(&executor, scope, dirname, is_combined); } // Disable the profiler and print the timing information paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c14a2b7786f9f7c06d59479d3bbce9c5d542e495..d38a9ce58726a1d045d6905354b0b592166c0110 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -186,6 +186,11 @@ endif() add_subdirectory(detail) if(WITH_DISTRIBUTE) + if(WITH_GPU) + op_library(gen_nccl_id_op DEPS nccl_common) + else() + set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) + endif() set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(send_op DEPS ${DISTRIBUTE_DEPS}) @@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) + cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor) else() - set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) + set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op) endif() op_library(cross_entropy_op DEPS cross_entropy) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 661dfa69fe1580ff3890f12defcd124225be0c06..ae60ab15325ef101feb7270a4f5d840cb2112be0 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, // stub context SendProcessor* s = new SendProcessor(ch); s->Prepare(var_h, time_out); - s->response_call_back_ = NULL; + s->response_call_back_ = nullptr; auto call = s->stub_g_.PrepareUnaryCall( s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index f6229b71bc01a6de51f50f5fe880ada6e15e74dd..dabce7414d2f0dca74193f1cd10c341793c10ec9 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { public: - explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; } + explicit BaseProcessor(std::shared_ptr ch) { + context_ = nullptr; + } virtual ~BaseProcessor() {} @@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor { ::grpc::GenericStub stub_g_; ::grpc::ByteBuffer reply_; - RequestSendCallBack response_call_back_ = NULL; + RequestSendCallBack response_call_back_ = nullptr; }; typedef std::function diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 7f9cae21ccca8dd51f9fbe98148d01a51ac6eb84..18f1bc53d0f561f412a5bbbe018bc3d427ac9ef9 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -47,6 +47,7 @@ class AsyncGRPCServer final { explicit AsyncGRPCServer(const std::string &address, bool sync_mode) : address_(address), sync_mode_(sync_mode), ready_(0) {} + ~AsyncGRPCServer() {} void WaitServerReady(); void RunSyncUpdate(); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index fffa9ae7a43ea5cd7b2bda6fbbf6ef9f7d23009d..9478c5702bcbf99fc88207b8c4843dbccf8a5925 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -32,6 +32,7 @@ service SendRecvService { enum VarType { LOD_TENSOR = 0; SELECTED_ROWS = 1; + NCCL_ID = 2; } // NOTICE(gongwb):don't modify this proto if you are not diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 1a8a1af20fa446dbd537944409ef0ca1e3e9116f..07c43554bc6a0d71d688a5a5772d0ab3d2de319a 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -14,6 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif #include #include // NOLINT @@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, } else if (var->IsType()) { request.set_type(::sendrecv::SELECTED_ROWS); GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size); +#ifdef PADDLE_WITH_CUDA + } else if (var->IsType()) { + request.set_type(::sendrecv::NCCL_ID); +#endif } else { PADDLE_THROW("Serialize does not support type: %s", typeid(var->Type()).name()); @@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void* buf = buffer.get(); ProtoEncodeHelper e(static_cast(buf), 1024); e.WriteRawBytes(std::string(header.data(), header.size())); +// NCCLID is copied directly to the message, return bytebuffer +// with only one slice if serializing NCCLID. +#ifdef PADDLE_WITH_CUDA + if (var->IsType()) { + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, + NCCL_UNIQUE_ID_BYTES); + const ncclUniqueId& uid = var->Get(); + e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES)); + + // for serialize NCCL_ID + ::grpc::Slice slices(e.size()); + memcpy(const_cast(slices.begin()), e.data(), e.size()); + ::grpc::ByteBuffer tmp(&slices, 1); + msg->Swap(&tmp); + return; + } +#endif + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); // steal reference of tensor data ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 99602a05d023f30c2eed8df25e7534fdc9ef2ced..462e303096e609c6797ca8cc16266ec3621623fc 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -17,6 +17,9 @@ #include #include #include +#ifdef PADDLE_WITH_CUDA +#include +#endif #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" @@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) { } case sendrecv::VariableMessage::kSerializedFieldNumber: { PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || - meta_.type() == sendrecv::LOD_TENSOR) && + meta_.type() == sendrecv::LOD_TENSOR || + meta_.type() == sendrecv::NCCL_ID) && meta_.varname() != "", "meta info should be got first!"); @@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) { return tag; } + if (meta_.type() == sendrecv::NCCL_ID) { +#ifdef PADDLE_WITH_CUDA + auto* var = scope_->FindVar(meta_.varname()); + if (var != nullptr) { + ncclUniqueId* id = var->GetMutable(); + if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal, + num_bytes)) { + return tag; + } + } + break; +#else + PADDLE_THROW("Not compiled with CUDA!"); +#endif + } + framework::DDim dims = GetDims(meta_.dims()); if (meta_.type() == sendrecv::LOD_TENSOR) { PADDLE_ENFORCE(meta_.lod_size() >= 0, diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5678f63466d368b3dd59380c18f9625cabd368b --- /dev/null +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace operators { + +class GenNCCLIdOp : public framework::OperatorBase { + public: + GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + // put nccl id in CPUPlace + auto& dev_ctx = *pool.Get(platform::CPUPlace()); + int trainer_id = Attr("trainer_id"); + framework::Scope& local_scope = scope.NewScope(); + + if (trainer_id == 0) { + GenerateAndSend(&local_scope, dev_ctx); + } else { + GetIdByServer(&local_scope, dev_ctx); + } + } + + private: + void GenerateAndSend(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + auto var = scope->FindVar(NCCL_ID_VARNAME); + PADDLE_ENFORCE_NOT_NULL(var); + auto id = var->GetMutable(); + PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id)); + + std::vector endpoint_list = + Attr>("endpoint_list"); + detail::RPCClient client; + for (auto& ep : endpoint_list) { + VLOG(3) << "sending nccl id to " << ep; + client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME); + } + client.Wait(); + VLOG(3) << "sending completed..."; + } + + void GetIdByServer(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + std::string endpoint = Attr("endpoint"); + // NOTE: Can not use unique_ptr here because the default + // deleter will call GRPC Server's base class's dtor and + // that will cause a wired crash. + detail::AsyncGRPCServer rpc_service(endpoint, true); + framework::ProgramDesc empty_program; + framework::Executor executor(dev_ctx.GetPlace()); + rpc_service.SetScope(scope); + rpc_service.SetDevCtx(&dev_ctx); + rpc_service.SetProgram(&empty_program); + rpc_service.SetExecutor(&executor); + + std::thread server_thread( + std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service)); + rpc_service.SetCond(0); + VLOG(3) << "start getting nccl id from trainer 0..."; + auto recv = rpc_service.Get(); + VLOG(3) << "got nccl id and stop server..."; + rpc_service.ShutDown(); + VLOG(3) << "rpc server stopped"; + server_thread.join(); + } +}; + +class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); + AddComment(R"DOC( +GenNCCLId operator + +For trainer 0: generate a new UniqueId and send it to all the other trainers. +For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. +)DOC"); + AddAttr("endpoint", + "(string), e.g. 127.0.0.1:6175 " + "current listen endpoint"); + AddAttr>( + "endpoint_list", + "['trainer1_ip:port', 'trainer2_ip:port', ...] " + "list of trainer endpoints start from trainer 1") + .SetDefault({}); + AddAttr("trainer_id", + "(int default 0) " + "The index of the trainer in distributed training.") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker); diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h index 0abda999a52bcbb94e6503692bd11aff26e849ba..62e6307ae9f4236a38c49daaf09fc05c54268159 100644 --- a/paddle/fluid/operators/math/sequence2batch.h +++ b/paddle/fluid/operators/math/sequence2batch.h @@ -64,18 +64,22 @@ class LoDTensor2BatchFunctor { bool is_reverse = false) const { if (!is_cal_batch_lod) { auto lods = batch->lod(); - PADDLE_ENFORCE_GT(lods.size(), 2UL); - PADDLE_ENFORCE_EQ(lods[1].size(), - static_cast(lod_tensor.dims()[0])); + PADDLE_ENFORCE_GT(lods.size(), 2UL, + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_ENFORCE_EQ( + lods[1].size(), static_cast(lod_tensor.dims()[0]), + "The LoD information should be consistent with the dims."); CopyMatrixRowsFunctor to_batch; to_batch(context, lod_tensor, lods[1], batch, true); return; } auto lods = lod_tensor.lod(); - auto lod = lods[0]; PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); + auto lod = lods[0]; + std::vector seq_info; for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; @@ -157,9 +161,12 @@ class Batch2LoDTensorFunctor { const framework::LoDTensor& batch, framework::LoDTensor* lod_tensor) const { auto in_lod = batch.lod(); - PADDLE_ENFORCE_GT(in_lod.size(), 2UL); - PADDLE_ENFORCE_EQ(in_lod[1].size(), - static_cast(lod_tensor->dims()[0])); + PADDLE_ENFORCE_GT(in_lod.size(), 2UL, + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_ENFORCE_EQ( + in_lod[1].size(), static_cast(lod_tensor->dims()[0]), + "The LoD information should be consistent with the dims."); CopyMatrixRowsFunctor to_seq; to_seq(context, batch, in_lod[1], lod_tensor, false); } diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index ccd7063fe69e0f21b4d2a821bb70902b39c9b9de..3dd8c7c11eca241e747bfa129962032d882ce44c 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -92,14 +92,16 @@ class ReshapeOp : public framework::OperatorWithKernel { } if (unk_dim_idx != -1) { - output_shape[unk_dim_idx] = -in_size / capacity; - // in_size < 0 and is un-determinate in compile time, skip the check, - // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], - // capacity = -24, in_size = -8, output_shape[0] = 0 - // the following check will fail. if (in_size > 0) { + // in_size < 0 and is un-determinate in compile time, skip the check, + // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, in_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -in_size / capacity; PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, "Invalid shape is given."); + } else { + output_shape[unk_dim_idx] = -1; } } else { PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); @@ -122,7 +124,10 @@ class ReshapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const { auto *out = ctx.Output("Out"); auto *in = ctx.Input("X"); - auto *shape_tensor = ctx.Input("Shape"); + + auto *shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; framework::DDim out_dims = out->dims(); diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc new file mode 100644 index 0000000000000000000000000000000000000000..bbae1d54aa3524fd45cb8ab13c86df8d54b8e643 --- /dev/null +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include // NOLINT + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/listen_and_serv_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/string/printf.h" + +USE_NO_KERNEL_OP(listen_and_serv); + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; +namespace detail = paddle::operators::detail; +namespace string = paddle::string; + +std::unique_ptr rpc_service; + +void StartServer(std::atomic* initialized) { + f::Scope scope; + p::CPUPlace place; + scope.Var(NCCL_ID_VARNAME); + p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(p::CPUPlace()); + + rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true)); + + f::ProgramDesc empty_program; + f::Executor executor(dev_ctx.GetPlace()); + rpc_service->SetScope(&scope); + rpc_service->SetDevCtx(&dev_ctx); + rpc_service->SetProgram(&empty_program); + rpc_service->SetExecutor(&executor); + + std::thread server_thread( + std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get())); + *initialized = true; + rpc_service->SetCond(0); + auto recv = rpc_service->Get(); + LOG(INFO) << "got nccl id and stop server..."; + rpc_service->ShutDown(); + server_thread.join(); +} + +TEST(SendNcclId, Normal) { + std::atomic initialized{false}; + std::thread server_thread(StartServer, &initialized); + while (!initialized) { + } + // wait server to start + // sleep(2); + rpc_service->WaitServerReady(); + + f::Scope scope; + p::CPUPlace place; + p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(p::CPUPlace()); + + auto var = scope.Var(NCCL_ID_VARNAME); + // var->SetType(f::proto::VarType_Type_RAW); + auto id = var->GetMutable(); + p::dynload::ncclGetUniqueId(id); + + int port = rpc_service->GetSelectedPort(); + std::string ep = string::Sprintf("127.0.0.1:%d", port); + detail::RPCClient client; + + client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); + client.Wait(); + server_thread.join(); + auto* ptr = rpc_service.release(); + delete ptr; +} diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 0013597fd516d15c7d502370eec77e1a6a5dca88..e30c1a9ebf08365a9856fb32b1ce5790869e2b33 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -14,12 +14,15 @@ #pragma once +#include #include // NOLINT #include #include #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" +#define NCCL_ID_VARNAME "NCCLID" + namespace paddle { namespace platform { @@ -73,7 +76,9 @@ struct NCCLContextMap { std::unordered_map contexts_; std::vector order_; - explicit NCCLContextMap(const std::vector &places) { + explicit NCCLContextMap(const std::vector &places, + ncclUniqueId *nccl_id = nullptr, + size_t num_trainers = 1, size_t trainer_id = 0) { PADDLE_ENFORCE(!places.empty()); order_.reserve(places.size()); for (auto &p : places) { @@ -85,18 +90,34 @@ struct NCCLContextMap { order_.size(), contexts_.size(), "NCCL Context Map does not support contain two or more same device"); - if (places.size() > 1) { - std::unique_ptr comms(new ncclComm_t[order_.size()]); + if (places.size() <= 1) { + return; + } + std::unique_ptr comms(new ncclComm_t[order_.size()]); + // if pass nccl_id here, can assume we are doing multi node training + if (nccl_id == nullptr) { + std::lock_guard guard(NCCLGroupGuard::NCCLMutex()); + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + comms.get(), static_cast(order_.size()), order_.data())); + } else { + PADDLE_ENFORCE_GT(num_trainers, 1); + // TODO(wuyi): need to ensure each node have same number of GPUs { - std::lock_guard guard(NCCLGroupGuard::NCCLMutex()); - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - comms.get(), static_cast(order_.size()), order_.data())); - } - int i = 0; - for (auto &dev_id : order_) { - contexts_.at(dev_id).comm_ = comms[i++]; + int nranks = num_trainers * order_.size(); + NCCLGroupGuard gurad; + for (auto &gpu_id : order_) { + int rank = trainer_id * order_.size() + gpu_id; + VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks; + PADDLE_ENFORCE(cudaSetDevice(gpu_id)); + PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( + comms.get() + gpu_id, nranks, *nccl_id, rank)); + } } } + int i = 0; + for (auto &dev_id : order_) { + contexts_.at(dev_id).comm_ = comms[i++]; + } } NCCLContextMap(const NCCLContextMap &other) = delete; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3e2eed31b446b83843fba943e4f2bc9e3787d7f6..b62291a99d34457dd17bf2bcafc1fc611419f086 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -503,12 +503,13 @@ All parameter, weight, gradient are variables in Paddle. const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope, std::vector &local_scopes, bool allow_op_delay, bool use_default_grad_scale, - bool balance_parameter_opt_between_cards) { + bool balance_parameter_opt_between_cards, size_t num_trainers, + size_t trainer_id) { new (&self) ParallelExecutor( num_threads, use_event, places, params, bcast_vars, main_program, loss_var_name, scope, local_scopes, allow_op_delay, use_default_grad_scale, - balance_parameter_opt_between_cards); + balance_parameter_opt_between_cards, num_trainers, trainer_id); }) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) // NOTE: even we return a vec* to Python use reference policy. diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 28e54f5492e7b04a1406e319cecf977d4a55725e..38c765938fe9d7b2103bfdd926874c485d0ff4dc 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -489,7 +489,7 @@ class Operator(object): 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'ncclInit', 'channel_create', 'channel_close', - 'channel_send', 'channel_recv', 'select' + 'channel_send', 'channel_recv', 'select', 'gen_nccl_id' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 5b43f860e7075745bbf6e76c2f9d0e9a87a86db0..7358c4b60e87893b9c04e3da2221dfb69d3ba0c7 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -31,7 +31,9 @@ class ParallelExecutor(object): allow_op_delay=False, share_vars_from=None, use_default_grad_scale=True, - balance_parameter_opt_between_cards=False): + balance_parameter_opt_between_cards=False, + num_trainers=1, + trainer_id=0): """ ParallelExecutor can run program in parallel. @@ -55,6 +57,11 @@ class ParallelExecutor(object): balance_parameter_opt_between_cards(bool, default True): Whether updating different gradients on different cards. Currently, it is not recommended. + num_trainers(int, default 1): If greater than 1, NCCL will be + initialized with multpile rank of nodes, each node should have + same number of GPUs. Distributed training will be enabled then. + trainer_id(int, default 0): Must use together with num_trainers. + trainer_id is the "rank" of current node starts from 0. Returns: A ParallelExecutor object. @@ -134,8 +141,9 @@ class ParallelExecutor(object): local_scopes, allow_op_delay, use_default_grad_scale, - balance_parameter_opt_between_cards) - + balance_parameter_opt_between_cards, + num_trainers, + trainer_id) self.scope = scope def run(self, fetch_list, feed=None, feed_dict=None): diff --git a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt index 9ab00325a2eef3bbc79757ad1a3e6f8511c49552..c2a15bdb3b17b65fe861dd429f548074c13e2f09 100644 --- a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt +++ b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt @@ -6,4 +6,5 @@ foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() +add_subdirectory(fit_a_line) add_subdirectory(recognize_digits) diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..673c965b662a022739f8d489c331f4de9455a926 --- /dev/null +++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +# default test +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9bbb52d769282460c571ebc51d5eff18de3114 --- /dev/null +++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py @@ -0,0 +1,137 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import contextlib +import numpy +import unittest + +# train reader +BATCH_SIZE = 20 + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) + +test_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.test(), buf_size=500), + batch_size=BATCH_SIZE) + + +def inference_program(): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + return y_predict + + +def linear(): + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + y_predict = inference_program() + + loss = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_loss = fluid.layers.mean(loss) + + return avg_loss + + +def train(use_cuda, save_dirname): + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + trainer = fluid.Trainer( + train_func=linear, + infer_func=inference_program, + place=place, + optimizer=fluid.optimizer.SGD(learning_rate=0.001)) + + def event_handler(event): + if isinstance(event, fluid.EndEpochEvent): + test_metrics = trainer.test( + reader=test_reader, feed_order=['x', 'y']) + print test_metrics + ''' + + ... + ['25.768919467926025'] + ['15.343549569447836'] + ... + + ''' + if float(test_metrics[0]) < 20.0: + if save_dirname is not None: + # NOT clear yet + # fluid.io.save_inference_model(save_dirname, ['x'], [y_predict]) + # trainer.save_params(save_dirname) + # https://github.com/PaddlePaddle/Paddle/pull/10445 + trainer.save_inference_model(save_dirname) + return + + trainer.train( + reader=train_reader, + num_epochs=100, + event_handler=event_handler, + feed_order=['x', 'y']) + + +# infer +def infer(use_cuda, save_dirname=None): + if save_dirname is None: + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + inferencer = fluid.Inferencer(param_path=save_dirname, place=place) + + batch_size = 10 + tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32") + + results = inferencer.infer({'x': tensor_x}) + print("infer results: ", results[0]) + + +def main(use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + + # Directory for saving the trained model + save_dirname = "fit_a_line.inference.model" + + train(use_cuda, save_dirname) + infer(use_cuda, save_dirname) + + +class TestFitALine(unittest.TestCase): + def test_cpu(self): + with self.program_scope_guard(): + with fluid.unique_name.guard(): + main(use_cuda=False) + + def test_cuda(self): + with self.program_scope_guard(): + with fluid.unique_name.guard(): + main(use_cuda=True) + + @contextlib.contextmanager + def program_scope_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main()