提交 9f32b61c 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into seq_expand_op

......@@ -228,6 +228,10 @@ class OpKernelRegistrar : public Registrar {
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU);
#define USE_GPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, GPU)
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
......
......@@ -122,7 +122,7 @@ class OperatorBase {
protected:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)opear
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
VariableNameMap inputs_;
......@@ -287,6 +287,16 @@ class ExecutionContext {
return device_context_;
}
//! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get actual name vector for this output.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
#ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
......@@ -398,6 +408,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
VLOG(3) << "Default IndicateDataType " << this->Type();
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
......
......@@ -126,11 +126,16 @@ class Tensor {
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder");
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::place() is called.");
return holder_->place();
}
std::type_index type() const { return holder_->type(); }
std::type_index type() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called.");
return holder_->type();
}
size_t memory_size() const;
......
......@@ -90,6 +90,13 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
# nccl_op contains several operators
if ("${TARGET}" STREQUAL "nccl_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n")
endif()
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
......@@ -121,6 +128,7 @@ function(op_library TARGET)
endfunction()
add_subdirectory(math)
add_subdirectory(nccl)
set(DEPS_OPS
recurrent_op
......@@ -130,6 +138,7 @@ set(DEPS_OPS
sum_op
pool_op
pool_with_index_op
nccl_op
sequence_conv_op
lstm_op)
......@@ -142,6 +151,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute)
......@@ -157,4 +169,8 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
if(WITH_GPU)
nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
......@@ -64,6 +65,9 @@ class BatchNormOp : public framework::OperatorWithKernel {
(tensor_format == TensorFormat::NCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"Input x must have 3 to 5 dimensions.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
......@@ -108,10 +112,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
"Store the global Variance when training");
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training");
"will apply to output when training")
.AsIntermediate();
AddOutput("SavedVariance",
"Variance of the current mini batch, "
"will apply to output when training");
"will apply to output when training")
.AsIntermediate();
AddComment(R"DOC(
https://arxiv.org/pdf/1502.03167.pdf
......@@ -135,7 +141,6 @@ class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
const int N = x_dims[0];
......@@ -289,6 +294,25 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
VLOG(3) << "IndicateDataType " << this->Type();
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::ToDataType(t->type());
}
};
template <typename T>
......
if(WITH_GPU)
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator )
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
namespace paddle {
namespace platform {} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/platform/device_context.h"
#include "paddle/platform/dynload/nccl.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace platform {
constexpr int kInvalidGPUId = -1;
struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
Communicator() {}
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
void InitAll(const std::vector<int>& gpus) {
comms_.resize(gpus.size());
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
}
~Communicator() {
for (size_t i = 0; i < comms_.size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy(comms_[i]);
}
}
DISABLE_COPY_AND_ASSIGN(Communicator);
};
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
// NCCLinitOp
class NCCLInitOp : public framework::OperatorBase {
public:
NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
}
platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>();
comm->InitAll(gpus);
}
};
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
AddAttr<std::vector<int>>("gpus", "gpu id lists");
AddAttr<int>("data_type", "output data type")
.SetDefault(framework::DataType::FP32);
AddComment(R"DOC(
create communicator.
)DOC");
}
};
// AllReduceOp
class NCCLAllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(
ctx->HasInput("Communicator"),
" Input(Communicator) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of AllReduce op input should not be NULL");
auto x_dims = ctx->GetInputsDim("X");
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction.");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// ReduceOp
class NCCLReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Reduce op input should not be NULL");
PADDLE_ENFORCE(
ctx->HasInput("Communicator"),
" Input(Communicator) of Reduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of Reduce op input should not be NULL");
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// BcastOp
class NCCLBcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// AllreduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddComment(R"DOC(
AllReduce the input tensors.
)DOC");
}
};
// ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op");
AddAttr<std::string>("reduction",
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddAttr<int>("root",
"root gpu of the parameter. if not "
"set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC(
Reduce the tensors)DOC");
}
};
// BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLBcastOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Bcast");
AddAttr<int>("root",
"root gpu of the parameter. if not "
"set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC(
Bcast the tensors.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp,
ops::NCCLBcastOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
ops::NCCLReduceOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenseshashernless required by applicable law or agreed
to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <functional>
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Communicator;
using framework::LoDTensor;
template <typename Type>
class NCCLTypeWrapper;
template <>
class NCCLTypeWrapper<float> {
public:
static const ncclDataType_t type = ncclFloat;
};
template <>
class NCCLTypeWrapper<double> {
public:
static const ncclDataType_t type = ncclDouble;
};
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") {
reduction_op_ = ncclMax;
} else if (reduction == "ncclSum") {
reduction_op_ = ncclSum;
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_THROW("Invalid reduction. default ncclSum.");
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << "gpu : "
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : "
<< " finished allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
}
}
};
template <typename T>
class NCCLReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2
auto outs = ctx.MultiOutput<LoDTensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") {
reduction_op_ = ncclMax;
} else if (reduction == "ncclSum") {
reduction_op_ = ncclSum;
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_THROW("Invalid reduction. default ncclSum.");
}
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size();
}
T* recvbuffer = nullptr;
if (root == gpu_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
}
VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
}
}
};
template <typename T>
class NCCLBcastKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
if (idx == root) {
auto ins = ctx.MultiInput<LoDTensor>("X");
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send "
<< ins[i]->numel();
VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
VLOG(1) << " after ncclBcast";
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast.";
}
} else {
auto outs = ctx.MultiOutput<LoDTensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) {
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
<< framework::product(outs[i]->dims());
PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
<< outs[i]->numel();
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel<float>);
REGISTER_OP_GPU_KERNEL(ncclBcast, ops::NCCLBcastKernel<float>);
REGISTER_OP_GPU_KERNEL(ncclReduce, ops::NCCLReduceKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <mutex>
#include <thread>
#include <utility>
#include <vector>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/var_desc.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h"
USE_NO_KERNEL_OP(ncclInit);
USE_GPU_ONLY_OP(ncclAllReduce);
USE_GPU_ONLY_OP(ncclReduce);
USE_GPU_ONLY_OP(ncclBcast);
namespace f = paddle::framework;
namespace p = paddle::platform;
static std::vector<int> gpu_list;
// test data amount
const f::DDim kDims = {100, 100};
// nccl op common tester, init communicator.
class NCCLTester : public ::testing::Test {
public:
virtual void SetUp() override {
cpu_ctx = new p::CPUDeviceContext(p::CPUPlace());
for (size_t i = 0; i < gpu_list.size(); ++i) {
p::GPUPlace place(i);
dev_ctxs.emplace_back(new p::CUDADeviceContext(place));
}
NCCLInitOp();
}
virtual void TearDown() override {
for (auto &device_context : dev_ctxs) {
delete device_context;
}
}
void NCCLInitOp() {
std::unique_ptr<f::OpDescBind> op1(new f::OpDescBind);
op1->SetType("ncclInit");
op1->SetOutput("Communicator", {"comm"});
op1->SetAttr("gpus", {gpu_list});
auto *var = g_scope.Var("comm");
var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *cpu_ctx);
VLOG(1) << "NCCLInitOp finished.";
}
template <class T>
void PerThreadProgram(int gpu_id, const f::OpDescBind &op_desc,
f::Scope *scope) {
std::unique_lock<std::mutex> lk(mu);
const f::OpDescBind *op1 = &op_desc;
p::GPUPlace place(gpu_id);
auto &ctx = dev_ctxs.at(gpu_id);
auto *send_tensor = scope->Var("st")->GetMutable<f::LoDTensor>();
auto *recv_tensor = scope->Var("rt")->GetMutable<f::LoDTensor>();
if (!send_tensor->numel()) {
send_tensor->Resize(kDims);
send_tensor->mutable_data<T>(kDims, place);
std::vector<T> send_vector(f::product(kDims), gpu_id);
send_tensor->CopyFromVector<T>(send_vector, *ctx);
ctx->Wait();
VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel();
}
lk.unlock();
PADDLE_ENFORCE(send_tensor->numel() == f::product(kDims),
"Tensor numel not match!");
auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type();
VLOG(1) << " send_tensor : " << send_tensor->numel()
<< " recv_tensor : " << recv_tensor->numel();
op->Run(*scope, *ctx);
VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type();
}
public:
std::vector<p::DeviceContext *> dev_ctxs;
p::DeviceContext *cpu_ctx;
f::Scope g_scope;
std::mutex mu;
};
// ncclInitOp with desc
TEST(NCCL, ncclInitOp) {
std::unique_ptr<f::OpDescBind> op_desc(new f::OpDescBind);
op_desc->SetType("ncclInit");
op_desc->SetOutput("Communicator", {"x1"});
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
std::unique_ptr<p::DeviceContext> ctx(new p::CPUDeviceContext(p::CPUPlace()));
auto *var = g_scope.Var("x1");
var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx.get());
VLOG(1) << "NCCLInitOp finished.";
}
// ncclAllReduceOp with desc
TEST_F(NCCLTester, ncclAllReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
op2->SetType("ncclAllReduce");
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
std::vector<f::Scope *> dev_scopes;
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
ths[i].join();
}
// check results
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
for (size_t i = 0; i < dev_scopes.size(); ++i) {
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[i]);
auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
auto *result_tensor = dev_scopes[i]->Var("ct")->GetMutable<f::LoDTensor>();
result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[i]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[i])->stream());
for (size_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
}
// ncclReduceOp with desc
TEST_F(NCCLTester, ncclReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 0;
op2->SetType("ncclReduce");
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes;
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
ths[i].join();
}
// check results on
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[kRoot]);
auto &recv_tensor = dev_scopes[kRoot]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
auto *result_tensor =
dev_scopes[kRoot]->Var("ct")->GetMutable<f::LoDTensor>();
result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[kRoot]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[kRoot])->stream());
for (int j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
// ncclBcastOp with desc
TEST_F(NCCLTester, ncclBcastOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 5;
op2->SetType("ncclBcast");
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes;
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
ths[i].join();
}
const int idx = 1;
// check results on
float result = kRoot;
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[idx]);
auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
auto *result_tensor = dev_scopes[idx]->Var("ct")->GetMutable<f::LoDTensor>();
result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[idx]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[idx])->stream());
for (size_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
int main(int argc, char **argv) {
const int dev_count = p::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
for (int i = 0; i < dev_count; ++i) {
gpu_list.emplace_back(i);
}
testing::InitGoogleTest(&argc, argv);
// device context should be release before scope.
// otherwise driver will down.
return RUN_ALL_TESTS();
}
......@@ -34,13 +34,19 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
for (auto dim : shape) {
PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive.");
auto x_dims = ctx->GetInputDim("X");
// TODO(qiao) change batch_size
for (int i = 1; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0,
"Each dimension of shape "
"must be positiv except the first.");
}
if (shape[0] < 0) {
shape[0] = x_dims[0];
}
// capacity check
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
auto x_dims = ctx->GetInputDim("X");
int64_t in_size = framework::product(x_dims);
PADDLE_ENFORCE_EQ(capacity, in_size,
"The size of Input(X) mismatches with Attr(shape).");
......
......@@ -26,13 +26,8 @@ class ReshapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X");
auto out_dims = out->dims();
out->mutable_data<T>(ctx.GetPlace());
auto shape = ctx.Attr<std::vector<int>>("shape");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context());
out->Resize(out_dims);
}
......
......@@ -31,9 +31,7 @@ namespace platform {
TEST(NCCL, init) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(status);
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
......@@ -64,8 +62,7 @@ TEST(NCCL, all_reduce) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
VLOG(1) << "Initializing ncclComm";
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(status);
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
VLOG(1) << "ncclComm initialized";
VLOG(1) << "Creating thread data";
std::vector<std::unique_ptr<PerThreadData<double>>> data;
......
......@@ -33,6 +33,11 @@ limitations under the License. */
#include "paddle/pybind/tensor_py.h"
#include "paddle/string/to_string.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
#endif
namespace paddle {
namespace pybind {
static size_t UniqueIntegerGenerator() {
......@@ -204,6 +209,13 @@ All parameter, weight, gradient are variables in Paddle.
return self.GetMutable<SelectedRows>();
},
py::return_value_policy::reference)
#ifdef PADDLE_WITH_CUDA
.def("get_communicator",
[](Variable &self) -> platform::Communicator * {
return self.GetMutable<platform::Communicator>();
},
py::return_value_policy::reference)
#endif
.def("get_net",
[](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>();
......@@ -269,8 +281,11 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CUDADeviceContext(place);
#endif
});
// clang-format on
// clang-format on
#ifdef PADDLE_WITH_CUDA
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
py::class_<platform::GPUPlace>(m, "GPUPlace")
.def(py::init<int>())
.def("__str__", string::to_string<const platform::GPUPlace &>);
......@@ -479,6 +494,9 @@ All parameter, weight, gradient are variables in Paddle.
BindOpDesc(m);
m.def("op_support_gpu", OpSupportGPU);
#ifdef PADDLE_WITH_CUDA
m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
#endif
return m.ptr();
}
......
......@@ -352,7 +352,10 @@ class Block(object):
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)}
def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs)
var = Variable(self, *args, **kwargs)
if 'init_attr' in kwargs:
self._prepend_initialize_ops_(var, kwargs['init_attr'])
return var
def has_var(self, name):
return name in self.vars
......
......@@ -161,6 +161,7 @@ def _create_op_func_(op_type):
_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('dropout')
_create_op_func_('reshape')
def cast(x, data_type, program=None):
......@@ -308,6 +309,96 @@ def pool2d(input,
return pool_out
def batch_norm(input,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
program=None,
init_program=None):
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
if data_layout == 'NHWC':
channel_num = input_shape[-1]
else:
raise ValueError("unsupported data layout:" + data_layout)
def get_init_attr(value):
if not isinstance(value, float):
raise ValueError("attr value should be a float")
return {'type': 'fill_constant', 'value': value}
def prepend_init_op(var, init_attr):
assert isinstance(var, Variable)
op_type = init_attr['type']
init_attr['shape'] = var.shape
init_attr['data_type'] = int(var.data_type)
op = var.block.prepend_op(
type=op_type, inputs=None, outputs={'Out': [var]}, attrs=init_attr)
return op
def create_persistable_var(dtype, shape, init_attr=None):
name = unique_name(".".join([helper.name, "xxxx"]))
var = init_program.global_block().create_var(
dtype=dtype, shape=shape, name=name, persistable=True)
if 'init_attr' is not None:
prepend_init_op(var, init_attr)
return program.global_block().create_var(
name=name, dtype=dtype, shape=shape, persistable=True)
param_shape = [channel_num]
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr, shape=param_shape, dtype=dtype)
bias = helper.create_parameter(
attr=helper.param_attr, shape=param_shape, dtype=dtype)
# create input
mean = create_persistable_var(dtype, param_shape, get_init_attr(0.0))
variance = create_persistable_var(dtype, param_shape, get_init_attr(1.0))
# create output
# mean and mean_out share the same memory
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
saved_mean = helper.create_tmp_variable(dtype)
saved_variance = helper.create_tmp_variable(dtype)
batch_norm_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="batch_norm",
inputs={
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
},
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test})
return helper.append_activation(batch_norm_out)
class BlockGuard(object):
"""
BlockGuard used to create sub-block in program by using Python `with`
......
......@@ -7,6 +7,7 @@ def simple_img_conv_pool(input,
pool_size,
pool_stride,
act,
pool_type='max',
program=None,
init_program=None):
conv_out = layers.conv2d(
......@@ -20,7 +21,75 @@ def simple_img_conv_pool(input,
pool_out = layers.pool2d(
input=conv_out,
pool_size=pool_size,
pool_type='max',
pool_type=pool_type,
pool_stride=pool_stride,
program=program,
init_program=init_program)
return pool_out
def img_conv_group(input,
conv_num_filter,
pool_size,
conv_padding=1,
conv_filter_size=3,
conv_act=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None,
pool_stride=1,
pool_type=None,
program=None,
init_program=None):
"""
Image Convolution Group, Used for vgg net.
"""
tmp = input
assert isinstance(conv_num_filter, list) or \
isinstance(conv_num_filter, tuple)
def __extend_list__(obj):
if not hasattr(obj, '__len__'):
return [obj] * len(conv_num_filter)
else:
return obj
conv_padding = __extend_list__(conv_padding)
conv_filter_size = __extend_list__(conv_filter_size)
conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)
for i in xrange(len(conv_num_filter)):
local_conv_act = conv_act
if conv_with_batchnorm[i]:
local_conv_act = None
tmp = layers.conv2d(
input=tmp,
num_filters=conv_num_filter[i],
filter_size=conv_filter_size[i],
padding=conv_padding[i],
act=local_conv_act,
program=program,
init_program=init_program)
if conv_with_batchnorm[i]:
tmp = layers.batch_norm(
input=tmp,
act=conv_act,
program=program,
init_program=init_program)
drop_rate = conv_batchnorm_drop_rate[i]
if abs(drop_rate) > 1e-5:
tmp = layers.dropout(
x=tmp,
dropout_prob=drop_rate,
program=program,
init_program=init_program)
pool_out = layers.pool2d(
input=tmp,
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
program=program,
init_program=init_program)
......
import unittest
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
from paddle.v2.framework.framework import Program
def conv_block(input,
num_filter,
groups,
dropouts,
program=None,
init_program=None):
return nets.img_conv_group(
input=input,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts,
pool_type='max',
program=program,
init_program=init_program)
class TestLayer(unittest.TestCase):
def test_batch_norm_layer(self):
program = Program()
init_program = Program()
images = layers.data(
name='pixel',
shape=[3, 48, 48],
data_type='float32',
program=program)
layers.batch_norm(
input=images, program=program, init_program=init_program)
#print str(program)
def test_dropout_layer(self):
program = Program()
init_program = Program()
images = layers.data(
name='pixel',
shape=[3, 48, 48],
data_type='float32',
program=program)
layers.dropout(
x=images,
dropout_prob=0.5,
program=program,
init_program=init_program)
#print str(program)
def test_img_conv_group(self):
program = Program()
init_program = Program()
images = layers.data(
name='pixel',
shape=[3, 48, 48],
data_type='float32',
program=program,
init_program=init_program)
conv1 = conv_block(images, 64, 2, [0.3, 0], program, init_program)
conv2 = conv_block(conv1, 256, 3, [0.4, 0.4, 0], program, init_program)
# print str(program)
if __name__ == '__main__':
unittest.main()
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.executor import Executor
import numpy as np
def vgg16_bn_drop(input, program, init_program):
def conv_block(input,
num_filter,
groups,
dropouts,
program=None,
init_program=None):
return nets.img_conv_group(
input=input,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts,
pool_type='max',
program=program,
init_program=init_program)
conv1 = conv_block(input, 64, 2, [0.3, 0], program, init_program)
conv2 = conv_block(conv1, 128, 2, [0.4, 0], program, init_program)
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0], program, init_program)
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0], program, init_program)
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0], program, init_program)
drop = layers.dropout(
x=conv5, dropout_prob=0.5, program=program, init_program=init_program)
fc1 = layers.fc(input=drop,
size=512,
act=None,
program=program,
init_program=init_program)
reshape1 = layers.reshape(
x=fc1,
shape=list(fc1.shape + (1, 1)),
program=program,
init_program=init_program)
bn = layers.batch_norm(
input=reshape1, act='relu', program=program, init_program=init_program)
drop2 = layers.dropout(
x=bn, dropout_prob=0.5, program=program, init_program=init_program)
fc2 = layers.fc(input=drop2,
size=512,
act=None,
program=program,
init_program=init_program)
return fc2
init_program = Program()
program = Program()
classdim = 10
data_shape = [3, 32, 32]
images = layers.data(
name='pixel', shape=data_shape, data_type='float32', program=program)
label = layers.data(
name='label',
shape=[1],
data_type='int64',
program=program,
init_program=init_program)
vgg_net = vgg16_bn_drop(images, program, init_program)
predict = layers.fc(input=vgg_net,
size=classdim,
act='softmax',
program=program,
init_program=init_program)
cost = layers.cross_entropy(
input=predict, label=label, program=program, init_program=init_program)
avg_cost = layers.mean(x=cost, program=program, init_program=init_program)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
opts = sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 128
PASS_NUM = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
place = core.CPUPlace()
exe = Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
for pass_id in range(PASS_NUM):
batch_id = 0
for data in train_reader():
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
batch_size = 1
for i in y_data.shape:
batch_size = batch_size * i
y_data = y_data.reshape([batch_size, 1])
tensor_img = core.LoDTensor()
tensor_y = core.LoDTensor()
tensor_img.set(img_data, place)
tensor_y.set(y_data, place)
outs = exe.run(program,
feed={"pixel": tensor_img,
"label": tensor_y},
fetch_list=[avg_cost])
loss = np.array(outs[0])
# print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) +
# " loss:" + str(loss))
batch_id = batch_id + 1
if batch_id > 1:
# this model is slow, so if we can train two mini batch, we think it works properly.
exit(0)
exit(1)
import unittest, os
import numpy as np
import paddle.v2 as paddle
from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
from op_test import OpTest, create_op, set_input
if not core.is_compile_gpu():
exit(0)
gpu_count = core.get_cuda_device_count()
if gpu_count <= 1:
exit(0)
g_scope = core.Scope()
g_ctx = core.DeviceContext.create(core.CPUPlace())
class TestNCCLInit(unittest.TestCase):
def test_init(self):
self.op_type = "ncclInit"
self.gpus = range(gpu_count)
self.inputs = {}
self.attrs = {"gpus": self.gpus}
g_scope.var("Communicator").get_communicator()
self.outputs = {"Communicator": g_scope.find_var("Communicator")}
nccl_init = create_op(
g_scope,
op_type=self.op_type,
inputs=self.inputs,
outputs=self.outputs,
attrs=self.attrs)
nccl_init.run(g_scope, g_ctx)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册