diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 2f461e7b2a1b971264009a086cfcb27e18fd1fed..19a9fc3802a2f2348ad7d50a267615ed70bbc4fe 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -228,6 +228,10 @@ class OpKernelRegistrar : public Registrar { USE_OP_ITSELF(op_type); \ USE_OP_DEVICE_KERNEL(op_type, CPU); +#define USE_GPU_ONLY_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_DEVICE_KERNEL(op_type, GPU) + #define USE_OP(op_type) \ USE_OP_ITSELF(op_type); \ USE_OP_KERNEL(op_type) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 3a9c7a732885b489c23c7a7961066a6e87849203..93885fa3028e072bc0bd021ea9287087678f3621 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -122,7 +122,7 @@ class OperatorBase { protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: - // I (Inputs)opear + // I (Inputs) // O (Outputs) // OG (Output Gradients) VariableNameMap inputs_; @@ -287,6 +287,16 @@ class ExecutionContext { return device_context_; } + //! Get actual name vector for this input. + const std::vector& Inputs(const std::string& name) const { + return op_.Inputs(name); + } + + //! Get actual name vector for this output. + const std::vector& Outputs(const std::string& name) const { + return op_.Outputs(name); + } + #ifdef PADDLE_WITH_CUDA const platform::CUDADeviceContext& cuda_device_context() const { PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); @@ -398,6 +408,7 @@ class OperatorWithKernel : public OperatorBase { // indicate kernel DataType by input data. Defaultly all input data must be // same. virtual DataType IndicateDataType(const ExecutionContext& ctx) const { + VLOG(3) << "Default IndicateDataType " << this->Type(); auto& scope = ctx.scope(); int data_type = -1; for (auto& input : this->inputs_) { diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 9d2dc6a32bb2d4f6368fd9c7264c55fb9588819c..7b9a5b75e1087a1cc3b6c6c7a6e4dc185c32dd42 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -126,11 +126,16 @@ class Tensor { inline Tensor Slice(const int& begin_idx, const int& end_idx) const; platform::Place place() const { - PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder"); + PADDLE_ENFORCE_NOT_NULL( + holder_, "Tensor not initialized yet when Tensor::place() is called."); return holder_->place(); } - std::type_index type() const { return holder_->type(); } + std::type_index type() const { + PADDLE_ENFORCE_NOT_NULL( + holder_, "Tensor not initialized yet when Tensor::type() is called."); + return holder_->type(); + } size_t memory_size() const; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index c72261710173a0f3af199646d6800bf9d7c27b67..60dc55a32f5f05875e4f3ce77431556e14adc74a 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -90,6 +90,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") endif() + # nccl_op contains several operators + if ("${TARGET}" STREQUAL "nccl_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n") + endif() + # reduce_op contains several operators if ("${TARGET}" STREQUAL "reduce_op") set(pybind_flag 1) @@ -121,6 +128,7 @@ function(op_library TARGET) endfunction() add_subdirectory(math) +add_subdirectory(nccl) set(DEPS_OPS recurrent_op @@ -130,6 +138,7 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op + nccl_op sequence_conv_op lstm_op) @@ -142,6 +151,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op selected_rows_functor) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) +if(WITH_GPU) +op_library(nccl_op DEPS nccl_common) +endif() op_library(sequence_conv_op DEPS context_project) op_library(lstm_op DEPS sequence2batch lstm_compute) @@ -157,4 +169,8 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) + +if(WITH_GPU) + nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context) +endif() cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) diff --git a/paddle/operators/batch_norm_op.cc b/paddle/operators/batch_norm_op.cc index f7dc990f0db8ae4891ff068fb97899e6d01478da..f2c8be4c54eed9cd0aeb004eeb74a42adc0695f5 100644 --- a/paddle/operators/batch_norm_op.cc +++ b/paddle/operators/batch_norm_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; template using EigenMatrix = framework::EigenMatrix; @@ -64,6 +65,9 @@ class BatchNormOp : public framework::OperatorWithKernel { (tensor_format == TensorFormat::NCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); + PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, + "Input x must have 3 to 5 dimensions."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); @@ -108,10 +112,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { "Store the global Variance when training"); AddOutput("SavedMean", "Mean of the current mini batch, " - "will apply to output when training"); + "will apply to output when training") + .AsIntermediate(); AddOutput("SavedVariance", "Variance of the current mini batch, " - "will apply to output when training"); + "will apply to output when training") + .AsIntermediate(); AddComment(R"DOC( https://arxiv.org/pdf/1502.03167.pdf @@ -135,7 +141,6 @@ class BatchNormKernel : public framework::OpKernel { const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); - PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, "The Input dim size should be between 3 and 5"); const int N = x_dims[0]; @@ -289,6 +294,25 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + VLOG(3) << "IndicateDataType " << this->Type(); + const auto *var = ctx.InputVar(framework::GradVarName("Y")); + if (var == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + return framework::ToDataType(t->type()); + } }; template diff --git a/paddle/operators/nccl/CMakeLists.txt b/paddle/operators/nccl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce0ddd89bfb0d73e237a6f9a777376624d8ef2d4 --- /dev/null +++ b/paddle/operators/nccl/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_GPU) + nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator ) +endif() diff --git a/paddle/operators/nccl/nccl_gpu_common.cc b/paddle/operators/nccl/nccl_gpu_common.cc new file mode 100644 index 0000000000000000000000000000000000000000..6be735e4c731f79684e0bdac3d69a30b328fed84 --- /dev/null +++ b/paddle/operators/nccl/nccl_gpu_common.cc @@ -0,0 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/nccl/nccl_gpu_common.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace platform {} // namespace platform +} // namespace paddle diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h new file mode 100644 index 0000000000000000000000000000000000000000..5858cd4839d367bb888b2b98cde2225751391162 --- /dev/null +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/platform/device_context.h" +#include "paddle/platform/dynload/nccl.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/macros.h" + +namespace paddle { +namespace platform { + +constexpr int kInvalidGPUId = -1; + +struct Communicator { + std::vector comms_; + std::unordered_map comm_id_map_; + + Communicator() {} + + int GetCommId(int device_id) const { return comm_id_map_.at(device_id); } + + void InitAll(const std::vector& gpus) { + comms_.resize(gpus.size()); + for (size_t i = 0; i < gpus.size(); ++i) { + comm_id_map_[gpus[i]] = i; + } + PADDLE_ENFORCE( + dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data())); + } + + ~Communicator() { + for (size_t i = 0; i < comms_.size(); ++i) { + // FIXME(dzh) : PADDLE_ENFORCE return void + dynload::ncclCommDestroy(comms_[i]); + } + } + + DISABLE_COPY_AND_ASSIGN(Communicator); +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d39cb2fcf9cc205edf86f8ab1d5e04b5672e00f6 --- /dev/null +++ b/paddle/operators/nccl_op.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" + +namespace paddle { +namespace operators { + +// NCCLinitOp +class NCCLInitOp : public framework::OperatorBase { + public: + NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + const auto &name = Output("Communicator"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), + "Can not find variable '%s' in the scope.", name); + std::vector gpus = Attr>("gpus"); + PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty."); + + if (scope.FindVar(name) == nullptr) { + PADDLE_THROW("Output(Communicator) is needed for ncclInit operator."); + } + + platform::Communicator *comm = + scope.FindVar(name)->GetMutable(); + comm->InitAll(gpus); + } +}; + +class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLInitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Communicator", + "Create Communicator for communicating between gpus"); + AddAttr>("gpus", "gpu id lists"); + AddAttr("data_type", "output data type") + .SetDefault(framework::DataType::FP32); + AddComment(R"DOC( + create communicator. + )DOC"); + } +}; + +// AllReduceOp +class NCCLAllReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of AllReduce op input should not be NULL"); + PADDLE_ENFORCE( + ctx->HasInput("Communicator"), + " Input(Communicator) of AllReduce op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Input(X) of AllReduce op input should not be NULL"); + + auto x_dims = ctx->GetInputsDim("X"); + + std::string reduction = ctx->Attrs().Get("reduction"); + PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + "invalid reduction."); + + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +// ReduceOp +class NCCLReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of Reduce op input should not be NULL"); + PADDLE_ENFORCE( + ctx->HasInput("Communicator"), + " Input(Communicator) of Reduce op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Input(X) of Reduce op input should not be NULL"); + + std::string reduction = ctx->Attrs().Get("reduction"); + PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + "invalid reduction."); + + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +// BcastOp +class NCCLBcastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasInput("Communicator"), + " Input(Communicator) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Output(Out) of Bcast op output should not be NULL"); + + int root = ctx->Attrs().Get("root"); + PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set."); + + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +// AllreduceOp +class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLAllReduceOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of AllReduce op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of AllReduce op"); + AddAttr("reduction", + "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") + .SetDefault("ncclSum"); + AddComment(R"DOC( + AllReduce the input tensors. + )DOC"); + } +}; + +// ReduceOp +class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLReduceOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of Reduce op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of Reduce op"); + AddAttr("reduction", + "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") + .SetDefault("ncclSum"); + AddAttr("root", + "root gpu of the parameter. if not " + "set(platform::kInvalidGPUId). hashed by name.") + .SetDefault(platform::kInvalidGPUId); + AddComment(R"DOC( + Reduce the tensors)DOC"); + } +}; + +// BcastOp +class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLBcastOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of BcastSend op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of Bcast"); + AddAttr("root", + "root gpu of the parameter. if not " + "set(platform::kInvalidGPUId). hashed by name.") + .SetDefault(platform::kInvalidGPUId); + AddComment(R"DOC( + Bcast the tensors. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp, + paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker); + +REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, + ops::NCCLAllReduceOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp, + ops::NCCLBcastOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp, + ops::NCCLReduceOpMaker); diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..86dee8ee8e1c1a1041d6bc9fa515d669a9c4e466 --- /dev/null +++ b/paddle/operators/nccl_op.cu @@ -0,0 +1,211 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenseshashernless required by applicable law or agreed +to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using platform::Communicator; +using framework::LoDTensor; + +template +class NCCLTypeWrapper; + +template <> +class NCCLTypeWrapper { + public: + static const ncclDataType_t type = ncclFloat; +}; + +template <> +class NCCLTypeWrapper { + public: + static const ncclDataType_t type = ncclDouble; +}; + +template +class NCCLAllReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + + auto ins = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); + + std::string reduction = ctx.Attr("reduction"); + ncclRedOp_t reduction_op_ = ncclSum; + + if (reduction == "ncclMin") { + reduction_op_ = ncclMin; + } else if (reduction == "ncclMax") { + reduction_op_ = ncclMax; + } else if (reduction == "ncclSum") { + reduction_op_ = ncclSum; + } else if (reduction == "ncclProd") { + reduction_op_ = ncclProd; + } else { + PADDLE_THROW("Invalid reduction. default ncclSum."); + } + + auto* comm = ctx.Input("Communicator"); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + + // device id + int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); + int idx = comm->GetCommId(gpu_id); + + for (size_t i = 0; i < ins.size(); ++i) { + VLOG(1) << "gpu : " + << " invoke allreduce. send " << ins[i]->numel() << " recv " + << outs[i]->numel(); + + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), + outs[i]->numel(), NCCLTypeWrapper::type, reduction_op_, + comm->comms_[idx], stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " + << " finished allreduce. send " << ins[i]->numel() << " recv " + << outs[i]->numel(); + } + } +}; + +template +class NCCLReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + + auto ins = ctx.MultiInput("X"); // x0, x1, x2 + auto outs = ctx.MultiOutput("Out"); + + std::string reduction = ctx.Attr("reduction"); + ncclRedOp_t reduction_op_ = ncclSum; + + if (reduction == "ncclMin") { + reduction_op_ = ncclMin; + } else if (reduction == "ncclMax") { + reduction_op_ = ncclMax; + } else if (reduction == "ncclSum") { + reduction_op_ = ncclSum; + } else if (reduction == "ncclProd") { + reduction_op_ = ncclProd; + } else { + PADDLE_THROW("Invalid reduction. default ncclSum."); + } + + int root = ctx.Attr("root"); + auto* comm = ctx.Input("Communicator"); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + // device id + int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); + int idx = comm->GetCommId(gpu_id); + + auto ins_names = ctx.Inputs("X"); + std::hash hasher; + for (size_t i = 0; i < ins.size(); ++i) { + if (root == platform::kInvalidGPUId) { + root = hasher(ins_names[i]) % comm->comms_.size(); + } + T* recvbuffer = nullptr; + if (root == gpu_id) { + recvbuffer = outs[i]->mutable_data(ctx.GetPlace()); + } + + VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send " + << ins[i]->numel() << " recv " << outs[i]->numel(); + + PADDLE_ENFORCE(platform::dynload::ncclReduce( + ins[i]->data(), recvbuffer, ins[i]->numel(), + NCCLTypeWrapper::type, reduction_op_, root, comm->comms_[idx], + stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " << gpu_id << " finished reduce. send " + << ins[i]->numel() << " recv " << outs[i]->numel(); + } + } +}; + +template +class NCCLBcastKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + + int root = ctx.Attr("root"); + + auto* comm = ctx.Input("Communicator"); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + // device id + int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); + int idx = comm->GetCommId(gpu_id); + + if (idx == root) { + auto ins = ctx.MultiInput("X"); + for (size_t i = 0; i < ins.size(); ++i) { + VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send " + << ins[i]->numel(); + + VLOG(1) << " before ncclBcast"; + PADDLE_ENFORCE(platform::dynload::ncclBcast( + (void*)ins[i]->data(), ins[i]->numel(), NCCLTypeWrapper::type, + root, comm->comms_[idx], stream)); + VLOG(1) << " after ncclBcast"; + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " << gpu_id << " finished Bcast."; + } + } else { + auto outs = ctx.MultiOutput("Out"); + for (size_t i = 0; i < outs.size(); ++i) { + VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " + << framework::product(outs[i]->dims()); + + PADDLE_ENFORCE(platform::dynload::ncclBcast( + outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), + NCCLTypeWrapper::type, root, comm->comms_[idx], stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv " + << outs[i]->numel(); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel); +REGISTER_OP_GPU_KERNEL(ncclBcast, ops::NCCLBcastKernel); +REGISTER_OP_GPU_KERNEL(ncclReduce, ops::NCCLReduceKernel); diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..80c50a28a9e5d560fc693c518b9e62091ddc5724 --- /dev/null +++ b/paddle/operators/nccl_op_test.cu @@ -0,0 +1,307 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/program_desc.h" +#include "paddle/framework/var_desc.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/gpu_info.h" +#include "paddle/platform/place.h" + +USE_NO_KERNEL_OP(ncclInit); +USE_GPU_ONLY_OP(ncclAllReduce); +USE_GPU_ONLY_OP(ncclReduce); +USE_GPU_ONLY_OP(ncclBcast); + +namespace f = paddle::framework; +namespace p = paddle::platform; + +static std::vector gpu_list; + +// test data amount +const f::DDim kDims = {100, 100}; + +// nccl op common tester, init communicator. +class NCCLTester : public ::testing::Test { + public: + virtual void SetUp() override { + cpu_ctx = new p::CPUDeviceContext(p::CPUPlace()); + for (size_t i = 0; i < gpu_list.size(); ++i) { + p::GPUPlace place(i); + dev_ctxs.emplace_back(new p::CUDADeviceContext(place)); + } + + NCCLInitOp(); + } + + virtual void TearDown() override { + for (auto &device_context : dev_ctxs) { + delete device_context; + } + } + + void NCCLInitOp() { + std::unique_ptr op1(new f::OpDescBind); + + op1->SetType("ncclInit"); + op1->SetOutput("Communicator", {"comm"}); + op1->SetAttr("gpus", {gpu_list}); + + auto *var = g_scope.Var("comm"); + var->GetMutable(); + + auto op = f::OpRegistry::CreateOp(*op1); + VLOG(1) << "invoke NCCLInitOp."; + op->Run(g_scope, *cpu_ctx); + VLOG(1) << "NCCLInitOp finished."; + } + + template + void PerThreadProgram(int gpu_id, const f::OpDescBind &op_desc, + f::Scope *scope) { + std::unique_lock lk(mu); + const f::OpDescBind *op1 = &op_desc; + + p::GPUPlace place(gpu_id); + auto &ctx = dev_ctxs.at(gpu_id); + + auto *send_tensor = scope->Var("st")->GetMutable(); + auto *recv_tensor = scope->Var("rt")->GetMutable(); + + if (!send_tensor->numel()) { + send_tensor->Resize(kDims); + send_tensor->mutable_data(kDims, place); + + std::vector send_vector(f::product(kDims), gpu_id); + send_tensor->CopyFromVector(send_vector, *ctx); + ctx->Wait(); + VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel(); + } + + lk.unlock(); + + PADDLE_ENFORCE(send_tensor->numel() == f::product(kDims), + "Tensor numel not match!"); + + auto op = f::OpRegistry::CreateOp(*op1); + + VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type(); + VLOG(1) << " send_tensor : " << send_tensor->numel() + << " recv_tensor : " << recv_tensor->numel(); + op->Run(*scope, *ctx); + VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type(); + } + + public: + std::vector dev_ctxs; + p::DeviceContext *cpu_ctx; + f::Scope g_scope; + std::mutex mu; +}; + +// ncclInitOp with desc +TEST(NCCL, ncclInitOp) { + std::unique_ptr op_desc(new f::OpDescBind); + + op_desc->SetType("ncclInit"); + op_desc->SetOutput("Communicator", {"x1"}); + op_desc->SetAttr("gpus", {gpu_list}); + + f::Scope g_scope; + std::unique_ptr ctx(new p::CPUDeviceContext(p::CPUPlace())); + + auto *var = g_scope.Var("x1"); + var->GetMutable(); + + auto op = f::OpRegistry::CreateOp(*op_desc); + VLOG(1) << "invoke NCCLInitOp."; + op->Run(g_scope, *ctx.get()); + VLOG(1) << "NCCLInitOp finished."; +} + +// ncclAllReduceOp with desc +TEST_F(NCCLTester, ncclAllReduceOp) { + std::unique_ptr op2(new f::OpDescBind); + op2->SetType("ncclAllReduce"); + op2->SetInput("X", {"st"}); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + + std::vector dev_scopes; + + std::vector ths; + + for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); + std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], + *op2.get(), dev_scopes[i]); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + ths[i].join(); + } + + // check results + float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0); + + for (size_t i = 0; i < dev_scopes.size(); ++i) { + p::CPUPlace cpu_place; + p::GPUPlace gpu_place(gpu_list[i]); + + auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + auto *result_tensor = dev_scopes[i]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[i]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[i])->stream()); + + for (size_t j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } + } +} + +// ncclReduceOp with desc +TEST_F(NCCLTester, ncclReduceOp) { + std::unique_ptr op2(new f::OpDescBind); + const int kRoot = 0; + op2->SetType("ncclReduce"); + op2->SetInput("X", {"st"}); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + op2->SetAttr("root", kRoot); + + std::vector dev_scopes; + + std::vector ths; + + for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); + std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], + *op2.get(), dev_scopes[i]); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + ths[i].join(); + } + + // check results on + float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0); + + p::CPUPlace cpu_place; + p::GPUPlace gpu_place(gpu_list[kRoot]); + + auto &recv_tensor = dev_scopes[kRoot]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + auto *result_tensor = + dev_scopes[kRoot]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[kRoot]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[kRoot])->stream()); + + for (int j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } +} + +// ncclBcastOp with desc +TEST_F(NCCLTester, ncclBcastOp) { + std::unique_ptr op2(new f::OpDescBind); + const int kRoot = 5; + op2->SetType("ncclBcast"); + op2->SetInput("X", {"st"}); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + op2->SetAttr("root", kRoot); + + std::vector dev_scopes; + + std::vector ths; + + for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); + std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], + *op2.get(), dev_scopes[i]); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + ths[i].join(); + } + + const int idx = 1; + // check results on + float result = kRoot; + + p::CPUPlace cpu_place; + p::GPUPlace gpu_place(gpu_list[idx]); + + auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + auto *result_tensor = dev_scopes[idx]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[idx]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[idx])->stream()); + + for (size_t j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } +} + +int main(int argc, char **argv) { + const int dev_count = p::GetCUDADeviceCount(); + if (dev_count <= 1) { + LOG(WARNING) + << "Cannot test multi-gpu nccl, because the CUDA device count is " + << dev_count; + return 0; + } + + for (int i = 0; i < dev_count; ++i) { + gpu_list.emplace_back(i); + } + testing::InitGoogleTest(&argc, argv); + + // device context should be release before scope. + // otherwise driver will down. + return RUN_ALL_TESTS(); +} diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index a8eb8d45eec214842ee756a260127b9d0aacb0f4..eda8226480a66ae1a631391e9335db04604039c5 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -34,13 +34,19 @@ class ReshapeOp : public framework::OperatorWithKernel { auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); - for (auto dim : shape) { - PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); + auto x_dims = ctx->GetInputDim("X"); + // TODO(qiao) change batch_size + for (int i = 1; i < shape.size(); ++i) { + PADDLE_ENFORCE(shape[i] > 0, + "Each dimension of shape " + "must be positiv except the first."); + } + if (shape[0] < 0) { + shape[0] = x_dims[0]; } // capacity check int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - auto x_dims = ctx->GetInputDim("X"); int64_t in_size = framework::product(x_dims); PADDLE_ENFORCE_EQ(capacity, in_size, "The size of Input(X) mismatches with Attr(shape)."); diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index c89cdf8cab9f209667c5e09b521b8f6e30f202fd..beb951713ae2a9fd83fe7c1a5e97ee8c642158a8 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -26,13 +26,8 @@ class ReshapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); auto* in = ctx.Input("X"); + auto out_dims = out->dims(); out->mutable_data(ctx.GetPlace()); - - auto shape = ctx.Attr>("shape"); - std::vector shape_int64(shape.size(), 0); - std::transform(shape.begin(), shape.end(), shape_int64.begin(), - [](int a) { return static_cast(a); }); - auto out_dims = framework::make_ddim(shape_int64); out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context()); out->Resize(out_dims); } diff --git a/paddle/platform/nccl_test.cu b/paddle/platform/nccl_test.cu index ab8b96f7263aed83407866fedf9e529ce0affe3f..c99dae68bef67c58d3efea42fef45e84bb3d9255 100644 --- a/paddle/platform/nccl_test.cu +++ b/paddle/platform/nccl_test.cu @@ -31,9 +31,7 @@ namespace platform { TEST(NCCL, init) { std::vector comms; comms.resize(dev_count); - - auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); - PADDLE_ENFORCE(status); + PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr)); for (int i = 0; i < dev_count; ++i) { dynload::ncclCommDestroy(comms[i]); } @@ -64,8 +62,7 @@ TEST(NCCL, all_reduce) { std::vector comms; comms.resize(dev_count); VLOG(1) << "Initializing ncclComm"; - auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); - PADDLE_ENFORCE(status); + PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr)); VLOG(1) << "ncclComm initialized"; VLOG(1) << "Creating thread data"; std::vector>> data; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index e9c1d40de185768048159c10ee278ffde12335f7..bf6e12264269c7603484e0acf502adab25645856 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -33,6 +33,11 @@ limitations under the License. */ #include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/operators/nccl/nccl_gpu_common.h" +#include "paddle/platform/gpu_info.h" +#endif + namespace paddle { namespace pybind { static size_t UniqueIntegerGenerator() { @@ -204,6 +209,13 @@ All parameter, weight, gradient are variables in Paddle. return self.GetMutable(); }, py::return_value_policy::reference) +#ifdef PADDLE_WITH_CUDA + .def("get_communicator", + [](Variable &self) -> platform::Communicator * { + return self.GetMutable(); + }, + py::return_value_policy::reference) +#endif .def("get_net", [](Variable &self) -> operators::NetOp * { return self.GetMutable(); @@ -269,8 +281,11 @@ All parameter, weight, gradient are variables in Paddle. return new paddle::platform::CUDADeviceContext(place); #endif }); - // clang-format on +// clang-format on +#ifdef PADDLE_WITH_CUDA + py::class_(m, "Communicator").def(py::init<>()); +#endif py::class_(m, "GPUPlace") .def(py::init()) .def("__str__", string::to_string); @@ -479,6 +494,9 @@ All parameter, weight, gradient are variables in Paddle. BindOpDesc(m); m.def("op_support_gpu", OpSupportGPU); +#ifdef PADDLE_WITH_CUDA + m.def("get_cuda_device_count", platform::GetCUDADeviceCount); +#endif return m.ptr(); } diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 348c393913b3d73f9c9c16580d19a19551f2a57b..43101c9ddad76b7c1c322130dc0362a5c8ea4336 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -352,7 +352,10 @@ class Block(object): return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)} def create_var(self, *args, **kwargs): - return Variable(self, *args, **kwargs) + var = Variable(self, *args, **kwargs) + if 'init_attr' in kwargs: + self._prepend_initialize_ops_(var, kwargs['init_attr']) + return var def has_var(self, name): return name in self.vars diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 9e6d5f49db6f073833ad5f3a5faa3e1097287526..041a3b2c0b03c8171c2af9d856b33f461bb486c1 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -161,6 +161,7 @@ def _create_op_func_(op_type): _create_op_func_('mean') _create_op_func_('mul') _create_op_func_('dropout') +_create_op_func_('reshape') def cast(x, data_type, program=None): @@ -308,6 +309,96 @@ def pool2d(input, return pool_out +def batch_norm(input, + act=None, + is_test=False, + momentum=0.9, + epsilon=1e05, + param_attr=None, + bias_attr=None, + data_layout='NCHW', + program=None, + init_program=None): + helper = LayerHelper('batch_norm', **locals()) + dtype = helper.input_dtype() + + input_shape = input.shape + if data_layout == 'NCHW': + channel_num = input_shape[1] + else: + if data_layout == 'NHWC': + channel_num = input_shape[-1] + else: + raise ValueError("unsupported data layout:" + data_layout) + + def get_init_attr(value): + if not isinstance(value, float): + raise ValueError("attr value should be a float") + return {'type': 'fill_constant', 'value': value} + + def prepend_init_op(var, init_attr): + assert isinstance(var, Variable) + op_type = init_attr['type'] + init_attr['shape'] = var.shape + init_attr['data_type'] = int(var.data_type) + op = var.block.prepend_op( + type=op_type, inputs=None, outputs={'Out': [var]}, attrs=init_attr) + return op + + def create_persistable_var(dtype, shape, init_attr=None): + name = unique_name(".".join([helper.name, "xxxx"])) + var = init_program.global_block().create_var( + dtype=dtype, shape=shape, name=name, persistable=True) + if 'init_attr' is not None: + prepend_init_op(var, init_attr) + return program.global_block().create_var( + name=name, dtype=dtype, shape=shape, persistable=True) + + param_shape = [channel_num] + + # create parameter + scale = helper.create_parameter( + attr=helper.param_attr, shape=param_shape, dtype=dtype) + bias = helper.create_parameter( + attr=helper.param_attr, shape=param_shape, dtype=dtype) + + # create input + mean = create_persistable_var(dtype, param_shape, get_init_attr(0.0)) + variance = create_persistable_var(dtype, param_shape, get_init_attr(1.0)) + + # create output + # mean and mean_out share the same memory + mean_out = mean + # variance and variance out share the same memory + variance_out = variance + saved_mean = helper.create_tmp_variable(dtype) + saved_variance = helper.create_tmp_variable(dtype) + + batch_norm_out = helper.create_tmp_variable(dtype) + + helper.append_op( + type="batch_norm", + inputs={ + "X": input, + "Scale": scale, + "Bias": bias, + "Mean": mean, + "Variance": variance + }, + outputs={ + "Y": batch_norm_out, + "MeanOut": mean_out, + "VarianceOut": variance_out, + "SavedMean": saved_mean, + "SavedVariance": saved_variance + }, + attrs={"momentum": momentum, + "epsilon": epsilon, + "is_test": is_test}) + + return helper.append_activation(batch_norm_out) + + class BlockGuard(object): """ BlockGuard used to create sub-block in program by using Python `with` diff --git a/python/paddle/v2/framework/nets.py b/python/paddle/v2/framework/nets.py index 8a83ebfb9639f6fae6344b68509a80580881dab0..803534fa391c49d646c5d98a442d35d06b98603e 100644 --- a/python/paddle/v2/framework/nets.py +++ b/python/paddle/v2/framework/nets.py @@ -7,6 +7,7 @@ def simple_img_conv_pool(input, pool_size, pool_stride, act, + pool_type='max', program=None, init_program=None): conv_out = layers.conv2d( @@ -20,7 +21,75 @@ def simple_img_conv_pool(input, pool_out = layers.pool2d( input=conv_out, pool_size=pool_size, - pool_type='max', + pool_type=pool_type, + pool_stride=pool_stride, + program=program, + init_program=init_program) + return pool_out + + +def img_conv_group(input, + conv_num_filter, + pool_size, + conv_padding=1, + conv_filter_size=3, + conv_act=None, + conv_with_batchnorm=False, + conv_batchnorm_drop_rate=None, + pool_stride=1, + pool_type=None, + program=None, + init_program=None): + """ + Image Convolution Group, Used for vgg net. + """ + tmp = input + assert isinstance(conv_num_filter, list) or \ + isinstance(conv_num_filter, tuple) + + def __extend_list__(obj): + if not hasattr(obj, '__len__'): + return [obj] * len(conv_num_filter) + else: + return obj + + conv_padding = __extend_list__(conv_padding) + conv_filter_size = __extend_list__(conv_filter_size) + conv_with_batchnorm = __extend_list__(conv_with_batchnorm) + conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate) + + for i in xrange(len(conv_num_filter)): + local_conv_act = conv_act + if conv_with_batchnorm[i]: + local_conv_act = None + + tmp = layers.conv2d( + input=tmp, + num_filters=conv_num_filter[i], + filter_size=conv_filter_size[i], + padding=conv_padding[i], + act=local_conv_act, + program=program, + init_program=init_program) + + if conv_with_batchnorm[i]: + tmp = layers.batch_norm( + input=tmp, + act=conv_act, + program=program, + init_program=init_program) + drop_rate = conv_batchnorm_drop_rate[i] + if abs(drop_rate) > 1e-5: + tmp = layers.dropout( + x=tmp, + dropout_prob=drop_rate, + program=program, + init_program=init_program) + + pool_out = layers.pool2d( + input=tmp, + pool_size=pool_size, + pool_type=pool_type, pool_stride=pool_stride, program=program, init_program=init_program) diff --git a/python/paddle/v2/framework/tests/test_image_classification_layer.py b/python/paddle/v2/framework/tests/test_image_classification_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..908cf44b88a5de88690f5e17a1da1b5f8b1d8079 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_image_classification_layer.py @@ -0,0 +1,75 @@ +import unittest + +import paddle.v2.framework.layers as layers +import paddle.v2.framework.nets as nets +from paddle.v2.framework.framework import Program + + +def conv_block(input, + num_filter, + groups, + dropouts, + program=None, + init_program=None): + return nets.img_conv_group( + input=input, + pool_size=2, + pool_stride=2, + conv_num_filter=[num_filter] * groups, + conv_filter_size=3, + conv_act='relu', + conv_with_batchnorm=True, + conv_batchnorm_drop_rate=dropouts, + pool_type='max', + program=program, + init_program=init_program) + + +class TestLayer(unittest.TestCase): + def test_batch_norm_layer(self): + program = Program() + init_program = Program() + images = layers.data( + name='pixel', + shape=[3, 48, 48], + data_type='float32', + program=program) + layers.batch_norm( + input=images, program=program, init_program=init_program) + + #print str(program) + + def test_dropout_layer(self): + program = Program() + init_program = Program() + images = layers.data( + name='pixel', + shape=[3, 48, 48], + data_type='float32', + program=program) + layers.dropout( + x=images, + dropout_prob=0.5, + program=program, + init_program=init_program) + + #print str(program) + + def test_img_conv_group(self): + program = Program() + init_program = Program() + + images = layers.data( + name='pixel', + shape=[3, 48, 48], + data_type='float32', + program=program, + init_program=init_program) + conv1 = conv_block(images, 64, 2, [0.3, 0], program, init_program) + conv2 = conv_block(conv1, 256, 3, [0.4, 0.4, 0], program, init_program) + + # print str(program) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_image_classification_train.py b/python/paddle/v2/framework/tests/test_image_classification_train.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb9051261ee6786ba78f62ea3bfd89ae90e1d74 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_image_classification_train.py @@ -0,0 +1,133 @@ +import paddle.v2 as paddle +import paddle.v2.framework.layers as layers +import paddle.v2.framework.nets as nets +import paddle.v2.framework.core as core +import paddle.v2.framework.optimizer as optimizer + +from paddle.v2.framework.framework import Program, g_program +from paddle.v2.framework.executor import Executor + +import numpy as np + + +def vgg16_bn_drop(input, program, init_program): + def conv_block(input, + num_filter, + groups, + dropouts, + program=None, + init_program=None): + return nets.img_conv_group( + input=input, + pool_size=2, + pool_stride=2, + conv_num_filter=[num_filter] * groups, + conv_filter_size=3, + conv_act='relu', + conv_with_batchnorm=True, + conv_batchnorm_drop_rate=dropouts, + pool_type='max', + program=program, + init_program=init_program) + + conv1 = conv_block(input, 64, 2, [0.3, 0], program, init_program) + conv2 = conv_block(conv1, 128, 2, [0.4, 0], program, init_program) + conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0], program, init_program) + conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0], program, init_program) + conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0], program, init_program) + + drop = layers.dropout( + x=conv5, dropout_prob=0.5, program=program, init_program=init_program) + fc1 = layers.fc(input=drop, + size=512, + act=None, + program=program, + init_program=init_program) + reshape1 = layers.reshape( + x=fc1, + shape=list(fc1.shape + (1, 1)), + program=program, + init_program=init_program) + bn = layers.batch_norm( + input=reshape1, act='relu', program=program, init_program=init_program) + drop2 = layers.dropout( + x=bn, dropout_prob=0.5, program=program, init_program=init_program) + fc2 = layers.fc(input=drop2, + size=512, + act=None, + program=program, + init_program=init_program) + return fc2 + + +init_program = Program() +program = Program() + +classdim = 10 +data_shape = [3, 32, 32] + +images = layers.data( + name='pixel', shape=data_shape, data_type='float32', program=program) + +label = layers.data( + name='label', + shape=[1], + data_type='int64', + program=program, + init_program=init_program) +vgg_net = vgg16_bn_drop(images, program, init_program) +predict = layers.fc(input=vgg_net, + size=classdim, + act='softmax', + program=program, + init_program=init_program) +cost = layers.cross_entropy( + input=predict, label=label, program=program, init_program=init_program) +avg_cost = layers.mean(x=cost, program=program, init_program=init_program) + +sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) +opts = sgd_optimizer.minimize(avg_cost) + +BATCH_SIZE = 128 +PASS_NUM = 1 + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.cifar.train10(), buf_size=128 * 10), + batch_size=BATCH_SIZE) + +place = core.CPUPlace() +exe = Executor(place) + +exe.run(init_program, feed={}, fetch_list=[]) + +for pass_id in range(PASS_NUM): + batch_id = 0 + for data in train_reader(): + img_data = np.array(map(lambda x: x[0].reshape(data_shape), + data)).astype("float32") + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + batch_size = 1 + for i in y_data.shape: + batch_size = batch_size * i + y_data = y_data.reshape([batch_size, 1]) + + tensor_img = core.LoDTensor() + tensor_y = core.LoDTensor() + tensor_img.set(img_data, place) + tensor_y.set(y_data, place) + + outs = exe.run(program, + feed={"pixel": tensor_img, + "label": tensor_y}, + fetch_list=[avg_cost]) + + loss = np.array(outs[0]) + # print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) + + # " loss:" + str(loss)) + batch_id = batch_id + 1 + + if batch_id > 1: + # this model is slow, so if we can train two mini batch, we think it works properly. + exit(0) +exit(1) diff --git a/python/paddle/v2/framework/tests/test_nccl_init_op.py b/python/paddle/v2/framework/tests/test_nccl_init_op.py new file mode 100644 index 0000000000000000000000000000000000000000..054909fdf5517a68c6a07971c65a1d5bdc20d4fa --- /dev/null +++ b/python/paddle/v2/framework/tests/test_nccl_init_op.py @@ -0,0 +1,39 @@ +import unittest, os +import numpy as np +import paddle.v2 as paddle +from paddle.v2.framework.op import Operator +import paddle.v2.framework.core as core +from op_test import OpTest, create_op, set_input + +if not core.is_compile_gpu(): + exit(0) + +gpu_count = core.get_cuda_device_count() + +if gpu_count <= 1: + exit(0) + +g_scope = core.Scope() +g_ctx = core.DeviceContext.create(core.CPUPlace()) + + +class TestNCCLInit(unittest.TestCase): + def test_init(self): + self.op_type = "ncclInit" + self.gpus = range(gpu_count) + + self.inputs = {} + self.attrs = {"gpus": self.gpus} + g_scope.var("Communicator").get_communicator() + self.outputs = {"Communicator": g_scope.find_var("Communicator")} + nccl_init = create_op( + g_scope, + op_type=self.op_type, + inputs=self.inputs, + outputs=self.outputs, + attrs=self.attrs) + nccl_init.run(g_scope, g_ctx) + + +if __name__ == "__main__": + unittest.main()