From 594aabd15bcc48225e373e279a034c6febe9f305 Mon Sep 17 00:00:00 2001 From: liuqi Date: Tue, 31 Oct 2017 09:22:48 +0800 Subject: [PATCH] Change the way to sync at batch norm opencl benchmark. --- mace/ops/batch_norm_benchmark.cc | 4 ++-- mace/ops/ops_test_util.h | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index e1b53b92..c1ac84ef 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -39,14 +39,14 @@ static void BatchNorm( // Warm-up for (int i = 0; i < 5; ++i) { net.RunOp(D); - net.Sync(); } + net.Sync(); mace::testing::StartTiming(); while (iters--) { net.RunOp(D); - net.Sync(); } + net.Sync(); } #define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \ diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index f8bb94c2..4abde486 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -141,6 +141,7 @@ class OpsTestNet { net_def.add_op()->CopyFrom(op_def_); VLOG(3) << net_def.DebugString(); net_ = CreateNet(net_def, &ws_, device); + device_ = device; return net_->Run(); } @@ -151,7 +152,7 @@ class OpsTestNet { } void Sync() { - if (net_) { + if (net_ && device_ == DeviceType::OPENCL) { OpenCLRuntime::Get()->command_queue().finish(); } } @@ -160,6 +161,7 @@ class OpsTestNet { Workspace ws_; OperatorDef op_def_; std::unique_ptr net_; + DeviceType device_; }; class OpsTestBase : public ::testing::Test { -- GitLab