diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index eb4e7ec52f907f9403e21ec2734d61824f51a58b..1d80bab90f513139f807b57258177c6b2ac53ac0 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" +#include #include #include #include "paddle/fluid/framework/executor.h" @@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } } } + std::vector fetch_data; + std::exception_ptr eptr; + try { + fetch_data = underlying_executor_->Run(fetch_tensors); + } catch (...) { + eptr = std::current_exception(); + } - auto fetch_data = underlying_executor_->Run(fetch_tensors); drop_scope_counter_ += 1; if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { @@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( scope->DeleteScope(local_scope); } } - return fetch_data; + if (eptr) { + std::rethrow_exception(eptr); + } else { + return fetch_data; + } } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 99b10254a7961bf7b27b256acaece573a71c4115..07097c7e75c6ce638549716cd6523f387cdefd92 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( set.clear(); }; + // Clean run context + run_op_futures_.clear(); + exception_.reset(); + // Step 3. Execution while (!pending_vars.empty()) { // 1. Run All Ready ops @@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { - std::lock_guard l(exception_mu_); + std::unique_lock l(exception_mu_); if (exception_) { + l.unlock(); + for (auto &run_op_future : run_op_futures_) { + run_op_future.wait(); + } + l.lock(); std::exception *exp = exception_.get(); if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else { LOG(FATAL) << "Unknown exception."; @@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp( } }; if (pool_) { - pool_->enqueue(op_run); + run_op_futures_.emplace_back(pool_->enqueue(op_run)); } else { op_run(); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index c69e0487e2e503a0d445300aa2fd6bb9c30b06c9..09973b7a72881464ad9e7776d4aad3d2261a118d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { private: ExecutionStrategy strategy_; + // use std::list because clear(), push_back, and for_each are O(1) + std::list> run_op_futures_; }; } // namespace details diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ab1d2143330fb8cbfd535758a83bc71de939c4e0..bc07bbe67e3dafc08aeb62cd75629966c216ce6e 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -265,6 +265,8 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) +op_library(unsqueeze_op DEPS reshape_op) +op_library(squeeze_op DEPS reshape_op) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 1dbafd23e92732bdaf0d263a01e267227786d839..e17c2ffd39eea31fe85933eda144ab97cf8c3dd8 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -23,7 +23,7 @@ class BatchReader : public framework::DecoratedReader { BatchReader(const std::shared_ptr& reader, int batch_size, bool discard_leftover) : DecoratedReader(reader), - batch_size_(batch_size), + batch_size_(static_cast(batch_size)), discard_leftover_(discard_leftover) { buffer_.reserve(batch_size_); } @@ -31,7 +31,7 @@ class BatchReader : public framework::DecoratedReader { void ReadNextImpl(std::vector* out) override; private: - int batch_size_; + size_t batch_size_; bool discard_leftover_; std::vector> buffer_; }; @@ -78,7 +78,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { void BatchReader::ReadNextImpl(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); - for (int i = 0; i < batch_size_; ++i) { + for (size_t i = 0; i < batch_size_; ++i) { buffer_.push_back(std::vector()); reader_->ReadNext(&buffer_.back()); if (buffer_.back().empty()) { @@ -95,9 +95,9 @@ void BatchReader::ReadNextImpl(std::vector* out) { // if buffer_ is empty, the 'out' will return as an empty vector. return; } - int out_num = buffer_[0].size(); + size_t out_num = buffer_[0].size(); out->reserve(out_num); - for (int j = 0; j < out_num; ++j) { + for (size_t j = 0; j < out_num; ++j) { // Merge shape and check date type std::type_index batch_type = buffer_[0][j].type(); framework::DDim batch_shape = buffer_[0][j].dims(); diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c507baf3a0ab0a557d29a53700685753616193b --- /dev/null +++ b/paddle/fluid/operators/squeeze_op.cc @@ -0,0 +1,202 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class SqueezeOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SqueezeOp should not be null."); + + const auto &x_dims = ctx->GetInputDim("X"); + // Check input tensor dims (<6) Eigen limit. + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimnesions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)."); + + const auto &axes = ctx->Attrs().Get>("axes"); + for (int a : axes) { + PADDLE_ENFORCE_LT(a, x_dims.size(), + "The squeeze axis should be less than input " + "tensor's rank."); + } + + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + } + + static framework::DDim GetOutputShape(const std::vector squeeze_dims, + const framework::DDim &in_dims) { + size_t num_squeeze_dims = squeeze_dims.size(); + int cnt_squeezed_dims = 0; + bool should_squeeze[9] = {false}; + + // Determines number of dimensions of output tensor after squeeze. + // Mark and count the dimensions need to be squeezed + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < in_dims.size(); ++idx) { + if (in_dims[idx] == 1) { + should_squeeze[idx] = true; + ++cnt_squeezed_dims; + } + } + } else { + for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() + : squeeze_dims[idx]; + // Check current index, the upper limit has beed checked in line 36. + PADDLE_ENFORCE(current >= 0, + "Invalid axis, the negative axis is out of range."); + PADDLE_ENFORCE(in_dims[current] == 1, + "Invalid axis index, the axis that will be squeezed " + "should be equal to 1."); + + if (!(should_squeeze[current])) { + ++cnt_squeezed_dims; + } + should_squeeze[current] = true; + } + } + + // Make output dimensions + std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { + if (!should_squeeze[in_idx]) { + output_shape[out_idx++] = in_dims[in_idx]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class SqueezeOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axes = Attr>("axes"); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(out_dims); + attrs["inplace"] = Attr("inplace"); + // Invoke Reshape Op + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + +class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor). The input tensor of squeeze operator."); + AddOutput("Out", "(Tensor). The output tensor of squeeze operator."); + AddAttr>("axes", + "(std::vector). List of integers," + " indicating the dimensions to squeeze.") + .SetDefault({}); + AddAttr("inplace", + "(default: false) Squeeze the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Squeeze Operator. + + Remove single-dimensional entries from the shape of a tensor. + Takes a parameter axes with a list of axes to squeeze. + If axes is not provided, all the single dimensions will be removed from the shape. + If an axis is selected with shape entry not equal to one, an error is raised. + + Examples: + Case 1: + Given + X.shape = (1, 3, 1, 5) + and + axes = [0] + we get: + Out.shape = (3, 1, 5) + + Case 2: + Given + X.shape = (1, 3, 1, 5) + and + axes = [] + we get: + Out.shape = (3, 5) + )DOC"); + } +}; + +class SqueezeGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + context->SetOutputDim(framework::GradVarName("X"), + context->GetInputDim("X")); + context->ShareLoD("X", framework::GradVarName("X")); + } +}; + +class SqueezeGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(x_dims); + attrs["inplace"] = Attr("inplace"); + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle + +// Tell linker to use reshape op +USE_OP(reshape); + +namespace ops = paddle::operators; +REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, + ops::SqueezeOpInferShape, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape); diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2a15fdf572e0de30f9949dda5020e130b0c5585 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -0,0 +1,191 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class UnsqueezeOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnsqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnsqueezeOp should not be null."); + + const auto &axes = ctx->Attrs().Get>("axes"); + const auto &x_dims = ctx->GetInputDim("X"); + // Validity Check: input tensor dims (<6). + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + } + + static framework::DDim GetOutputShape(const std::vector unsqz_dims, + const framework::DDim &in_dims) { + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validity Check: rank range. + PADDLE_ENFORCE(output_size <= 6, + "The output tensor's rank should be less than 6."); + + for (int axis : unsqz_dims) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Vaildity Check: the axis bound + PADDLE_ENFORCE( + cur >= 0 && cur <= cur_output_size, + "The unsqueeze dims must be within range of current rank."); + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class UnsqueezeOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axes = Attr>("axes"); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(out_dims); + attrs["inplace"] = Attr("inplace"); + // Invoke Reshape op. + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + +class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); + AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); + AddAttr>("axes", + "(std::vector). List of integers," + " indicating the dimensions to be inserted") + .AddCustomChecker([](const std::vector &axes) { + PADDLE_ENFORCE(!axes.empty(), + "Invalid axes, The unsqueeze axes is empty."); + // Validity Check: axes dims (<6). + PADDLE_ENFORCE(static_cast(axes.size()) < 6, + "Invalid dimensions, dynamic dimensions should be " + "within [1, 6] dimensions (Eigen limit)."); + // Validity Check: the range of unsqueeze aixs. + for (int axis : axes) { + PADDLE_ENFORCE(axis < 6, + "Invalid dimensions, input axis should be" + " within [1, 6] dimensions (Eigen limit)."); + } + }); + AddAttr( + "inplace", + "(default: false) Unsqueeze the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Unsqueeze Operator. + + Insert single-dimensional entries to the shape of a tensor. + Takes one required argument axes, a list of dimensions that will be inserted. + Dimension indices in axes are as seen in the output tensor. + + For example: + Given a tensor such that tensor with shape [3, 4, 5], + then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] + )DOC"); + } +}; + +class UnsqueezeGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", framework::GradVarName("X")); + } +}; + +class UnsqueezeGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(x_dims); + attrs["inplace"] = Attr("inplace"); + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle + +// Tell linker to use reshape op. +USE_OP(reshape); + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, + ops::UnsqueezeOpInferShape, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, + ops::UnsqueezeGradInferShape); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index bec161a571d0f98b99ac656f08cbc2d401c943fc..25900811509aee8b37fdaf09cf902ea2ae3eee57 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -92,8 +92,15 @@ install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR} DESTINATION opt/paddle/share/wheels ) -find_program(PATCHELF_EXECUTABLE patchelf) -if(NOT PATCHELF_EXECUTABLE) - message(FATAL_ERROR "patchelf not found, please install it.\n" - "For Ubuntu, the command is: apt-get install -y patchelf.") -endif() +if(APPLE) + find_program(INSTALL_NAME_TOOL_EXECUTABLE install_name_tool) + if(NOT INSTALL_NAME_TOOL_EXECUTABLE) + message(FATAL_ERROR "install_name_tool not found, please check.\n") + endif() +else(APPLE) + find_program(PATCHELF_EXECUTABLE patchelf) + if(NOT PATCHELF_EXECUTABLE) + message(FATAL_ERROR "patchelf not found, please install it.\n" + "For Ubuntu, the command is: apt-get install -y patchelf.") + endif() +endif(APPLE) diff --git a/python/paddle/fluid/annotations.py b/python/paddle/fluid/annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8756a4664013643c278c013ca21bb237a6b4a7 --- /dev/null +++ b/python/paddle/fluid/annotations.py @@ -0,0 +1,38 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import sys + +__all__ = ['deprecated'] + + +def deprecated(since, instead, extra_message=""): + def decorator(func): + err_msg = "API {0} is deprecated since {1}. Please use {2} instead.".format( + func.__name__, since, instead) + if len(extra_message) != 0: + err_msg += "\n" + err_msg += extra_message + + @functools.wraps(func) + def wrapper(*args, **kwargs): + print >> sys.stderr, err_msg + return func(*args, **kwargs) + + wrapper.__doc__ += "\n " + wrapper.__doc__ += err_msg + return wrapper + + return decorator diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index e7a065599ec034039442449ea20e1fcd02c44f8d..ddcde04716d21df1f18e7202936f470d3d58a661 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -18,10 +18,7 @@ import collections import copy import unique_name -__all__ = [ - 'append_backward', - 'calc_gradient', -] +__all__ = ['append_backward'] def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): diff --git a/python/paddle/fluid/layers/device.py b/python/paddle/fluid/layers/device.py index e0c1aab230aeed7fb858e91e7da7eae58032ee16..384d302a709eeec220864b9e8c9210ed028470f6 100644 --- a/python/paddle/fluid/layers/device.py +++ b/python/paddle/fluid/layers/device.py @@ -18,10 +18,12 @@ All util layers. from layer_function_generator import autodoc from ..framework import unique_name from ..layer_helper import LayerHelper +from ..annotations import deprecated -__all__ = ['get_places'] +__all__ = [] +@deprecated(since='0.15.0', instead="ParallelExecutor") @autodoc() def get_places(device_count=None, device_type=None): helper = LayerHelper('get_places', **locals()) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index d458e7bd07d26ea6ee28fe7060523ad8a8822997..65fef50ead6f3da7abde844747e4e81f45c96207 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -29,7 +29,7 @@ __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', - 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'Optimizer', 'RMSPropOptimizer' + 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer' ] @@ -67,7 +67,7 @@ class Optimizer(object): self._LARS_weight_decay = LARS_weight_decay def _create_global_learning_rate(self): - lr = self.global_learning_rate() + lr = self._global_learning_rate() if isinstance(lr, framework.Variable): return @@ -86,7 +86,7 @@ class Optimizer(object): dtype='float32' if self._dtype == None else self._dtype, persistable=True) - def global_learning_rate(self, program=None): + def _global_learning_rate(self, program=None): """ get global decayed learning rate :return: @@ -110,9 +110,9 @@ class Optimizer(object): return param_lr else: if param_lr == 1.0: - return self.global_learning_rate() + return self._global_learning_rate() else: - return self.global_learning_rate() * param_lr + return self._global_learning_rate() * param_lr def _create_accumulators(self, block, parameters): """Create all accumulators needed by the parameters @@ -185,10 +185,10 @@ class Optimizer(object): format(name, param.name)) return self._accumulators[name][param.name] - def create_optimization_pass(self, - parameters_and_grads, - loss, - startup_program=None): + def _create_optimization_pass(self, + parameters_and_grads, + loss, + startup_program=None): """Add optimization operators to update gradients to variables. Args: @@ -221,7 +221,7 @@ class Optimizer(object): self._create_global_learning_rate() if self._LARS_weight_decay > 0.0: layers.append_LARS(parameters_and_grads, - self.global_learning_rate(), + self._global_learning_rate(), self._LARS_weight_decay) optimize_ops = [] @@ -263,8 +263,8 @@ class Optimizer(object): params_grads = append_regularization_ops(params_grads, self.regularization) - optimize_ops = self.create_optimization_pass(params_grads, loss, - startup_program) + optimize_ops = self._create_optimization_pass(params_grads, loss, + startup_program) return optimize_ops, params_grads diff --git a/python/paddle/fluid/tests/book/notest_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py index 1df7b99aad6094a8b8ddfe783b9de35cef61c524..95002aa7f9bb639828b47eb1e86e4ef954fb85ff 100644 --- a/python/paddle/fluid/tests/book/notest_understand_sentiment.py +++ b/python/paddle/fluid/tests/book/notest_understand_sentiment.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function - +from paddle.fluid.layers.device import get_places import unittest import paddle.fluid as fluid import paddle @@ -144,7 +144,7 @@ def train(word_dict, cost, acc_out, prediction = net_method( data, label, input_dim=dict_dim, class_dim=class_dim) else: - places = fluid.layers.get_places() + places = get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): cost, acc, _ = net_method( diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index 5f5c8544bbdb87421f129b201a0ebaf4cb8602a1..49f549fa184037a64aa846f0d1d0e1b57db1f2ef 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function -import argparse -import paddle.fluid as fluid -import paddle -import sys -import numpy -import unittest + import math -import sys import os +import sys +import unittest + +import numpy + +import paddle +import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places BATCH_SIZE = 64 @@ -76,7 +78,7 @@ def train(nn_type, net_conf = conv_net if parallel: - places = fluid.layers.get_places() + places = get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): img_ = pd.read_input(img) diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py index 49bd72c7a53c0ae740bdbabe15b1d37340699d41..80e0692bc640efc280c43bd5b929847ad29207c4 100644 --- a/python/paddle/fluid/tests/book/test_word2vec.py +++ b/python/paddle/fluid/tests/book/test_word2vec.py @@ -14,6 +14,7 @@ import paddle import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places import unittest import os import numpy as np @@ -80,7 +81,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True): avg_cost, predict_word = __network__( [first_word, second_word, third_word, forth_word, next_word]) else: - places = fluid.layers.get_places() + places = get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): avg_cost, predict_word = __network__( diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py index be347cd5315668dde0454d7959dbf9bcfa465b5f..bec9f8594ff7c1aff8ae5ed55c9623754d9ea091 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import paddle -import paddle.fluid as fluid import math import sys +import paddle +import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places + # need to fix random seed and training data to compare the loss # value accurately calculated by the default and the memory optimization # version. @@ -34,7 +35,7 @@ if fluid.core.is_compiled_with_cuda(): use_nccl = False place = fluid.CUDAPlace(0) -places = fluid.layers.get_places(device_count=0, device_type=device_type) +places = get_places(device_count=0, device_type=device_type) pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl) with pd.do(): x_ = pd.read_input(x) diff --git a/python/paddle/fluid/tests/unittests/test_calc_gradient.py b/python/paddle/fluid/tests/unittests/test_calc_gradient.py index 06e676cd83e77549afd679e730426c590cc046bf..7f2a9e6971ed933463216e38498d48ab132a1a37 100644 --- a/python/paddle/fluid/tests/unittests/test_calc_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_calc_gradient.py @@ -16,8 +16,6 @@ import unittest import paddle.fluid as fluid import paddle.fluid.layers as layers -import paddle.fluid.framework as framework -import paddle.fluid.optimizer as optimizer from paddle.fluid.backward import calc_gradient diff --git a/python/paddle/fluid/tests/unittests/test_get_places_op.py b/python/paddle/fluid/tests/unittests/test_get_places_op.py index 6dab1e22f0c50ab011d6b8e8944097600cf3fecc..964423e2d2638224244b4ca774d8eee08f3ec989 100644 --- a/python/paddle/fluid/tests/unittests/test_get_places_op.py +++ b/python/paddle/fluid/tests/unittests/test_get_places_op.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places import decorators import unittest @@ -20,7 +21,7 @@ import unittest class TestGetPlaces(unittest.TestCase): @decorators.prog_scope() def test_get_places(self): - places = fluid.layers.get_places() + places = get_places() cpu = fluid.CPUPlace() exe = fluid.Executor(cpu) exe.run(fluid.default_main_program()) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 31ae4e7d8c4107fce7022234d3aa47827647d012..82418f34ccb7e665a041079a19880c7bb34b0b0f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import paddle.fluid.layers as layers +from paddle.fluid.layers.device import get_places import paddle.fluid.nets as nets from paddle.fluid.framework import Program, program_guard, default_main_program from paddle.fluid.param_attr import ParamAttr @@ -238,7 +239,7 @@ class TestBook(unittest.TestCase): def test_get_places(self): program = Program() with program_guard(program): - x = layers.get_places(device_count=4) + x = get_places(device_count=4) self.assertIsNotNone(x) print(str(program)) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 7286c7c450108c4b5ad7136041bc4e989894a2ba..43385691bb3960004b5b69a1c55e41dd4252fa71 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -97,7 +97,7 @@ class TestMomentumOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) - opts = momentum_optimizer.create_optimization_pass( + opts = momentum_optimizer._create_optimization_pass( params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) sgd_op = opts[-1] @@ -151,7 +151,7 @@ class TestMomentumOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) - opts = momentum_optimizer.create_optimization_pass( + opts = momentum_optimizer._create_optimization_pass( params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) sgd_op = opts[-1] @@ -214,8 +214,8 @@ class TestAdagradOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0) - opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = adagrad_optimizer._create_optimization_pass( + params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) self.assertEqual([op.type for op in opts], ["fill_constant", "elementwise_mul", "adagrad"]) @@ -278,8 +278,8 @@ class TestAdamOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adam_optimizer.get_accumulators()), 0) - opts = adam_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = adam_optimizer._create_optimization_pass(params_grads, mul_out, + init_program) self.assertEqual(len(opts), 5) self.assertEqual( [op.type for op in opts], @@ -345,8 +345,8 @@ class TestAdamaxOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) - opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = adamax_optimizer._create_optimization_pass(params_grads, mul_out, + init_program) self.assertEqual(len(opts), 4) self.assertEqual( [op.type for op in opts], @@ -409,7 +409,7 @@ class TestDecayedAdagradOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) - opts = decayed_adagrad_optimizer.create_optimization_pass( + opts = decayed_adagrad_optimizer._create_optimization_pass( params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) self.assertEqual( @@ -475,8 +475,8 @@ class TestFtrlOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(ftrl_optimizer.get_accumulators()), 0) - opts = ftrl_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = ftrl_optimizer._create_optimization_pass(params_grads, mul_out, + init_program) self.assertEqual(len(opts), 3) self.assertEqual([op.type for op in opts], ["fill_constant", "elementwise_mul", "ftrl"]) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_op.py b/python/paddle/fluid/tests/unittests/test_parallel_op.py index 9ba5f988f317a515b77c0b428da236626419a2c3..9ec05e02973138e3ec233ef07f98afd598ec86b1 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_op.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_op.py @@ -15,6 +15,7 @@ import unittest import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places import paddle.fluid.profiler as profiler import numpy @@ -115,7 +116,7 @@ class BaseParallelForTest(unittest.TestCase): if use_parallel: thread_num = fluid.core.get_cuda_device_count( ) if use_gpu else 8 - places = fluid.layers.get_places(thread_num) + places = get_places(thread_num) pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl) data = next(generator) diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bca6af2fd5dfadbc48cf1a76cfa6ffd4f1fdfdef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -0,0 +1,114 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestSqueezeOp(OpTest): + def setUp(self): + self.op_type = "squeeze" + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, 2) + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} + + +# Correct: There is mins axis. +class TestSqueezeOp1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, -2) + self.new_shape = (3, 5) + + +# Correct: No axes input. +class TestSqueezeOp2(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = () + self.new_shape = (3, 5) + + +# Correct: Just part of axes be squeezed. +class TestSqueezeOp3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (3, 5, 1, 4) + + +# Correct: Inplace. +class TestSqueezeOpInplace1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, 2) + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is mins axis. +class TestSqueezeOpInplace2(TestSqueezeOp): + def inti_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, -2) + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. No axes input. +class TestSqueezeOpInplace3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = () + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inpalce. Just part of axes be squeezed. +class TestSqueezeOpInplace4(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (3, 5, 1, 4) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4aa0a40b5eb494f6027e800ca6b466bbe1c302 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -0,0 +1,111 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestUnsqueezeOp(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} + + +# Correct: Single input index. +class TestUnsqueezeOp1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (-1, ) + self.new_shape = (3, 5, 1) + + +# Correct: Mixed input axis. +class TestUnsqueezeOp2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -1) + self.new_shape = (1, 3, 5, 1) + + +# Correct: There is duplicated axis. +class TestUnsqueezeOp3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + +# Correct: Inplace. +class TestUnsqueezeOpInplace1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, 2) + self.new_shape = (1, 3, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is mins index. +class TestUnsqueezeOpInplace2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -2) + self.new_shape = (1, 3, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is duplicated axis. +class TestUnsqueezeOpInplace3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +if __name__ == "__main__": + unittest.main()