diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index e1b53b92e0bc36a471676f45f30aa4631d92e10c..c1ac84ef60e7c89ab2042f3815f18f1fbaf63da4 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -39,14 +39,14 @@ static void BatchNorm( // Warm-up for (int i = 0; i < 5; ++i) { net.RunOp(D); - net.Sync(); } + net.Sync(); mace::testing::StartTiming(); while (iters--) { net.RunOp(D); - net.Sync(); } + net.Sync(); } #define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \ diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index f8bb94c2d5fee156774222fff5d6a04e4cd4ca6b..4abde486559c3140820d5d144303d4b72ee73b82 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -141,6 +141,7 @@ class OpsTestNet { net_def.add_op()->CopyFrom(op_def_); VLOG(3) << net_def.DebugString(); net_ = CreateNet(net_def, &ws_, device); + device_ = device; return net_->Run(); } @@ -151,7 +152,7 @@ class OpsTestNet { } void Sync() { - if (net_) { + if (net_ && device_ == DeviceType::OPENCL) { OpenCLRuntime::Get()->command_queue().finish(); } } @@ -160,6 +161,7 @@ class OpsTestNet { Workspace ws_; OperatorDef op_def_; std::unique_ptr net_; + DeviceType device_; }; class OpsTestBase : public ::testing::Test {