提交 2921f8a7 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into optimize-pserver-profiler-thread-pool
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
......@@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
......@@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass)
else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......
......@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass");
}
}
private:
......@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
......@@ -73,6 +73,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false};
bool remove_unnecessary_lock_{false};
// User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes
// from python side.
......
......@@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] {
auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
});
};
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
}
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
......
......@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected:
void RunImpl() override;
......@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
} // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
}
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
for (auto &op : ops) {
preceding_ops_[op.get()];
pending_ops_[op.get()];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get());
pending_ops_[op.get()].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
namespace details {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
size_t OpNumber() const;
std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
private:
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops);
void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
pending_ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
......@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
......
......@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
return nullptr;
}
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
......@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (next_compute_op->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
next_compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
}
}
......@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (compute_op->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(compute_op->Outputs().front());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
}
......
......@@ -29,6 +29,15 @@ set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2")
download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2_tester.cc)
# DAM
set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam")
download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz")
inference_analysis_test(test_analyzer_dam SRCS analyzer_dam_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS
--infer_model=${DAM_INSTALL_DIR}/model
--infer_data=${DAM_INSTALL_DIR}/data.txt
--use_analysis=0)
# chinese_ner
set(CHINESE_NER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/chinese_ner")
download_model_and_data(${CHINESE_NER_INSTALL_DIR} "chinese_ner_model.tar.gz" "chinese_ner-data.txt.tar.gz")
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using contrib::AnalysisConfig;
#define MAX_TURN_NUM 9
#define MAX_TURN_LEN 50
static std::vector<float> result_data;
struct DataRecord {
std::vector<std::vector<int64_t>>
turns[MAX_TURN_NUM]; // turns data : MAX_TURN_NUM
std::vector<std::vector<float>>
turns_mask[MAX_TURN_NUM]; // turns mask data : MAX_TURN_NUM
std::vector<std::vector<int64_t>> response; // response data : 1
std::vector<std::vector<float>> response_mask; // response mask data : 1
size_t batch_iter{0};
size_t batch_size{1};
size_t num_samples; // total number of samples
DataRecord() = default;
explicit DataRecord(const std::string &path, int batch_size = 1)
: batch_size(batch_size) {
Load(path);
}
DataRecord NextBatch() {
DataRecord data;
size_t batch_end = batch_iter + batch_size;
// NOTE skip the final batch, if no enough data is provided.
if (batch_end <= response.size()) {
for (int i = 0; i < MAX_TURN_NUM; ++i) {
data.turns[i].assign(turns[i].begin() + batch_iter,
turns[i].begin() + batch_end);
}
for (int i = 0; i < MAX_TURN_NUM; ++i) {
data.turns_mask[i].assign(turns_mask[i].begin() + batch_iter,
turns_mask[i].begin() + batch_end);
}
data.response.assign(response.begin() + batch_iter,
response.begin() + batch_end);
data.response_mask.assign(response_mask.begin() + batch_iter,
response_mask.begin() + batch_end);
CHECK(!data.response.empty());
CHECK(!data.response_mask.empty());
CHECK_EQ(data.response.size(), data.response_mask.size());
}
batch_iter += batch_size;
return data;
}
void Load(const std::string &path) {
std::ifstream file(path);
std::string line;
size_t num_lines = 0;
result_data.clear();
while (std::getline(file, line)) {
num_lines++;
std::vector<std::string> data;
split(line, ',', &data);
CHECK_EQ(data.size(), 2 * MAX_TURN_NUM + 3);
// load turn data
std::vector<int64_t> turns_tmp[MAX_TURN_NUM];
for (int i = 0; i < MAX_TURN_NUM; ++i) {
split_to_int64(data[i], ' ', &turns_tmp[i]);
turns[i].push_back(std::move(turns_tmp[i]));
}
// load turn_mask data
std::vector<float> turns_mask_tmp[MAX_TURN_NUM];
for (int i = 0; i < MAX_TURN_NUM; ++i) {
split_to_float(data[MAX_TURN_NUM + i], ' ', &turns_mask_tmp[i]);
turns_mask[i].push_back(std::move(turns_mask_tmp[i]));
}
// load response data
std::vector<int64_t> response_tmp;
split_to_int64(data[2 * MAX_TURN_NUM], ' ', &response_tmp);
response.push_back(std::move(response_tmp));
// load response_mask data
std::vector<float> response_mask_tmp;
split_to_float(data[2 * MAX_TURN_NUM + 1], ' ', &response_mask_tmp);
response_mask.push_back(std::move(response_mask_tmp));
// load result data
float result_tmp;
result_tmp = std::stof(data[2 * MAX_TURN_NUM + 2]);
result_data.push_back(result_tmp);
}
num_samples = num_lines;
}
};
void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
int batch_size) {
PaddleTensor turns_tensor[MAX_TURN_NUM];
PaddleTensor turns_mask_tensor[MAX_TURN_NUM];
PaddleTensor response_tensor;
PaddleTensor response_mask_tensor;
std::string turn_pre = "turn_";
std::string turn_mask_pre = "turn_mask_";
auto one_batch = data->NextBatch();
int size = one_batch.response[0].size();
CHECK_EQ(size, MAX_TURN_LEN);
// turn tensor assignment
for (int i = 0; i < MAX_TURN_NUM; ++i) {
turns_tensor[i].name = turn_pre + std::to_string(i);
turns_tensor[i].shape.assign({batch_size, size, 1});
turns_tensor[i].dtype = PaddleDType::INT64;
TensorAssignData<int64_t>(&turns_tensor[i], one_batch.turns[i]);
}
// turn mask tensor assignment
for (int i = 0; i < MAX_TURN_NUM; ++i) {
turns_mask_tensor[i].name = turn_mask_pre + std::to_string(i);
turns_mask_tensor[i].shape.assign({batch_size, size, 1});
turns_mask_tensor[i].dtype = PaddleDType::FLOAT32;
TensorAssignData<float>(&turns_mask_tensor[i], one_batch.turns_mask[i]);
}
// response tensor assignment
response_tensor.name = "response";
response_tensor.shape.assign({batch_size, size, 1});
response_tensor.dtype = PaddleDType::INT64;
TensorAssignData<int64_t>(&response_tensor, one_batch.response);
// response mask tensor assignment
response_mask_tensor.name = "response_mask";
response_mask_tensor.shape.assign({batch_size, size, 1});
response_mask_tensor.dtype = PaddleDType::FLOAT32;
TensorAssignData<float>(&response_mask_tensor, one_batch.response_mask);
// Set inputs.
for (int i = 0; i < MAX_TURN_NUM; ++i) {
input_slots->push_back(std::move(turns_tensor[i]));
}
for (int i = 0; i < MAX_TURN_NUM; ++i) {
input_slots->push_back(std::move(turns_mask_tensor[i]));
}
input_slots->push_back(std::move(response_tensor));
input_slots->push_back(std::move(response_mask_tensor));
}
void SetConfig(contrib::AnalysisConfig *cfg) {
cfg->prog_file = FLAGS_infer_model + "/__model__";
cfg->param_file = FLAGS_infer_model + "/param";
cfg->use_gpu = false;
cfg->device = 0;
cfg->specify_input_name = true;
cfg->enable_ir_optim = true;
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
std::vector<PaddleTensor> input_slots;
int test_batch_num =
FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
LOG(INFO) << "The number of samples to be test: "
<< test_batch_num * FLAGS_batch_size;
for (int bid = 0; bid < test_batch_num; ++bid) {
input_slots.clear();
PrepareInputs(&input_slots, &data, FLAGS_batch_size);
(*inputs).emplace_back(input_slots);
}
}
// Easy for profiling independently.
TEST(Analyzer_dam, profile) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], result_data[i], 1e-3);
}
}
}
// Check the fuse status
TEST(Analyzer_dam, fuse_statis) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
if (FLAGS_use_analysis) {
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 317);
EXPECT_EQ(num_ops, 2020);
}
}
// Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_dam, compare) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
if (FLAGS_use_analysis) {
CompareNativeAndAnalysis(cfg, input_slots_all);
}
}
} // namespace inference
} // namespace paddle
......@@ -20,7 +20,6 @@ using contrib::AnalysisConfig;
struct DataRecord {
std::vector<std::vector<int64_t>> word_data_all, mention_data_all;
std::vector<std::vector<int64_t>> rnn_word_datas, rnn_mention_datas;
std::vector<size_t> lod; // two inputs have the same lod info.
size_t batch_iter{0};
size_t batch_size{1};
......@@ -45,8 +44,6 @@ struct DataRecord {
CHECK(!data.mention_data_all.empty());
CHECK_EQ(data.word_data_all.size(), data.mention_data_all.size());
for (size_t j = 0; j < data.word_data_all.size(); j++) {
data.rnn_word_datas.push_back(data.word_data_all[j]);
data.rnn_mention_datas.push_back(data.mention_data_all[j]);
// calculate lod
data.lod.push_back(data.lod.back() + data.word_data_all[j].size());
}
......@@ -87,8 +84,8 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
lod_mention_tensor.shape.assign({size, 1});
lod_mention_tensor.lod.assign({one_batch.lod});
// assign data
TensorAssignData<int64_t>(&lod_word_tensor, one_batch.rnn_word_datas);
TensorAssignData<int64_t>(&lod_mention_tensor, one_batch.rnn_mention_datas);
TensorAssignData<int64_t>(&lod_word_tensor, one_batch.word_data_all);
TensorAssignData<int64_t>(&lod_mention_tensor, one_batch.mention_data_all);
// Set inputs.
input_slots->assign({lod_word_tensor, lod_mention_tensor});
for (auto &tensor : *input_slots) {
......
......@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
......@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
};
......@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
......@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
// ------------------- cudnn conv backward filter ---------------------
......@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
}
......
......@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
......@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
};
......@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
output_grad->numel() / output_grad->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
......@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
......@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + filter_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
}
......
......@@ -75,7 +75,12 @@ if(WITH_GPU)
endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel
SRCS jit_kernel.cc jit_gen.cc jit_code.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas gflags enforce)
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak)
endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
......@@ -14,10 +14,13 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
......@@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel<T> {
static inline bool useMKL(int d) { return false; }
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
// roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
......@@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel<T> {
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VMulMKL<T>;
......@@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel<T> {
this->Compute = VMulRefer<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VMulJitCode::init(d);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VMulKernelImpl<float>::useMKL(int d) {
return jit::MayIUse(jit::avx512f) && d > 512;
......@@ -99,6 +110,7 @@ template <>
bool VMulKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
REGISTER_JITKERNEL(vmul, VMulKernel);
......
......@@ -26,7 +26,7 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out");
auto in_list = context.MultiInput<framework::Tensor>("X");
auto* trainer_id_t = context.Input<framework::Tensor>("TrainerId");
int64_t trainer_id;
int64_t trainer_id = 0;
auto* trainer_id_data = trainer_id_t->data<int64_t>();
if (platform::is_gpu_place(context.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
......@@ -38,7 +38,6 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
} else {
trainer_id = *trainer_id_data;
}
printf("after get trainer_id %lu\n", trainer_id);
PADDLE_ENFORCE_LT(trainer_id, in_list.size());
out->mutable_data<T>(context.GetPlace());
out->ShareDataWith(*(in_list[trainer_id]));
......
......@@ -179,7 +179,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
......@@ -198,7 +198,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
......@@ -243,7 +243,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
......
......@@ -153,34 +153,20 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_;
};
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
std::lock_guard<std::mutex> lock(mtx_);
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
}
~CudnnHolder() {
CudnnHolder::~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
}
}
private:
void ReallocateWorkspace(size_t required_workspace_len) {
void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
......@@ -191,17 +177,7 @@ class CudnnHolder {
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
}
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
}
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
......@@ -222,12 +198,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
driver_version_ = GetCUDADriverVersion(place_.device);
runtime_version_ = GetCUDARuntimeVersion(place_.device);
LOG(INFO) << "device: " << place_.device
LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
<< ", CUDA Capability: " << compute_capability_
<< ", Driver Version: " << driver_version_ / 1000 << "."
<< (driver_version_ % 100) / 10
<< ", Runtime Version: " << runtime_version_ / 1000 << "."
<< (runtime_version_ % 100) / 10;
<< ", Driver Version: " << driver_version_ / 1000
<< "." << (driver_version_ % 100) / 10
<< ", Runtime Version: " << runtime_version_ / 1000
<< "." << (runtime_version_ % 100) / 10;
callback_manager_.reset(new StreamCallbackManager(stream_));
}
......@@ -269,9 +245,8 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle();
}
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder_.get());
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
......@@ -73,7 +73,60 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice;
class CudnnHolder;
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place);
~CudnnHolder();
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
private:
friend class CudnnWorkspaceHandle;
void ReallocateWorkspace(size_t required_workspace_len);
template <typename Callback>
void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
std::mutex& Mutex() { return mtx_; }
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
class CudnnWorkspaceHandle {
public:
/*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */
inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {}
/*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */
template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
required_workspace_len);
}
CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
private:
CudnnHolder* holder_; // not own
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};
class CUDADeviceContext : public DeviceContext {
public:
......@@ -101,10 +154,14 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func,
size_t workspace_len) const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
* would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
......
......@@ -24,8 +24,6 @@
namespace paddle {
namespace platform {
using StreamCallback = std::function<void(cudaStream_t, cudaError_t)>;
class StreamCallbackManager;
struct StreamCallbackContext {
......@@ -35,7 +33,7 @@ struct StreamCallbackContext {
: manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own
StreamCallback callback_;
std::function<void()> callback_;
};
class StreamCallbackManager {
......@@ -45,16 +43,18 @@ class StreamCallbackManager {
template <typename Callback>
inline void AddCallback(Callback &&callback) const {
AddCallbackWithStreamAndErrorInfo(
[=](cudaStream_t, cudaError_t) { callback(); });
}
template <typename Callback>
inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const {
auto *stream_callback_context = new StreamCallbackContext(this, callback);
PADDLE_ENFORCE(cudaStreamAddCallback(
stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0));
auto *stream_callback_context =
new StreamCallbackContext(this, std::forward<Callback>(callback));
PADDLE_ENFORCE(
#if CUDA_VERSION >= 10000
cudaLaunchHostFunc(stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context)
#else
cudaStreamAddCallback(stream_,
StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0)
#endif
); // NOLINT
}
void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
......@@ -63,17 +63,21 @@ class StreamCallbackManager {
const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data)
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status,
void *user_data) {
cudaError_t status, void *user_data)
#endif
{
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_(stream, status);
callback_context->callback_();
});
}
};
......
......@@ -821,13 +821,24 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property("enable_sequential_execution",
.def_property(
"enable_sequential_execution",
[](const BuildStrategy &self) {
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequential_execution_ = b;
})
},
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
.def_property(
"remove_unnecessary_lock",
[](const BuildStrategy &self) {
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
......
......@@ -86,6 +86,8 @@ if(WITH_DISTRIBUTE)
# FIXME(typhoonzero): add this back
#py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
# TODO(typhoonzero): make dist test parallel when fix port management issue
set_tests_properties(test_dist_mnist test_dist_word2vec test_dist_se_resnext test_dist_ctr test_dist_simnet_bow test_dist_save_load test_dist_text_classification test_dist_mnist_batch_merge PROPERTIES RUN_SERIAL TRUE)
endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif()
......
......@@ -18,6 +18,7 @@ import multiprocessing
import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import time
import numpy as np
import math
......@@ -82,6 +83,8 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
exe = fluid.ParallelExecutor(
......
......@@ -174,7 +174,6 @@ class TestCRFModel(unittest.TestCase):
print(pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name])[0])
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
......@@ -183,7 +182,6 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
......@@ -192,7 +190,6 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......@@ -201,7 +198,6 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......
......@@ -1588,7 +1588,6 @@ to transpile() call.")
ref_inputs = []
for p, p_bak in self.param_bak_list:
if p.name == param_var.name:
print("#### ref inputs: ", param_var.name, p_bak.name)
ref_inputs.append(p_bak)
block.append_op(
type="ref_by_trainer_id",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册