diff --git a/paddle/fluid/framework/details/execution_strategy.h b/paddle/fluid/framework/details/execution_strategy.h index 5183be878eb49cccc68603c3fdd8023be5578036..15c496130c2b6c7643ff96661be09e5ac4870344 100644 --- a/paddle/fluid/framework/details/execution_strategy.h +++ b/paddle/fluid/framework/details/execution_strategy.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include // for size_t namespace paddle { namespace framework { @@ -26,6 +27,7 @@ struct ExecutionStrategy { bool allow_op_delay_{false}; size_t num_iteration_per_drop_scope_{100}; ExecutorType type_{kDefault}; + bool dry_run_{false}; }; } // namespace details diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 98fc390e72fab3701538fd6f974460fa5114fdb0..2b2329b9698908fdbe3385f1d555d756c47fc5c0 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -128,7 +128,9 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( size_t complete = 0; while (op_to_run != nullptr) { try { - op_to_run->Run(strategy_.use_cuda_); + if (LIKELY(!strategy_.dry_run_)) { + op_to_run->Run(strategy_.use_cuda_); + } ++complete; } catch (...) { exception_.Catch(std::current_exception()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index dc63effd1b7c8fe5bb3fc91058eb855e552d3926..2d2bdb604f2d08adbaa0b38d04b8e377b2e6ab6c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -211,7 +211,9 @@ void ThreadedSSAGraphExecutor::RunOp( if (VLOG_IS_ON(10)) { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); } - op->Run(strategy_.use_cuda_); + if (LIKELY(!strategy_.dry_run_)) { + op->Run(strategy_.use_cuda_); + } VLOG(10) << op << " " << op->Name() << " Done "; running_ops_--; ready_var_q->Extend(op->Outputs()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index dbb0b498d995a897b109bd4ef98521b2193276ed..5c0bc169eaf3f54596eb8e08b7bf80a82253c9b2 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -48,7 +48,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { // Use topological sort algorithm FeedFetchList Run(const std::vector &fetch_tensors) override; - ~ThreadedSSAGraphExecutor() {} + ~ThreadedSSAGraphExecutor() final = default; private: void RunOp(const std::shared_ptr> &ready_var_q, diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index a45b9ec7a20ac3629d182f009b735d4d82fb5dc2..dfb107688ad7281765049cd9849d56b8a61bdd37 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -38,9 +38,20 @@ class ParallelExecutorPrivate { explicit ParallelExecutorPrivate(const std::vector &places) : places_(places) {} + ~ParallelExecutorPrivate() { + if (own_local_scope_) { + for (size_t i = 1; i < local_scopes_.size(); ++i) { + // Skip the first scope, since it is the global scope. + Scope *local_scope = local_scopes_[i]; + if (global_scope_->HasKid(local_scope)) { + global_scope_->DeleteScope(local_scope); + } + } + } + } std::vector places_; std::vector local_scopes_; - Scope *global_scope_; + Scope *global_scope_; // not owned std::unique_ptr executor_; #ifdef PADDLE_WITH_CUDA @@ -306,16 +317,6 @@ ParallelExecutor::~ParallelExecutor() { for (auto &p : member_->places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); } - - if (member_->own_local_scope_) { - for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { - Scope *local_scope = member_->local_scopes_[i]; - if (member_->global_scope_->HasKid(local_scope)) { - member_->global_scope_->DeleteScope(local_scope); - } - } - } - // member_ must be destructed before gcs_ since the destructor of // ReferenceCountOpHandle use raw pointers of gcs_ inside. member_.reset(); diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index d31c8e3b7d66a0cdb2c4725783c9a24f049c666d..e5678cf607a8ff3763e79c1f321a81c021846fb1 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -1,5 +1,5 @@ if(WITH_TESTING) - include(test.cmake) # some generic cmake funtion for inference + include(tests/test.cmake) # some generic cmake funtion for inference endif() # analysis and tensorrt must be added before creating static library, # otherwise, there would be undefined reference to them in static library. diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index b57a26b47026d1ecffab23b65c3eeb7de58f94eb..2ca84c80058b35840aff5d072cdc99ecf5165f8e 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1,5 +1,11 @@ set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor) +function(download_model install_dir model_name) + if (NOT EXISTS ${install_dir}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name}) + endif() +endfunction() + function(download_model_and_data install_dir model_name data_name) if (NOT EXISTS ${install_dir}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name}) @@ -13,6 +19,13 @@ function(inference_analysis_api_test target install_dir filename) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt) endfunction() +function(inference_analysis_api_test_with_fake_data target install_dir filename model_name) + download_model(${install_dir} ${model_name}) + inference_analysis_test(${target} SRCS ${filename} + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${install_dir}/model) +endfunction() + # RNN1 if(NOT APPLE) set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") @@ -61,17 +74,13 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana # ocr set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") if (NOT EXISTS ${OCR_INSTALL_DIR}) - inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") + inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") endif() inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) # resnet50 -set(RESNET50_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/resnet50") -if (NOT EXISTS ${RESNET50_INSTALL_DIR}) - inference_download_and_uncompress(${RESNET50_INSTALL_DIR} ${INFERENCE_URL} "resnet50_model.tar.gz") -endif() -inference_analysis_test(test_analyzer_resnet50 SRCS analyzer_resnet50_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${RESNET50_INSTALL_DIR}/model) +inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 + "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") # anakin if (WITH_ANAKIN AND WITH_MKL) # only needed in CI diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index c2151eea0823f80feb17b014c1f739d2a15ae862..e5c8dfd22a006d5271248c5b083aab2c22417502 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -30,25 +30,7 @@ void SetConfig(AnalysisConfig *cfg) { } void SetInput(std::vector> *inputs) { - PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); - - PaddleTensor input; - // channel=3, height/width=318 - std::vector shape({FLAGS_batch_size, 3, 318, 318}); - input.shape = shape; - input.dtype = PaddleDType::FLOAT32; - - // fill input data, for profile easily, do not use random data here. - size_t size = FLAGS_batch_size * 3 * 318 * 318; - input.data.Resize(size * sizeof(float)); - float *input_data = static_cast(input.data.data()); - for (size_t i = 0; i < size; i++) { - *(input_data + i) = static_cast(i) / size; - } - - std::vector input_slots; - input_slots.assign({input}); - (*inputs).emplace_back(input_slots); + SetFakeImageInput(inputs, FLAGS_infer_model); } // Easy for profiling independently. @@ -61,13 +43,6 @@ void profile(bool use_mkldnn = false) { std::vector> input_slots_all; SetInput(&input_slots_all); TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); - - if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { - PADDLE_ENFORCE_EQ(outputs.size(), 1UL); - size_t size = GetSize(outputs[0]); - // output is a 512-dimension feature - EXPECT_EQ(size, 512 * FLAGS_batch_size); - } } TEST(Analyzer_resnet50, profile) { profile(); } @@ -83,8 +58,7 @@ TEST(Analyzer_resnet50, fuse_statis) { auto predictor = CreatePaddlePredictor(cfg); auto fuse_statis = GetFuseStatis( static_cast(predictor.get()), &num_ops); - ASSERT_TRUE(fuse_statis.count("fc_fuse")); - EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); + LOG(INFO) << "num_ops: " << num_ops; } // Compare result of NativeConfig and AnalysisConfig diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 19c3f532d5dcb7588793fa21fa179f6b48649103..8c5888d8da7b33eeca77311c10dd818648e8e524 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -25,6 +25,7 @@ #include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h" +#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/platform/profiler.h" DEFINE_string(infer_model, "", "model path"); @@ -105,6 +106,34 @@ std::unordered_map GetFuseStatis(PaddlePredictor *predictor, return fuse_statis; } +void SetFakeImageInput(std::vector> *inputs, + const std::string &dirname) { + // Set fake_image_data + PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); + std::vector> feed_target_shapes = + GetFeedTargetShapes(dirname, true, "model", "params"); + int dim1 = feed_target_shapes[0][1]; + int dim2 = feed_target_shapes[0][2]; + int dim3 = feed_target_shapes[0][3]; + + PaddleTensor input; + std::vector shape({FLAGS_batch_size, dim1, dim2, dim3}); + input.shape = shape; + input.dtype = PaddleDType::FLOAT32; + + // fill input data, for profile easily, do not use random data here. + size_t size = FLAGS_batch_size * dim1 * dim2 * dim3; + input.data.Resize(size * sizeof(float)); + float *input_data = static_cast(input.data.data()); + for (size_t i = 0; i < size; i++) { + *(input_data + i) = static_cast(i) / size; + } + + std::vector input_slots; + input_slots.assign({input}); + (*inputs).emplace_back(input_slots); +} + void TestOneThreadPrediction( const AnalysisConfig &config, const std::vector> &inputs, diff --git a/paddle/fluid/inference/test.cmake b/paddle/fluid/inference/tests/test.cmake similarity index 100% rename from paddle/fluid/inference/test.cmake rename to paddle/fluid/inference/tests/test.cmake diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 94f0550df57e79fa68c135f5c9c4b7effe6ac156..2118fcfd4bb1589947617e462f09971fcc090b98 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/profiler.h" @@ -94,15 +93,15 @@ void CheckError(const paddle::framework::LoDTensor& output1, std::unique_ptr InitProgram( paddle::framework::Executor* executor, paddle::framework::Scope* scope, - const std::string& dirname, const bool is_combined = false) { + const std::string& dirname, const bool is_combined = false, + const std::string& prog_filename = "__model_combined__", + const std::string& param_filename = "__params_combined__") { std::unique_ptr inference_program; if (is_combined) { // All parameters are saved in a single file. // Hard-coding the file names of program and parameters in unittest. // The file names should be consistent with that used in Python API // `fluid.io.save_inference_model`. - std::string prog_filename = "__model_combined__"; - std::string param_filename = "__params_combined__"; inference_program = paddle::inference::Load(executor, scope, dirname + "/" + prog_filename, dirname + "/" + param_filename); @@ -115,12 +114,15 @@ std::unique_ptr InitProgram( } std::vector> GetFeedTargetShapes( - const std::string& dirname, const bool is_combined = false) { + const std::string& dirname, const bool is_combined = false, + const std::string& prog_filename = "__model_combined__", + const std::string& param_filename = "__params_combined__") { auto place = paddle::platform::CPUPlace(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); - auto inference_program = InitProgram(&executor, scope, dirname, is_combined); + auto inference_program = InitProgram(&executor, scope, dirname, is_combined, + prog_filename, param_filename); auto& global_block = inference_program->Block(0); const std::vector& feed_target_names = @@ -136,15 +138,6 @@ std::vector> GetFeedTargetShapes( return feed_target_shapes; } -void Compile(paddle::framework::ProgramDesc* program) { - std::unique_ptr g( - new paddle::framework::ir::Graph(*program)); - auto pass = paddle::framework::ir::PassRegistry::Instance().Get( - "graph_to_program_pass"); - pass->SetNotOwned("program", program); - pass->Apply(std::move(g)); -} - template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, @@ -182,7 +175,6 @@ void TestInference(const std::string& dirname, paddle::platform::DeviceContextPool::Instance().Get(place)); inference_program = InitProgram(&executor, scope, dirname, is_combined); } - Compile(inference_program.get()); // Disable the profiler and print the timing information paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, @@ -261,5 +253,3 @@ void TestInference(const std::string& dirname, delete scope; } - -USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fc821e04a0baf9278295da18ee5a69afcf2c4605..238cc19189cfd74afa38bdcb5f5c802f9521dfea 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -742,7 +742,12 @@ All parameter, weight, gradient are variables in Paddle. will clean up the temp variables at the end of the current iteration. 2. In some NLP model, it may cause the GPU memory is insufficient, in this case, you should reduce `num_iteration_per_drop_scope`. - )DOC"); + )DOC") + .def_property("_dry_run", + [](const ExecutionStrategy &self) { return self.dry_run_; }, + [](ExecutionStrategy &self, bool dry_run) { + self.dry_run_ = dry_run; + }); exec_strategy.def_property( "use_experimental_executor", diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 80b50022dd1ac5ec739029f6cfff3f7f170ada00..d1c926c4e4d41d55130a37e0bf2492f56fde0658 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -60,7 +60,7 @@ def data(name, For example if shape=[1], the resulting shape is [-1, 1]. 2. If shape contains -1, such as shape=[1, -1], append_batch_size will be enforced to be be False (ineffective). - dtype(int|float): The type of data : float32, float_16, int etc + dtype(basestring): The type of data : float32, float_16, int etc type(VarType): The output type. By default it is LOD_TENSOR. lod_level(int): The LoD Level. 0 means the input data is not a sequence. stop_gradient(bool): A boolean that mentions whether gradient should flow. diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_dry_run.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_dry_run.py new file mode 100644 index 0000000000000000000000000000000000000000..c93740669f40aee3a6c143d153cfd0f5bb72dbd9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_dry_run.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import logging +import six + + +class TestBase(unittest.TestCase): + def main(self, + network_func, + iter=100, + iter_per_pe=100, + use_gpu=True, + use_experimental_executor=False): + if use_gpu and not fluid.core.is_compiled_with_cuda(): + logging.warning( + "Paddle is not compiled with CUDA, skip GPU unittests") + return + + main_prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.Scope() + with fluid.program_guard(main_prog, startup_prog): + with fluid.scope_guard(scope): + loss = network_func() + fluid.Executor( + fluid.CUDAPlace(0) + if use_gpu else fluid.CPUPlace()).run(startup_prog) + + for _ in six.moves.xrange(iter): + exe_strategy = fluid.ExecutionStrategy() + exe_strategy._dry_run = True + exe_strategy.use_experimental_executor = use_experimental_executor + pe = fluid.ParallelExecutor( + use_cuda=True, + loss_name=loss.name, + main_program=main_prog, + exec_strategy=exe_strategy) + for _ in six.moves.xrange(iter_per_pe): + pe.run([]) + + +class TestMNISTDryRun(TestBase): + def test_mnist_dry_run(self): + for use_gpu in (False, True): + for use_experimental_executor in (False, True): + self.main( + network_func=TestMNISTDryRun.network_func, + use_gpu=use_gpu, + use_experimental_executor=use_experimental_executor) + + @staticmethod + def network_func(): + img = fluid.layers.data(name='img', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = img + for _ in six.moves.xrange(10): + hidden = fluid.layers.fc(input=img, size=200, act='tanh') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + fluid.optimizer.Adam().minimize(avg_loss) + return avg_loss + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index af3745987aa3eae96968bdc6b5c9cd951e9ca6fa..3eecc4670152e72443f731c71d7db67ca8e02e72 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -14,30 +14,18 @@ from __future__ import print_function -from parallel_executor_test_base import TestParallelExecutorBase -import paddle.fluid as fluid -import paddle.fluid.core as core -import numpy as np -import paddle -import paddle.dataset.mnist as mnist import unittest -import os -MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio" +import numpy as np +import paddle.fluid.core as core +import os +import paddle.fluid as fluid +from parallel_executor_test_base import TestParallelExecutorBase def simple_fc_net(use_feed): - if use_feed: - img = fluid.layers.data(name='image', shape=[784], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - else: - reader = fluid.layers.open_files( - filenames=[MNIST_RECORDIO_FILE], - shapes=[[-1, 784], [-1, 1]], - lod_levels=[0, 0], - dtypes=['float32', 'int64']) - reader = fluid.layers.io.double_buffer(reader) - img, label = fluid.layers.read_file(reader) + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') hidden = img for _ in range(4): hidden = fluid.layers.fc( @@ -53,17 +41,8 @@ def simple_fc_net(use_feed): def fc_with_batchnorm(use_feed): - if use_feed: - img = fluid.layers.data(name='image', shape=[784], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - else: - reader = fluid.layers.open_files( - filenames=[MNIST_RECORDIO_FILE], - shapes=[[-1, 784], [-1, 1]], - lod_levels=[0, 0], - dtypes=['float32', 'int64']) - reader = fluid.layers.io.double_buffer(reader) - img, label = fluid.layers.read_file(reader) + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') hidden = img for _ in range(1): @@ -88,19 +67,6 @@ class TestMNIST(TestParallelExecutorBase): @classmethod def setUpClass(cls): os.environ['CPU_NUM'] = str(4) - # Convert mnist to recordio file - with fluid.program_guard(fluid.Program(), fluid.Program()): - reader = paddle.batch(mnist.train(), batch_size=4) - feeder = fluid.DataFeeder( - feed_list=[ # order is image and label - fluid.layers.data( - name='image', shape=[784]), - fluid.layers.data( - name='label', shape=[1], dtype='int64'), - ], - place=fluid.CPUPlace()) - fluid.recordio_writer.convert_reader_to_recordio_file( - MNIST_RECORDIO_FILE, reader, feeder) def _init_data(self): np.random.seed(5) @@ -111,10 +77,6 @@ class TestMNIST(TestParallelExecutorBase): def _compare_reduce_and_allreduce(self, model, use_cuda): if use_cuda and not core.is_compiled_with_cuda(): return - self.check_network_convergence( - model, use_cuda=use_cuda, use_reduce=True) - self.check_network_convergence( - model, use_cuda=use_cuda, allow_op_delay=True, use_reduce=True) img, label = self._init_data() @@ -140,9 +102,6 @@ class TestMNIST(TestParallelExecutorBase): def check_simple_fc_convergence(self, use_cuda, use_reduce=False): if use_cuda and not core.is_compiled_with_cuda(): return - self.check_network_convergence(simple_fc_net, use_cuda=use_cuda) - self.check_network_convergence( - simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) img, label = self._init_data() @@ -199,8 +158,6 @@ class TestMNIST(TestParallelExecutorBase): if use_cuda and not core.is_compiled_with_cuda(): return - self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda) - img, label = self._init_data() self.check_network_convergence( diff --git a/python/setup.py.in b/python/setup.py.in index ee19294ad5c884cf73a4f14290f61f0b345ea8c7..b1ff9f3a5c3d877edb6bc6a12efce053a44b4c9c 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -14,7 +14,8 @@ RC = 0 def git_commit(): try: cmd = ['git', 'rev-parse', 'HEAD'] - git_commit = subprocess.Popen(cmd, stdout = subprocess.PIPE).communicate()[0].strip() + git_commit = subprocess.Popen(cmd, stdout = subprocess.PIPE, + cwd="@PADDLE_SOURCE_DIR@").communicate()[0].strip() except: git_commit = 'Unknown' git_commit = git_commit.decode() @@ -44,7 +45,7 @@ def get_patch(): def is_taged(): try: cmd = ['git', 'describe', '--exact-match', '--tags', 'HEAD', '2>/dev/null'] - git_tag = subprocess.Popen(cmd, stdout = subprocess.PIPE).communicate()[0].strip() + git_tag = subprocess.Popen(cmd, stdout = subprocess.PIPE, cwd="@PADDLE_SOURCE_DIR@").communicate()[0].strip() git_tag = git_tag.decode() except: return False @@ -55,8 +56,7 @@ def is_taged(): return False def write_version_py(filename='paddle/version.py'): - cnt = ''' -# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY + cnt = '''# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY # full_version = '%(major)d.%(minor)d.%(patch)s' major = '%(major)d'