未验证 提交 99acf1da 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #10351 from chengduoZH/feature/update_sparse_parameter

Feature/update sparse parameter
...@@ -15,12 +15,14 @@ if(WITH_GPU) ...@@ -15,12 +15,14 @@ if(WITH_GPU)
dynload_cuda) dynload_cuda)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
else() else()
set(multi_devices_graph_builder_deps) set(multi_devices_graph_builder_deps)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
endif() endif()
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
......
...@@ -19,14 +19,12 @@ ...@@ -19,14 +19,12 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
// the input and output may have dummy var. if (places_.size() == 1) return;
VarHandle *in_var_handle;
// The input and output may have dummy vars.
VarHandle *in_var_handle;
{ {
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1, PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
...@@ -55,27 +53,97 @@ void BroadcastOpHandle::RunImpl() { ...@@ -55,27 +53,97 @@ void BroadcastOpHandle::RunImpl() {
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
for (auto *out : out_var_handles) { // NOTE: The tensors' Place of input and output must be all on GPU or all on
if (*out == *in_var_handle) { // CPU.
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue; continue;
} }
auto t_out_p = out_var_handle->place_;
auto &out_p = out->place_; auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_); ->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(), if (platform::is_gpu_place(in_tensor.place())) {
"Places must be all on CPU or all on CUDA."); PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var); VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
in_tensor.type()); in_tensor.type());
}
if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue;
}
auto &out_p = out_var_handle->place_;
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
RunAndRecordEvent(out_p, [in_tensor, out_var] {
paddle::framework::TensorCopy(
in_tensor, platform::CPUPlace(),
&VariableVisitor::GetMutableTensor(out_var));
});
}
} else {
#ifdef PADDLE_WITH_CUDA
VarHandle *out_handle = nullptr;
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
std::vector<std::function<void()>> broadcast_calls;
for (auto out_var_handle : out_var_handles) {
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
int dst_id =
boost::get<platform::CUDAPlace>(out_var_handle->place_).device;
auto &nccl_ctx = nccl_ctxs_->at(dst_id);
void *send_recv_buffer = nullptr;
if (root_id == dst_id) {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
out_handle = out_var_handle;
} else {
send_recv_buffer =
VariableVisitor::GetMutableTensor(out_var).mutable_data(
out_var_handle->place_);
}
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
root_id, nccl_ctx.comm_, nccl_ctx.stream()));
});
}
auto dev_ctx = dev_ctxes_.at(out_p); this->RunAndRecordEvent([&] {
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { {
paddle::framework::TensorCopy( platform::NCCLGroupGuard guard;
in_tensor, out_p, *(dev_ctx), for (auto &call : broadcast_calls) {
&VariableVisitor::GetMutableTensor(out_var)); call();
}
}
if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
}); });
#else
PADDLE_THROW("CUDA is not enabled.");
#endif
} }
} }
......
...@@ -24,14 +24,32 @@ ...@@ -24,14 +24,32 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct BroadcastOpHandle : public OpHandleBase { struct BroadcastOpHandle : public OpHandleBase {
public: public:
#ifdef PADDLE_WITH_CUDA
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
}
}
}
#else
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override; std::string Name() const override;
...@@ -44,6 +62,9 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -44,6 +62,9 @@ struct BroadcastOpHandle : public OpHandleBase {
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
#endif
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -35,15 +35,25 @@ struct TestBroadcastOpHandle { ...@@ -35,15 +35,25 @@ struct TestBroadcastOpHandle {
std::unique_ptr<OpHandleBase> op_handle_; std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<std::unique_ptr<VarHandleBase>> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
bool use_gpu_;
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif
void WaitAll() { void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) { for (size_t j = 0; j < ctxs_.size(); ++j) {
ctxs_[j]->Wait(); ctxs_[j]->Wait();
} }
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_) {
nccl_ctxs_->WaitAll();
}
#endif
} }
void InitCtxOnGpu(bool use_gpu) { void InitCtxOnGpu(bool use_gpu) {
if (use_gpu) { use_gpu_ = use_gpu;
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
int count = p::GetCUDADeviceCount(); int count = p::GetCUDADeviceCount();
if (count <= 1) { if (count <= 1) {
...@@ -57,6 +67,7 @@ struct TestBroadcastOpHandle { ...@@ -57,6 +67,7 @@ struct TestBroadcastOpHandle {
gpu_list_.push_back(p); gpu_list_.push_back(p);
ctxs_.emplace_back(new p::CUDADeviceContext(p)); ctxs_.emplace_back(new p::CUDADeviceContext(p));
} }
nccl_ctxs_.reset(new platform::NCCLContextMap(gpu_list_));
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
...@@ -67,6 +78,9 @@ struct TestBroadcastOpHandle { ...@@ -67,6 +78,9 @@ struct TestBroadcastOpHandle {
gpu_list_.push_back(p); gpu_list_.push_back(p);
ctxs_.emplace_back(new p::CPUDeviceContext(p)); ctxs_.emplace_back(new p::CPUDeviceContext(p));
} }
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_.reset(nullptr);
#endif
} }
} }
...@@ -82,7 +96,21 @@ struct TestBroadcastOpHandle { ...@@ -82,7 +96,21 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
#else
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
#endif
}
auto* in_var_handle = auto* in_var_handle =
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
...@@ -97,7 +125,9 @@ struct TestBroadcastOpHandle { ...@@ -97,7 +125,9 @@ struct TestBroadcastOpHandle {
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
......
...@@ -25,6 +25,7 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -25,6 +25,7 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() { void GatherOpHandle::RunImpl() {
if (places_.size() == 1) return;
// the input and output may have dummy var. // the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
...@@ -35,7 +36,6 @@ void GatherOpHandle::RunImpl() { ...@@ -35,7 +36,6 @@ void GatherOpHandle::RunImpl() {
VarHandle *out_var_handle; VarHandle *out_var_handle;
{ {
auto out_var_handles = DynamicCast<VarHandle>(outputs_); auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one."); "The number of output should be one.");
out_var_handle = out_var_handles.front(); out_var_handle = out_var_handles.front();
...@@ -50,68 +50,62 @@ void GatherOpHandle::RunImpl() { ...@@ -50,68 +50,62 @@ void GatherOpHandle::RunImpl() {
auto pre_in_var = auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_);
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows."); "Currently, gather_op only can gather SelectedRows.");
auto pre_place = in_0_handle->place_;
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
"The place of input and output should be the same.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated(in_var_handles);
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors; std::vector<Tensor> in_tensors;
std::vector<platform::Place> in_places;
auto &pre_in = pre_in_var->Get<framework::SelectedRows>(); // Gather the inputs
// gather the inputs
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
in_places.push_back(in_p);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
auto *in_var = auto *in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>(); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var);
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(), auto &in_sr_value = in_var->Get<framework::SelectedRows>();
"The type of input is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
"The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent.");
auto &in_sr_rows = in_sr.rows(); auto &in_sr_rows = in_sr_value.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
in_tensors.emplace_back(in_sr_value.value());
in_tensors.emplace_back(in_sr.value());
} }
// write the output // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
auto &out_place = out_var_handle->place_; platform::Place t_out_p = out_var_handle->place_;
auto out_scope_idx = out_var_handle->scope_idx_; if (platform::is_gpu_place(pre_in_value.place())) {
auto out_var = var_scopes.at(out_scope_idx)->FindVar(out_var_handle->name_); PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
}
auto out = out_var->GetMutable<framework::SelectedRows>(); auto out_var =
out->set_height(pre_in.height()); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
out->set_rows(out_rows); PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_value = out_var->GetMutable<framework::SelectedRows>();
out_value->set_height(pre_in_value.height());
out_value->set_rows(out_rows);
size_t rows = out_rows.size(); size_t rows = out_rows.size();
DDim out_dim = pre_in.GetCompleteDims(); DDim out_dim = pre_in_value.GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows); out_dim[0] = static_cast<int64_t>(rows);
out->mutable_value()->Resize(out_dim); out_value->mutable_value()->Resize(out_dim).mutable_data(
out->mutable_value()->mutable_data(out_place, pre_in.value().type()); t_out_p, pre_in_value.value().type());
Tensor *out_tensor = out->mutable_value(); Tensor *out_tensor = out_value->mutable_value();
// copy // copy
auto dev_ctx = dev_ctxes_[out_place]; auto dev_ctx = dev_ctxes_[out_var_handle->place_];
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] { RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
t_out_p] {
int s = 0, e = 0; int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) { for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0]; e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e); auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx), paddle::framework::TensorCopy(in_tensors[j], t_out_p, *dev_ctx, &sub_out);
&sub_out);
s = e; s = e;
} }
}); });
......
...@@ -11,9 +11,11 @@ ...@@ -11,9 +11,11 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -34,8 +36,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -34,8 +36,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs) platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
...@@ -105,6 +107,11 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -105,6 +107,11 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types;
for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType();
}
auto graph = new SSAGraph(); auto graph = new SSAGraph();
SSAGraph &result = *graph; SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
...@@ -133,12 +140,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -133,12 +140,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
is_forwarding = false; is_forwarding = false;
} else { } else {
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
InsertNCCLAllReduceOp(&result, og); if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0);
} else {
InsertNCCLAllReduceOp(&result, og);
}
} }
} }
} }
...@@ -165,6 +177,50 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -165,6 +177,50 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const {
PADDLE_ENFORCE(var_types.count(og) != 0);
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name,
size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
#else
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
#endif
result->ops_.emplace_back(op_handle);
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i];
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
#ifndef ADDLE_WITH_CUDA
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
const OpDesc &op,
int dev_id) const {
result->ops_.emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id);
}
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const { const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
...@@ -174,7 +230,6 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( ...@@ -174,7 +230,6 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
} }
return nullptr; return nullptr;
} }
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -247,6 +302,36 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, ...@@ -247,6 +302,36 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_[i][og];
#ifndef PADDLE_WITH_CUDA
auto &p = places_[i];
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
}
auto &vars = result->vars_[dst_dev_id][og];
auto var =
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var);
op_handle->AddOutput(var);
return var;
}
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
auto &p = places_[0]; auto &p = places_[0];
...@@ -263,6 +348,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { ...@@ -263,6 +348,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
return op.OutputArgumentNames().size() == 1 && return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
...@@ -27,6 +27,7 @@ class NCCLContextMap; ...@@ -27,6 +27,7 @@ class NCCLContextMap;
namespace framework { namespace framework {
class Scope; class Scope;
namespace details { namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder { class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: public:
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -34,8 +35,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -34,8 +35,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool skip_scale_loss, platform::NCCLContextMap *nccl_ctxs,
platform::NCCLContextMap *nccl_ctxs); bool use_default_grad_scale);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
...@@ -74,6 +75,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -74,6 +75,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const; void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
...@@ -81,11 +86,18 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -81,11 +86,18 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;
/** /**
* Get send op in the global block of program. * Get send op in the global block of program.
* nullptr if not found. * nullptr if not found.
*/ */
OpDesc *GetSendOpDesc(const ProgramDesc &program) const; OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -22,6 +22,7 @@ namespace framework { ...@@ -22,6 +22,7 @@ namespace framework {
namespace details { namespace details {
void ReduceOpHandle::RunImpl() { void ReduceOpHandle::RunImpl() {
if (places_.size() == 1) return;
// the input and output may have dummy var. // the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
...@@ -51,44 +52,48 @@ void ReduceOpHandle::RunImpl() { ...@@ -51,44 +52,48 @@ void ReduceOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated(in_var_handles);
auto pre_place = in_0_handle->place_;
std::vector<platform::Place> in_places;
auto pre_in_tensor = VariableVisitor::GetMutableTensor(pre_in_var);
for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
in_places.emplace_back(in_p);
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx
for (auto *in_handle : in_var_handles) {
in_places.emplace_back(in_handle->place_);
auto in_var = auto in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*pre_in_var, *in_var);
auto in_tensor = VariableVisitor::GetMutableTensor(in_var);
PADDLE_ENFORCE_EQ(in_tensor.type(), pre_in_tensor.type(),
"The type of input is not consistent.");
} }
auto out_var = auto out_var =
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place();
platform::Place t_out_p;
if (platform::is_gpu_place(in_p)) {
PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place_),
"Places of input and output must be all on GPU.");
t_out_p = out_var_handle->place_;
} else {
t_out_p = platform::CPUPlace();
}
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<framework::SelectedRows>()) {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var_handle->place_,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else { } else {
std::vector<const LoDTensor *> lod_tensors = std::vector<const LoDTensor *> lod_tensors =
GetInputValues<LoDTensor>(in_var_handles, var_scopes); GetInputValues<LoDTensor>(in_var_handles, var_scopes);
if (paddle::platform::is_cpu_place(pre_place)) { if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(ToDataType(lod_tensors[0]->type()), func);
} else if (paddle::platform::is_gpu_place(pre_place)) { } else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto pre_in = pre_in_var->Get<framework::LoDTensor>(); auto pre_in = pre_in_var->Get<framework::LoDTensor>();
VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var); VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var);
...@@ -96,7 +101,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -96,7 +101,7 @@ void ReduceOpHandle::RunImpl() {
out_var_handle->place_, pre_in.type()); out_var_handle->place_, pre_in.type());
auto out_p = out_var_handle->place_; auto out_p = out_var_handle->place_;
int root = boost::get<platform::CUDAPlace>(out_p).device; int root_id = boost::get<platform::CUDAPlace>(out_p).device;
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < var_scopes.size(); ++i) { for (size_t i = 0; i < var_scopes.size(); ++i) {
auto &p = in_places[i]; auto &p = in_places[i];
...@@ -104,23 +109,23 @@ void ReduceOpHandle::RunImpl() { ...@@ -104,23 +109,23 @@ void ReduceOpHandle::RunImpl() {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
void *buffer = const_cast<void *>(lod_tensor.data<void>()); void *buffer = const_cast<void *>(lod_tensor.data<void>());
void *recvbuffer = nullptr; void *recvbuffer = nullptr;
if (root == dev_id) { if (root_id == dev_id) {
recvbuffer = recvbuffer =
out_var->GetMutable<framework::LoDTensor>()->mutable_data( out_var->GetMutable<framework::LoDTensor>()->mutable_data(
out_var_handle->place_); out_var_handle->place_);
} }
int type = platform::ToNCCLDataType(lod_tensor.type()); int type = platform::ToNCCLDataType(lod_tensor.type());
all_reduce_calls.emplace_back([=] { size_t numel = static_cast<size_t>(lod_tensor.numel());
PADDLE_ENFORCE(platform::dynload::ncclReduce( all_reduce_calls.emplace_back(
buffer, recvbuffer, static_cast<size_t>(lod_tensor.numel()), [buffer, recvbuffer, type, numel, root_id, &nccl_ctx] {
static_cast<ncclDataType_t>(type), ncclSum, root, comm, stream)); PADDLE_ENFORCE(platform::dynload::ncclReduce(
}); buffer, recvbuffer, numel, static_cast<ncclDataType_t>(type),
ncclSum, root_id, nccl_ctx.comm_, nccl_ctx.stream()));
});
} }
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
...@@ -130,7 +135,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -130,7 +135,7 @@ void ReduceOpHandle::RunImpl() {
} }
}); });
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not enabled.");
#endif #endif
} else { } else {
PADDLE_THROW("Place should be CPUPlace or CUDAPlace."); PADDLE_THROW("Place should be CPUPlace or CUDAPlace.");
......
...@@ -55,7 +55,7 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -55,7 +55,7 @@ struct ReduceOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; }; bool IsMultiDeviceTransfer() override { return true; };
protected: protected:
void RunImpl() override; void RunImpl() override;
......
...@@ -62,7 +62,7 @@ struct VarHandle : public VarHandleBase { ...@@ -62,7 +62,7 @@ struct VarHandle : public VarHandleBase {
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
bool operator==(const VarHandle& o) const { bool IsTheSameVar(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ && return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_; o.scope_idx_ == scope_idx_;
} }
......
...@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { ...@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
VisitVariable(src, &visitor); VisitVariable(src, &visitor);
} }
struct EnforceShapeAndDTypeEQVisitor {
const Variable* trg_;
void operator()(const LoDTensor& src) {
auto& tensor = trg_->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
src.place().which(), tensor.place().which(),
"The Places of the two Variable must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(src.type(), tensor.type(),
"The dtype of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.dims(), tensor.dims(),
"The dims of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.lod(), tensor.lod(),
"The lod of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.layout(), tensor.layout(),
"The layout of the two Variable's tensor is not equal.");
}
void operator()(const SelectedRows& src) {
auto& selected_rows = trg_->Get<SelectedRows>();
PADDLE_ENFORCE_EQ(
src.place().which(), selected_rows.place().which(),
"The Places of the two Variable must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(src.value().type(), selected_rows.value().type(),
"The dtype of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.value().layout(), selected_rows.value().layout(),
"The layout of the two Variable's tensor is not equal.");
PADDLE_ENFORCE_EQ(src.height(), selected_rows.height(),
"The height of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.GetCompleteDims(), selected_rows.GetCompleteDims(),
"The dims of the two Variable is not equal.");
}
template <typename T>
void operator()(const T&) {
PADDLE_ENFORCE("EnforceShapeAndDTypeEQ is not supported by type %s",
typeid(T).name());
}
};
void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1,
const Variable& var2) {
EnforceShapeAndDTypeEQVisitor visitor{&var1};
VisitVariable(var2, &visitor);
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,6 +26,9 @@ class VariableVisitor { ...@@ -26,6 +26,9 @@ class VariableVisitor {
static Tensor &GetMutableTensor(Variable *var); static Tensor &GetMutableTensor(Variable *var);
static void ShareDimsAndLoD(const Variable &src, Variable *trg); static void ShareDimsAndLoD(const Variable &src, Variable *trg);
static void EnforceShapeAndDTypeEQ(const Variable &var1,
const Variable &var2);
}; };
} // namespace details } // namespace details
......
...@@ -93,7 +93,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -93,7 +93,7 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
use_default_grad_scale, member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get(), use_default_grad_scale);
#else #else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_, params, member_->local_scopes_,
......
...@@ -43,7 +43,7 @@ class ParallelExecutor(object): ...@@ -43,7 +43,7 @@ class ParallelExecutor(object):
training. training.
allow_op_delay(bool, default False): Whether to delay and buffer allow_op_delay(bool, default False): Whether to delay and buffer
some operators together for scheduling or not, which may some operators together for scheduling or not, which may
improve performance in some cases, defalut False. improve performance in some cases, default False.
share_vars_from(ParallelExecutor, default None): If provied, share_vars_from(ParallelExecutor, default None): If provied,
it will share variables from the specified ParallelExecutor. it will share variables from the specified ParallelExecutor.
use_default_grad_scale(bool, default True): If set True, a default use_default_grad_scale(bool, default True): If set True, a default
...@@ -93,7 +93,7 @@ class ParallelExecutor(object): ...@@ -93,7 +93,7 @@ class ParallelExecutor(object):
if use_cuda: if use_cuda:
# Experiments on se-resnext shows that too many threads hurt # Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future. # performance. Worth tunning for other models in the future.
num_threads = len(self._places) num_threads = len(self._places) * 2
else: else:
num_threads = min( num_threads = min(
len(self._places) * 2, multiprocessing.cpu_count()) len(self._places) * 2, multiprocessing.cpu_count())
...@@ -130,6 +130,7 @@ class ParallelExecutor(object): ...@@ -130,6 +130,7 @@ class ParallelExecutor(object):
local_scopes, local_scopes,
allow_op_delay, allow_op_delay,
use_default_grad_scale) use_default_grad_scale)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None):
......
...@@ -280,7 +280,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -280,7 +280,7 @@ class TestMNIST(TestParallelExecutorBase):
fluid.recordio_writer.convert_reader_to_recordio_file( fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder) './mnist.recordio', reader, feeder)
def test_simple_fc(self): def check_simple_fc_convergence(self):
self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True) self.check_network_convergence(simple_fc_net, allow_op_delay=True)
...@@ -290,7 +290,10 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -290,7 +290,10 @@ class TestMNIST(TestParallelExecutorBase):
simple_fc_net, feed_dict={"image": img, simple_fc_net, feed_dict={"image": img,
"label": label}) "label": label})
def test_simple_fc_parallel_accuracy(self): def test_simple_fc(self):
self.check_simple_fc_convergence()
def check_simple_fc_parallel_accuracy(self):
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64') label = numpy.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
...@@ -311,7 +314,10 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -311,7 +314,10 @@ class TestMNIST(TestParallelExecutorBase):
for p_l in parallel_last_loss: for p_l in parallel_last_loss:
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6) self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_batchnorm_fc(self): def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy()
def check_batchnorm_fc_convergence(self):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm)
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64') label = numpy.ones(shape=[32, 1], dtype='int64')
...@@ -319,6 +325,9 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -319,6 +325,9 @@ class TestMNIST(TestParallelExecutorBase):
fc_with_batchnorm, feed_dict={"image": img, fc_with_batchnorm, feed_dict={"image": img,
"label": label}) "label": label})
def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence()
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
# @classmethod # @classmethod
...@@ -339,7 +348,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -339,7 +348,7 @@ class TestResnet(TestParallelExecutorBase):
# fluid.recordio_writer.convert_reader_to_recordio_file( # fluid.recordio_writer.convert_reader_to_recordio_file(
# "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress) # "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress)
def test_resnet(self): def check_resnet_convergence(self):
import functools import functools
batch_size = 2 batch_size = 2
self.check_network_convergence( self.check_network_convergence(
...@@ -348,6 +357,9 @@ class TestResnet(TestParallelExecutorBase): ...@@ -348,6 +357,9 @@ class TestResnet(TestParallelExecutorBase):
iter=20, iter=20,
batch_size=batch_size) batch_size=batch_size)
def test_resnet(self):
self.check_resnet_convergence()
class ModelHyperParams(object): class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses # Dictionary size for source and target language. This model directly uses
...@@ -510,7 +522,7 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -510,7 +522,7 @@ class TestTransformer(TestParallelExecutorBase):
class ParallelExecutorTestingDuringTraining(unittest.TestCase): class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def test_parallel_testing(self): def check_network_convergence(self):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -550,6 +562,9 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -550,6 +562,9 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
"Train loss: " + str(train_loss) + "\n Test loss:" + "Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss)) str(test_loss))
def test_parallel(self):
self.check_network_convergence()
import paddle.dataset.conll05 as conll05 import paddle.dataset.conll05 as conll05
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -568,21 +583,26 @@ embedding_name = 'emb' ...@@ -568,21 +583,26 @@ embedding_name = 'emb'
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
**ignored): is_sparse, **ignored):
# 8 features # 8 features
predicate_embedding = fluid.layers.embedding( predicate_embedding = fluid.layers.embedding(
input=predicate, input=predicate,
is_sparse=is_sparse,
size=[pred_dict_len, word_dim], size=[pred_dict_len, word_dim],
dtype='float32', dtype='float32',
param_attr='vemb') param_attr='vemb')
mark_embedding = fluid.layers.embedding( mark_embedding = fluid.layers.embedding(
input=mark, size=[mark_dict_len, mark_dim], dtype='float32') input=mark,
is_sparse=is_sparse,
size=[mark_dict_len, mark_dim],
dtype='float32')
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
emb_layers = [ emb_layers = [
fluid.layers.embedding( fluid.layers.embedding(
size=[word_dict_len, word_dim], size=[word_dict_len, word_dim],
is_sparse=is_sparse,
input=x, input=x,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=embedding_name, trainable=False)) for x in word_input name=embedding_name, trainable=False)) for x in word_input
...@@ -632,7 +652,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -632,7 +652,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
class TestCRFModel(unittest.TestCase): class TestCRFModel(unittest.TestCase):
def test_all(self): def check_network_convergence(self, is_sparse):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -652,6 +672,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -652,6 +672,7 @@ class TestCRFModel(unittest.TestCase):
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1) name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1)
mark = fluid.layers.data( mark = fluid.layers.data(
name='mark_data', shape=[1], dtype='int64', lod_level=1) name='mark_data', shape=[1], dtype='int64', lod_level=1)
feature_out = db_lstm(**locals()) feature_out = db_lstm(**locals())
target = fluid.layers.data( target = fluid.layers.data(
name='target', shape=[1], dtype='int64', lod_level=1) name='target', shape=[1], dtype='int64', lod_level=1)
...@@ -694,3 +715,9 @@ class TestCRFModel(unittest.TestCase): ...@@ -694,3 +715,9 @@ class TestCRFModel(unittest.TestCase):
print map(numpy.array, print map(numpy.array,
pe.run(feed=feeder.feed(cur_batch), pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0] fetch_list=[avg_cost.name]))[0]
def test_update_sparse_parameter(self):
self.check_network_convergence(is_sparse=True)
def test_update_dense_parameter(self):
self.check_network_convergence(is_sparse=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册