提交 0023c3bc 编写于 作者: Y Yu Yang

Use atomic bool

上级 09935ab9
...@@ -633,7 +633,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -633,7 +633,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size()); auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, bool> pending_vars; std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) { for (auto &place_pair : member_->vars_) {
...@@ -737,9 +737,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -737,9 +737,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
void ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, bool> &pending_vars, std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandle *op) const { OpHandle *op) const {
std::vector<bool *> ready_buffer; std::vector<std::atomic<bool> *> ready_buffer;
for (auto *var : op->outputs_) { for (auto *var : op->outputs_) {
ready_buffer.emplace_back(&pending_vars[var]); ready_buffer.emplace_back(&pending_vars[var]);
} }
......
...@@ -60,8 +60,9 @@ class ParallelExecutor { ...@@ -60,8 +60,9 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, bool>& pending_vars, void RunOp(
OpHandle* op) const; std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const; void PolishGraphToSupportDataHarzaeds() const;
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "cuda_runtime.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
TEST(Event, CpuElapsedTime) { TEST(Event, CpuElapsedTime) {
...@@ -157,3 +158,11 @@ TEST(RecordEvent, RecordEvent) { ...@@ -157,3 +158,11 @@ TEST(RecordEvent, RecordEvent) {
// Will remove parsing-related code from test later // Will remove parsing-related code from test later
DisableProfiler(EventSortingKey::kTotal, "/tmp/profiler"); DisableProfiler(EventSortingKey::kTotal, "/tmp/profiler");
} }
TEST(TMP, stream_wait) {
cudaStream_t stream;
cudaStreamCreate(&stream);
cudaStreamSynchronize(stream);
cudaStreamSynchronize(stream);
cudaStreamSynchronize(stream);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册