diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 6c7d2c10363d3e311dfae455f3dd8fcfc51077a0..05b5f3977cbed2f08df73c6d8ba2fff687db3313 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -80,6 +80,8 @@ parser.add_argument( type=str, default="", help="Comma-separated list of hostname:port pairs") +parser.add_argument( + "--profile", action='store_true', help="If set, profile a few steps.") # Flags for defining the tf.train.Server parser.add_argument( @@ -183,8 +185,8 @@ def main(): start_time = time.time() num_samples = 0 train_pass_acc.reset() - for batch_id, data in enumerate(train_reader()): - ts = time.time() + + def run_step(batch_id, data): img_data = np.array( map(lambda x: x[0].reshape(data_shape), data)).astype( "float32") @@ -196,14 +198,28 @@ def main(): feed={"pixel": img_data, "label": y_data}, fetch_list=[avg_cost, batch_acc, batch_size]) + return loss, acc, b_size + + if args.profile and args.task_index == 0: + # warmup. + for batch_id, data in enumerate(train_reader()): + if batch_id > 5: break + run_step(batch_id, data) + with profiler.profiler('All', 'total', '/tmp/profile_vgg'): + for batch_id, data in enumerate(train_reader()): + if batch_id > 5: break + run_step(batch_id, data) + + for batch_id, data in enumerate(train_reader()): + ts = time.time() + loss, acc, b_size = run_step(batch_id, data) iters += 1 num_samples += len(data) train_pass_acc.add(value=acc, weight=b_size) print( - "Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " - "Speed = %.2f img/s " % (args.task_index, pass_id, iters, - loss, acc, - len(data) / (time.time() - ts)) + "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " + "Speed = %.2f img/s" % (pass_id, iters, loss, acc, + len(data) / (time.time() - ts)) ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index 288789d6e484100820c937e6081701f1e9245706..c8b656394b403c4965e01e96c9215d9406091907 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,5 @@ nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) +nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/io_converter.cc b/paddle/fluid/inference/tensorrt/io_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..2baac96c26453af7e70e541d80b437df3d5c2657 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/io_converter.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/io_converter.h" +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +using platform::is_gpu_place; +using platform::is_cpu_place; + +class DefaultInputConverter : public EngineInputConverter { + public: + DefaultInputConverter() {} + // NOTE out is GPU memory. + virtual void operator()(const LoDTensor& in, void* out, + size_t max_size) override { + PADDLE_ENFORCE(out != nullptr); + PADDLE_ENFORCE_LE(in.memory_size(), max_size); + const auto& place = in.place(); + if (is_cpu_place(place)) { + PADDLE_ENFORCE(stream_ != nullptr); + PADDLE_ENFORCE_EQ(0, + cudaMemcpyAsync(out, in.data(), in.memory_size(), + cudaMemcpyHostToDevice, *stream_)); + + } else if (is_gpu_place(place)) { + PADDLE_ENFORCE_EQ(0, + cudaMemcpyAsync(out, in.data(), in.memory_size(), + cudaMemcpyHostToHost, *stream_)); + + } else { + PADDLE_THROW("Unknown device for converter"); + } + cudaStreamSynchronize(*stream_); + } +}; + +REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/io_converter.h b/paddle/fluid/inference/tensorrt/io_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..6ea61cbbac05f106f736b7d6a13912157c5ef48c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/io_converter.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/inference/utils/singleton.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +using framework::LoDTensor; + +/* + * Convert Input from Fluid to an Engine. + * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in + * most cases just need to copy the data. + */ +class EngineInputConverter { + public: + EngineInputConverter() {} + + virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {} + + void SetStream(cudaStream_t* stream) { stream_ = stream; } + + static void Run(const std::string& in_op_type, const LoDTensor& in, void* out, + size_t max_size, cudaStream_t* stream) { + PADDLE_ENFORCE(stream != nullptr); + auto* converter = Registry::Lookup(in_op_type); + PADDLE_ENFORCE_NOT_NULL(converter); + converter->SetStream(stream); + (*converter)(in, out, max_size); + } + + virtual ~EngineInputConverter() {} + + protected: + cudaStream_t* stream_{nullptr}; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \ + struct trt_input_##in_op_type__##_converter { \ + trt_input_##in_op_type__##_converter() { \ + ::paddle::inference::Registry::Register< \ + Converter__>(#in_op_type__); \ + } \ + }; \ + trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__; diff --git a/paddle/fluid/inference/tensorrt/test_io_converter.cc b/paddle/fluid/inference/tensorrt/test_io_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..365e9366862bee25c70dba0cdd92f318ab3ee90f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_io_converter.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/inference/tensorrt/io_converter.h" + +#include + +namespace paddle { +namespace inference { +namespace tensorrt { + +class EngineInputConverterTester : public ::testing::Test { + public: + void SetUp() override { tensor.Resize({10, 10}); } + + framework::LoDTensor tensor; +}; + +TEST_F(EngineInputConverterTester, DefaultCPU) { + void* buffer; + tensor.mutable_data(platform::CPUPlace()); + ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); + + cudaStream_t stream; + EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(), + &stream); +} + +TEST_F(EngineInputConverterTester, DefaultGPU) { + void* buffer; + tensor.mutable_data(platform::CUDAPlace()); + ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); + + cudaStream_t stream; + EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(), + &stream); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/book/CMakeLists.txt b/paddle/fluid/inference/tests/book/CMakeLists.txt index 97d9f03f88ad3e851a2dd4256d34e8ca76fdfb01..cc179a86256e6b552c08a091402157bdcc86b383 100644 --- a/paddle/fluid/inference/tests/book/CMakeLists.txt +++ b/paddle/fluid/inference/tests/book/CMakeLists.txt @@ -24,6 +24,11 @@ function(inference_test TARGET_NAME) endforeach() endfunction(inference_test) +#################### +# Inference tests here depend on fluid/tests/book. If users want to run +# individual test with ctest, they need to run tests in fluid/tests/book +# first to generate saved model. +#################### # This unittest is buggy! #inference_test(fit_a_line) inference_test(image_classification ARGS vgg resnet) diff --git a/paddle/fluid/inference/utils/singleton.h b/paddle/fluid/inference/utils/singleton.h new file mode 100644 index 0000000000000000000000000000000000000000..f05921067c45f156319375b613f51101cfda8e90 --- /dev/null +++ b/paddle/fluid/inference/utils/singleton.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { + +// NOTE not thread-safe. +template +struct Singleton { + static T& Global() { + static T* x = new T; + return *x; + } + + Singleton() = delete; + Singleton& operator=(const Singleton&) = delete; +}; + +/* + * An registor for any type. + * NOTE not thread-safe. + */ +template +struct Registry { + static Registry& Global() { + static auto* x = new Registry; + return *x; + } + + template + static void Register(const std::string& name) { + PADDLE_ENFORCE_EQ(items_.count(name), 0); + items_[name] = new ItemChild; + } + + static ItemParent* Lookup(const std::string& name) { + auto it = items_.find(name); + if (it == items_.end()) return nullptr; + return it->second; + } + + ~Registry() { + for (auto& item : items_) { + delete item.second; + } + } + + private: + Registry() = default; + static std::unordered_map items_; +}; + +template +std::unordered_map Registry::items_; + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 3b343a7e5675c1ef599c12a44063a7ae20317fff..9478c5702bcbf99fc88207b8c4843dbccf8a5925 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -70,6 +70,10 @@ message VariableMessage { bytes rows = 9; // Look up table block execution output variable name. string out_varname = 10; + // If true, the ps server will start profiling, the ps + // server stops profiling and generates a profile to /tmp/profile_ps_* + // when profile switches from true to false. + bool profile = 11; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index bde418a45a74319b73d6460ac1bf020f89ef1790..7b746a880575139f7bffb7a1da5e706752aed468 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/detail/proto_encoder_helper.h" #include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -48,6 +49,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void* payload = nullptr; size_t payload_size = 0; ProtoEncodeHelper e(static_cast(buf), 1024); + // Note: normally the profiler is enabled in 1 trainer, hence only + // 1 trainer returns true for ShouldSendProfileState(). It tells PS + // servers the trainer's profiling state so that PS can follow the + // trainer. + if (platform::ShouldSendProfileState()) { + e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled()); + } e.WriteString(VarMsg::kVarnameFieldNumber, name); if (var->IsType()) { e.WriteUint64(VarMsg::kTypeFieldNumber, 0); diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 9cf6dd90fc03f706015870368f862a22a321aa15..71c5e807eb1fb4d04d07302e30e13a8ec8634dc9 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -20,6 +20,7 @@ #ifdef PADDLE_WITH_CUDA #include #endif +#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" @@ -446,7 +447,26 @@ int VariableResponse::Parse(Source* source) { meta_.set_out_varname(temp); break; } - + case sendrecv::VariableMessage::kProfileFieldNumber: { + bool profiling; + if (!input.ReadRaw(reinterpret_cast(&profiling), 1)) { + return tag; + } + meta_.set_profile(profiling); + int64_t listener_id = platform::ListenerId(); + if (listener_id <= 0) { + break; + } + if (profiling && !platform::IsProfileEnabled()) { + platform::EnableProfiler(platform::ProfilerState::kCPU); + } else if (!profiling && platform::IsProfileEnabled()) { + // TODO(panyx0718): Should we allow to customize file dir. + platform::DisableProfiler( + platform::EventSortingKey::kDefault, + string::Sprintf("/tmp/profile_ps_%lld", listener_id)); + } + break; + } default: { // Unknown tag, return unknown error. return -1; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 318d3a2ad3f569a881f5a2f5cf579fdcbd49262b..8acbf820250957163397342c645b333f0da0801c 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/listen_and_serv_op.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -294,6 +295,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const { + // Mark this as PS that it should decide profiling by listening from trainer. + platform::SetProfileListener(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); framework::Scope &recv_scope = scope.NewScope(); diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 412cdda286c3a77af002fdc5eb6a5ae440606b82..cfddd8e8711f8005e0eff7ef7a2980f535b2f851 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/profiler.h" + #include #include #include #include +#include #include #include // NOLINT +#include #include #ifdef PADDLE_WITH_CUDA #include @@ -33,6 +36,9 @@ namespace platform { struct EventList; +static int64_t profiler_lister_id = 0; +static bool should_send_profile_state = false; + // The profiler state, the initial value is ProfilerState::kDisabled static ProfilerState g_state = ProfilerState::kDisabled; // The thread local event list only can be accessed by the specific thread @@ -219,13 +225,12 @@ void EnableProfiler(ProfilerState state) { PADDLE_ENFORCE(state != ProfilerState::kDisabled, "Can't enbale profling, since the input state is ", "ProfilerState::kDisabled"); - PADDLE_ENFORCE(g_state == ProfilerState::kDisabled, - "The profiling state should be disabled when calling ", - "EnableProfiler."); - g_state = state; - if (g_state == ProfilerState::kAll) { - GetDeviceTracer()->Enable(); + if (state == g_state) { + return; } + g_state = state; + should_send_profile_state = true; + GetDeviceTracer()->Enable(); #ifdef PADDLE_WITH_CUDA if (g_state == ProfilerState::kCUDA) { // Generate some dummy events first to reduce the startup overhead. @@ -435,8 +440,7 @@ void ParseEvents(const std::vector>& events, void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path) { - PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, - "Can't disable profiling, since it's not starting."); + if (g_state == ProfilerState::kDisabled) return; // Mark the profiling stop. Mark("_stop_profiler_", nullptr); @@ -444,12 +448,25 @@ void DisableProfiler(EventSortingKey sorted_key, ParseEvents(all_events, sorted_key); ResetProfiler(); DeviceTracer* tracer = GetDeviceTracer(); - if (g_state == ProfilerState::kAll && tracer && tracer->IsEnabled()) { + if (tracer->IsEnabled()) { tracer->Disable(); tracer->GenProfile(profile_path); } g_state = ProfilerState::kDisabled; + should_send_profile_state = true; +} + +bool IsProfileEnabled() { return g_state != ProfilerState::kDisabled; } +bool ShouldSendProfileState() { return should_send_profile_state; } + +void SetProfileListener() { + std::mt19937 rng; + rng.seed(std::random_device()()); + std::uniform_int_distribution dist6( + 1, std::numeric_limits::max()); + profiler_lister_id = dist6(rng); } +int64_t ListenerId() { return profiler_lister_id; } } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 428d9ebceaabd987261c1dcd6e66faf044b702c0..61b98143e41abb9e47d2c717c7876f1bab7f5077 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -114,5 +114,13 @@ void ResetProfiler(); void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path); +// Test if the profiler is currently enabled. +bool IsProfileEnabled(); +// Whether the trainer should send profiling state to PS. +bool ShouldSendProfileState(); +// Mark current process as PS by assigning a lister id. +void SetProfileListener(); +int64_t ListenerId(); + } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 37d368946770978700abe49eef6825e1d96839f0..c8a435748dc5b51bf9e57b5b597e1422f0380e8e 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -60,6 +60,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\ 'io', 'initializer', 'layers', + 'transpiler' 'nets', 'optimizer', 'learning_rate_decay', diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 92c1ee2f2ab66476d0fed2e43a0b3569383ba3a5..9e26b2ce510829b3ea6825e46f6f8b087debf7c0 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1042,13 +1042,14 @@ class Program(object): Returns(Program): The cloned Program object. """ - p = Program() if for_test: - p.desc = core.inference_optimize(self.desc) + p = self.inference_optimize() else: + p = Program() p.desc = core.ProgramDesc(self.desc) - p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] - p.sync_with_cpp() + p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] + p.sync_with_cpp() + p.copy_param_info_from(self) return p @@ -1061,7 +1062,7 @@ class Program(object): if isinstance(t, Variable): # After transpiler processing, the op that output this # variable maybe has been changed, so t.op is not reliable - # and we need to find the current op that generate this + # and we need to find the current op that generate this # variable here. t.op = None global_block = self.global_block() @@ -1087,8 +1088,16 @@ class Program(object): return res def inference_optimize(self): + # this is an alternative implement before + # core.inference_optimize being fixed. res = Program() - res.desc = core.inference_optimize(self.desc) + res.desc = core.ProgramDesc(self.desc) + for i in xrange(res.desc.num_blocks()): + block = res.desc.block(i) + for j in xrange(block.op_size()): + op = block.op(j) + if op.has_attr('is_test'): + op.set_attr('is_test', True) res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())] res.sync_with_cpp() return res diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index c618b02a768f2ca3e2b2914d8ee0134836d5c0d2..bb9c6fdc60089fc2b43573a6421a6f9781d2d4a8 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -251,7 +251,7 @@ class EditDistance(MetricBase): self.instance_error += seq_num - seq_right_count self.total_distance += total_distance - def eval(): + def eval(self): if self.seq_num == 0: raise ValueError( "There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance." @@ -280,6 +280,7 @@ class DetectionMAP(MetricBase): super(DetectionMAP, self).__init__(name) # the current map value self.value = .0 + self.weight = .0 def update(self, value, weight): if not _is_number_or_matrix_(value): @@ -340,8 +341,8 @@ class Auc(MetricBase): raise ValueError("The 'predictions' must be a numpy ndarray.") kepsilon = 1e-7 # to account for floating point imprecisions - thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) - for i in range(num_thresholds - 2)] + thresholds = [(i + 1) * 1.0 / (self._num_thresholds - 1) + for i in range(self._num_thresholds - 2)] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] # caculate TP, FN, TN, FP count @@ -358,19 +359,20 @@ class Auc(MetricBase): fp += 1 else: tn += 1 - tp_list[idx_thresh] += tp - fn_list[idx_thresh] += fn - tn_list[idx_thresh] += tn - fp_list[idx_thresh] += fp + self.tp_list[idx_thresh] += tp + self.fn_list[idx_thresh] += fn + self.tn_list[idx_thresh] += tn + self.fp_list[idx_thresh] += fp def eval(self): epsilon = self._epsilon num_thresholds = self._num_thresholds - tpr = (tp_list.astype("float32") + epsilon) / ( - tp_list + fn_list + epsilon) - fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) - rec = (tp_list.astype("float32") + epsilon) / ( - tp_list + fp_list + epsilon) + tpr = (self.tp_list.astype("float32") + epsilon) / ( + self.tp_list + self.fn_list + epsilon) + fpr = self.fp_list.astype("float32") / ( + self.fp_list + self.tn_list + epsilon) + rec = (self.tp_list.astype("float32") + epsilon) / ( + self.tp_list + self.fp_list + epsilon) x = fpr[:num_thresholds - 1] - fpr[1:] y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 8252592c8ce0ea0a9959f882170d42bdc74e996a..1cbecd69e59882212d623b7fcaf741f0370a7a15 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -19,10 +19,11 @@ import executor import data_feeder import contextlib import io +import transpiler # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module import optimizer as opt_module -import distribute_transpiler +from transpiler import distribute_transpiler __all__ = [ 'Trainer', diff --git a/python/setup.py.in b/python/setup.py.in index a811b509a90b8b0d84451f54462a0308c062d022..c42601d335f01491156dc3591341c1a3213aecfe 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -68,7 +68,8 @@ packages=['paddle', 'paddle.fluid', 'paddle.fluid.proto', 'paddle.fluid.proto.profiler', - 'paddle.fluid.layers'] + 'paddle.fluid.layers', + 'paddle.fluid.transpiler'] if '${WITH_FLUID_ONLY}'== 'OFF': packages+=['paddle.proto', diff --git a/tools/timeline.py b/tools/timeline.py index f4083c824e7333a74661d096d4954609f767c83e..8cd6353d46f496831cb61c1cdbbd156ca0579fb4 100644 --- a/tools/timeline.py +++ b/tools/timeline.py @@ -22,7 +22,11 @@ import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2 parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--profile_path', type=str, default='', help='Input profile file name.') + '--profile_path', + type=str, + default='', + help='Input profile file name. If there are multiple file, the format ' + 'should be trainer1=file1,trainer2=file2,ps=file3') parser.add_argument( '--timeline_path', type=str, default='', help='Output timeline file name.') args = parser.parse_args() @@ -108,8 +112,8 @@ class _ChromeTraceFormatter(object): class Timeline(object): - def __init__(self, profile_pb): - self._profile_pb = profile_pb + def __init__(self, profile_dict): + self._profile_dict = profile_dict self._pid = 0 self._devices = dict() self._chrome_trace = _ChromeTraceFormatter() @@ -120,35 +124,37 @@ class Timeline(object): return cur_pid def _allocate_pids(self): - for event in self._profile_pb.events: - if event.type == profiler_pb2.Event.CPU: - if (event.device_id, "CPU") not in self._devices: - pid = self._allocate_pid() - self._devices[(event.device_id, "CPU")] = pid - self._chrome_trace.emit_pid("cpu:block:%d" % - (event.device_id), pid) - elif event.type == profiler_pb2.Event.GPUKernel: - if (event.device_id, "GPUKernel") not in self._devices: - pid = self._allocate_pid() - self._devices[(event.device_id, "GPUKernel")] = pid - self._chrome_trace.emit_pid("gpu:%d" % (event.device_id), - pid) + for k, profile_pb in self._profile_dict.iteritems(): + for event in profile_pb.events: + if event.type == profiler_pb2.Event.CPU: + if (k, event.device_id, "CPU") not in self._devices: + pid = self._allocate_pid() + self._devices[(k, event.device_id, "CPU")] = pid + self._chrome_trace.emit_pid("%s:cpu:block:%d" % + (k, event.device_id), pid) + elif event.type == profiler_pb2.Event.GPUKernel: + if (k, event.device_id, "GPUKernel") not in self._devices: + pid = self._allocate_pid() + self._devices[(k, event.device_id, "GPUKernel")] = pid + self._chrome_trace.emit_pid("%s:gpu:%d" % + (k, event.device_id), pid) def _allocate_events(self): - for event in self._profile_pb.events: - if event.type == profiler_pb2.Event.CPU: - type = "CPU" - elif event.type == profiler_pb2.Event.GPUKernel: - type = "GPUKernel" - pid = self._devices[(event.device_id, type)] - args = {'name': event.name} - if event.memcopy.bytes > 0: - args = {'mem_bytes': event.memcopy.bytes} - # TODO(panyx0718): Chrome tracing only handles ms. However, some - # ops takes micro-seconds. Hence, we keep the ns here. - self._chrome_trace.emit_region( - event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, - event.sub_device_id, 'Op', event.name, args) + for k, profile_pb in self._profile_dict.iteritems(): + for event in profile_pb.events: + if event.type == profiler_pb2.Event.CPU: + type = "CPU" + elif event.type == profiler_pb2.Event.GPUKernel: + type = "GPUKernel" + pid = self._devices[(k, event.device_id, type)] + args = {'name': event.name} + if event.memcopy.bytes > 0: + args = {'mem_bytes': event.memcopy.bytes} + # TODO(panyx0718): Chrome tracing only handles ms. However, some + # ops takes micro-seconds. Hence, we keep the ns here. + self._chrome_trace.emit_region( + event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, + event.sub_device_id, 'Op', event.name, args) def generate_chrome_trace(self): self._allocate_pids() @@ -163,11 +169,23 @@ timeline_path = '/tmp/timeline' if args.timeline_path: timeline_path = args.timeline_path -with open(profile_path, 'r') as f: - profile_s = f.read() - profile_pb = profiler_pb2.Profile() - profile_pb.ParseFromString(profile_s) - -tl = Timeline(profile_pb) +profile_paths = profile_path.split(',') +profile_dict = dict() +if len(profile_path) == 1: + with open(profile_path, 'r') as f: + profile_s = f.read() + profile_pb = profiler_pb2.Profile() + profile_pb.ParseFromString(profile_s) + profile_dict['trainer'] = profile_pb +else: + for profile_path in profile_paths: + k, v = profile_path.split('=') + with open(v, 'r') as f: + profile_s = f.read() + profile_pb = profiler_pb2.Profile() + profile_pb.ParseFromString(profile_s) + profile_dict[k] = profile_pb + +tl = Timeline(profile_dict) with open(timeline_path, 'w') as f: f.write(tl.generate_chrome_trace())