提交 005fc8f4 编写于 作者: W wangguibao

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into async_executor

cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
...@@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif() endif()
...@@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap ...@@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass) if (WITH_GPU)
else() list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
endif() endif()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass");
}
} }
private: private:
...@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass); ...@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
...@@ -73,6 +73,8 @@ struct BuildStrategy { ...@@ -73,6 +73,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
bool remove_unnecessary_lock_{false};
// User normally doesn't need to call this API. // User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes // The PassBuilder allows for more customized insert, remove of passes
// from python side. // from python side.
......
...@@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] { auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}); };
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
} }
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
......
...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
bool is_lock_and_record_event_free_{false};
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
}
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
namespace details {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
size_t OpNumber() const;
std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
private:
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
pending_ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>( dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device)); platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
} }
...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
~ReferenceCountOpHandle() { ~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_)); PADDLE_ENFORCE(cudaEventDestroy(event_));
} }
} }
......
...@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { ...@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
return nullptr; return nullptr;
} }
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount); auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
...@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (next_compute_op->Outputs().empty()) { AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
next_compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
} }
} }
...@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (compute_op->Outputs().empty()) { AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(compute_op->Outputs().front());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
} }
......
...@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) { auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
...@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out)); &beta, cudnn_output_desc, output_data + i * group_offset_out));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
}; };
...@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
...@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in)); cudnn_input_desc, input_grad_data + i * group_offset_in));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
// ------------------- cudnn conv backward filter --------------------- // ------------------- cudnn conv backward filter ---------------------
...@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
} }
......
...@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int output_offset = output->numel() / output->dims()[0] / groups; int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups; int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) { auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
...@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
algo, cudnn_workspace, workspace_size_in_bytes, &beta, algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g)); cudnn_output_desc, output_data + output_offset * g));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
}; };
...@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
output_grad->numel() / output_grad->dims()[0] / groups; output_grad->numel() / output_grad->dims()[0] / groups;
int filter_offset = filter->numel() / groups; int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
...@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g)); input_grad_data + input_offset * g));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
...@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + filter_offset * g)); cudnn_filter_desc, filter_grad_data + filter_offset * g));
}; };
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
} }
......
...@@ -75,7 +75,12 @@ if(WITH_GPU) ...@@ -75,7 +75,12 @@ if(WITH_GPU)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel
SRCS jit_kernel.cc jit_gen.cc jit_code.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
DEPS cpu_info cblas gflags enforce) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak)
endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -14,10 +14,13 @@ limitations under the License. */ ...@@ -14,10 +14,13 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
...@@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel<T> {
static inline bool useMKL(int d) { return false; } static inline bool useMKL(int d) { return false; }
explicit VMulKernelImpl(int d) : VMulKernel<T>() { explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
...@@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel<T> {
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return; return;
} }
#endif
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
if (useMKL(d)) { if (useMKL(d)) {
this->Compute = VMulMKL<T>; this->Compute = VMulMKL<T>;
...@@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel<T> {
this->Compute = VMulRefer<T>; this->Compute = VMulRefer<T>;
} }
#ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr};
#endif
}; };
#ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VMulKernelImpl<float>::useJIT(int d) { bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VMulJitCode::init(d); return gen::VMulJitCode::init(d);
} }
#endif
#ifdef PADDLE_WITH_MKLML
template <> template <>
bool VMulKernelImpl<float>::useMKL(int d) { bool VMulKernelImpl<float>::useMKL(int d) {
return jit::MayIUse(jit::avx512f) && d > 512; return jit::MayIUse(jit::avx512f) && d > 512;
...@@ -99,6 +110,7 @@ template <> ...@@ -99,6 +110,7 @@ template <>
bool VMulKernelImpl<double>::useMKL(int d) { bool VMulKernelImpl<double>::useMKL(int d) {
return true; return true;
} }
#endif
REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vmul, VMulKernel);
......
...@@ -179,8 +179,8 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -179,8 +179,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad"); auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor); auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut"); auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out, PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor"); "MeanGrad and MeanGradOut must be the same Tensor");
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out); auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g; mg_out.device(place) = rho * mg + (1 - rho) * g;
...@@ -198,8 +198,8 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -198,8 +198,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) { if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad"); auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut"); auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out, PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor"); "MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>( for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()), param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()), mean_square_out->mutable_data<T>(ctx.GetPlace()),
...@@ -243,8 +243,8 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -243,8 +243,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) { if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad"); auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut"); auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out, PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor"); "MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>( for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()), param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()), mean_square_out->mutable_data<T>(ctx.GetPlace()),
......
...@@ -153,55 +153,31 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -153,55 +153,31 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
}; };
class CudnnHolder { CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
public: : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); }
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
void RunFunc(const std::function<void(void*)>& cudnn_func, CudnnHolder::~CudnnHolder() {
size_t required_workspace_len) { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
std::lock_guard<std::mutex> lock(mtx_); if (workspace_ != nullptr) {
if (required_workspace_len > workspace_len_) { paddle::memory::Free(place_, workspace_);
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
} }
}
~CudnnHolder() { void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); if (required_workspace_len <= workspace_len_) {
if (workspace_ != nullptr) { return;
paddle::memory::Free(place_, workspace_);
}
} }
if (workspace_ != nullptr) {
private: // Maybe someone is using the current workspace
void ReallocateWorkspace(size_t required_workspace_len) { PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
if (required_workspace_len <= workspace_len_) { paddle::memory::Free(place_, workspace_);
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
} }
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
cudnnHandle_t cudnn_handle_; workspace_len_ = required_workspace_len;
void* workspace_; }
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) { : place_(place), cudnn_holder_(nullptr) {
...@@ -269,9 +245,8 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { ...@@ -269,9 +245,8 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle(); return cudnn_holder_->cudnn_handle();
} }
void CUDADeviceContext::RunCudnnFuncWithWorkspace( CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const { return CudnnWorkspaceHandle(cudnn_holder_.get());
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
} }
cudaStream_t CUDADeviceContext::stream() const { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
...@@ -73,7 +73,60 @@ struct DefaultDeviceContextType<platform::CPUPlace> { ...@@ -73,7 +73,60 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
class CudnnHolder; class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place);
~CudnnHolder();
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
private:
friend class CudnnWorkspaceHandle;
void ReallocateWorkspace(size_t required_workspace_len);
template <typename Callback>
void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
std::mutex& Mutex() { return mtx_; }
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
class CudnnWorkspaceHandle {
public:
/*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */
inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {}
/*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */
template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
required_workspace_len);
}
CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
private:
CudnnHolder* holder_; // not own
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
...@@ -101,10 +154,14 @@ class CUDADeviceContext : public DeviceContext { ...@@ -101,10 +154,14 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
/*! \brief Run a cudnn function with the workspace provided by /*! \brief Return a cudnn workspace handle to call multiple cudnn
* CUDADeviceContext */ * functions without interrupting by other threads.
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func, * Once the first cudnn function is called by the handle, a lock
size_t workspace_len) const; * would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const; cudaStream_t stream() const;
......
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
namespace paddle { namespace paddle {
namespace platform { namespace platform {
using StreamCallback = std::function<void(cudaStream_t, cudaError_t)>;
class StreamCallbackManager; class StreamCallbackManager;
struct StreamCallbackContext { struct StreamCallbackContext {
...@@ -35,7 +33,7 @@ struct StreamCallbackContext { ...@@ -35,7 +33,7 @@ struct StreamCallbackContext {
: manager_(manager), callback_(callback) {} : manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own const StreamCallbackManager *manager_; // do not own
StreamCallback callback_; std::function<void()> callback_;
}; };
class StreamCallbackManager { class StreamCallbackManager {
...@@ -45,16 +43,18 @@ class StreamCallbackManager { ...@@ -45,16 +43,18 @@ class StreamCallbackManager {
template <typename Callback> template <typename Callback>
inline void AddCallback(Callback &&callback) const { inline void AddCallback(Callback &&callback) const {
AddCallbackWithStreamAndErrorInfo( auto *stream_callback_context =
[=](cudaStream_t, cudaError_t) { callback(); }); new StreamCallbackContext(this, std::forward<Callback>(callback));
} PADDLE_ENFORCE(
#if CUDA_VERSION >= 10000
template <typename Callback> cudaLaunchHostFunc(stream_, StreamCallbackManager::StreamCallbackFunc,
inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const { stream_callback_context)
auto *stream_callback_context = new StreamCallbackContext(this, callback); #else
PADDLE_ENFORCE(cudaStreamAddCallback( cudaStreamAddCallback(stream_,
stream_, StreamCallbackManager::StreamCallbackFunc, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0)); stream_callback_context, 0)
#endif
); // NOLINT
} }
void Wait() const { thread_pool_.reset(new ThreadPool(1)); } void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
...@@ -63,17 +63,21 @@ class StreamCallbackManager { ...@@ -63,17 +63,21 @@ class StreamCallbackManager {
const cudaStream_t stream_; const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_; mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use // cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here // thread_pool here
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data)
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, cudaError_t status, void *user_data)
void *user_data) { #endif
{
auto *callback_context_ptr = auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data); reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() { callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context( std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr); callback_context_ptr);
callback_context->callback_(stream, status); callback_context->callback_();
}); });
} }
}; };
......
...@@ -822,13 +822,24 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -822,13 +822,24 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b; self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important }) // FIXME(chengudo): enable_data_balance seems not important
.def_property("enable_sequential_execution", .def_property(
[](const BuildStrategy &self) { "enable_sequential_execution",
return self.enable_sequential_execution_; [](const BuildStrategy &self) {
}, return self.enable_sequential_execution_;
[](BuildStrategy &self, bool b) { },
self.enable_sequential_execution_ = b; [](BuildStrategy &self, bool b) {
}) self.enable_sequential_execution_ = b;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
.def_property(
"remove_unnecessary_lock",
[](const BuildStrategy &self) {
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
.def_property( .def_property(
"fuse_elewise_add_act_ops", "fuse_elewise_add_act_ops",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
......
...@@ -86,6 +86,8 @@ if(WITH_DISTRIBUTE) ...@@ -86,6 +86,8 @@ if(WITH_DISTRIBUTE)
# FIXME(typhoonzero): add this back # FIXME(typhoonzero): add this back
#py_test_modules(test_dist_transformer MODULES test_dist_transformer) #py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) #set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
# TODO(typhoonzero): make dist test parallel when fix port management issue
set_tests_properties(test_dist_mnist test_dist_word2vec test_dist_se_resnext test_dist_ctr test_dist_simnet_bow test_dist_save_load test_dist_text_classification test_dist_mnist_batch_merge PROPERTIES RUN_SERIAL TRUE)
endif(NOT APPLE) endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
......
...@@ -18,6 +18,7 @@ import multiprocessing ...@@ -18,6 +18,7 @@ import multiprocessing
import os import os
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import time import time
import numpy as np import numpy as np
import math import math
...@@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
......
...@@ -174,7 +174,6 @@ class TestCRFModel(unittest.TestCase): ...@@ -174,7 +174,6 @@ class TestCRFModel(unittest.TestCase):
print(pe.run(feed=feeder.feed(cur_batch), print(pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name])[0]) fetch_list=[avg_cost.name])[0])
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_all_reduce(self): def test_update_sparse_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
...@@ -183,7 +182,6 @@ class TestCRFModel(unittest.TestCase): ...@@ -183,7 +182,6 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False) is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_all_reduce(self): def test_update_dense_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
...@@ -192,7 +190,6 @@ class TestCRFModel(unittest.TestCase): ...@@ -192,7 +190,6 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence( self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False) is_sparse=False, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_reduce(self): def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
...@@ -201,7 +198,6 @@ class TestCRFModel(unittest.TestCase): ...@@ -201,7 +198,6 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False) is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_reduce(self): def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册