diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index eb4e7ec52f907f9403e21ec2734d61824f51a58b..1d80bab90f513139f807b57258177c6b2ac53ac0 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" +#include #include #include #include "paddle/fluid/framework/executor.h" @@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } } } + std::vector fetch_data; + std::exception_ptr eptr; + try { + fetch_data = underlying_executor_->Run(fetch_tensors); + } catch (...) { + eptr = std::current_exception(); + } - auto fetch_data = underlying_executor_->Run(fetch_tensors); drop_scope_counter_ += 1; if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { @@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( scope->DeleteScope(local_scope); } } - return fetch_data; + if (eptr) { + std::rethrow_exception(eptr); + } else { + return fetch_data; + } } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 99b10254a7961bf7b27b256acaece573a71c4115..07097c7e75c6ce638549716cd6523f387cdefd92 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( set.clear(); }; + // Clean run context + run_op_futures_.clear(); + exception_.reset(); + // Step 3. Execution while (!pending_vars.empty()) { // 1. Run All Ready ops @@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { - std::lock_guard l(exception_mu_); + std::unique_lock l(exception_mu_); if (exception_) { + l.unlock(); + for (auto &run_op_future : run_op_futures_) { + run_op_future.wait(); + } + l.lock(); std::exception *exp = exception_.get(); if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else { LOG(FATAL) << "Unknown exception."; @@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp( } }; if (pool_) { - pool_->enqueue(op_run); + run_op_futures_.emplace_back(pool_->enqueue(op_run)); } else { op_run(); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index c69e0487e2e503a0d445300aa2fd6bb9c30b06c9..09973b7a72881464ad9e7776d4aad3d2261a118d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { private: ExecutionStrategy strategy_; + // use std::list because clear(), push_back, and for_each are O(1) + std::list> run_op_futures_; }; } // namespace details diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ab1d2143330fb8cbfd535758a83bc71de939c4e0..d265150f25419509126028e36e629aee3ee6bd0f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -259,12 +259,15 @@ op_library(max_sequence_len_op DEPS lod_rank_table) op_library(sequence_conv_op DEPS context_project) op_library(sequence_pool_op DEPS sequence_pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) +op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) op_library(lstmp_op DEPS sequence2batch lstm_compute) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) +op_library(unsqueeze_op DEPS reshape_op) +op_library(squeeze_op DEPS reshape_op) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..dadd054b9a6f8d44f4e5832888052bffde34c827 --- /dev/null +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/hierarchical_sigmoid_op.h" +#include + +namespace paddle { +namespace operators { + +/** + * Organize the classes into a binary tree. At each node, a sigmoid function + * is used to calculate the probability of belonging to the right branch. + * This idea is from "F. Morin, Y. Bengio (AISTATS 05): + * Hierarchical Probabilistic Neural Network Language Model." + * + * Here we uses a simple way of making the binary tree. + * Assuming the number of classes C = 6, + * The classes are organized as a binary tree in the following way: + * + * @code{.py} + * *-*-*- 2 + * | | |- 3 + * | | + * | |-*- 4 + * | |- 5 + * | + * |-*- 0 + * |- 1 + * @endcode + * + * where * indicates an internal node, and each leaf node represents a class. + * - Node 0 ... C-2 are internal nodes. + * - Node C-1 ... 2C-2 are leaf nodes. + * - Class c is represented by leaf node \f$c+C-1\f$. + * + * We assign an id for each node: + * - the id of root be 0. + * - the left child of a node i is 2*i+1. + * - the right child of a node i is 2*i+2. + * + * It's easy to see that: + * - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$. + * - the j-th level ancestor of node i is + * \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$. + * - A node i is a left child of its parent if \f$(i-1)\%2==0\f$. + * + */ + +class HierarchicalSigmoidOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("PreOut"), + "Output(PreOut) should not be null."); + const int64_t batch_size = ctx->GetInputDim("X")[0]; + std::vector output_shape({batch_size, 1}); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } +}; + +template +class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, required) The input tensor with shape [N, D], " + "where N is the size of mini-batch, and D is the feature size."); + AddInput("W", + "(Tensor, required), The parameters of hierarchical " + "sigmoid operator, each of them is a 2-D tensor, the shape is" + "[num_classes - 1, D]."); + AddInput("Label", + "(Tensor, required), The labels of training data. It's a" + "tensor with shape [N, 1]."); + AddInput("Bias", + "(Tensor, optional), The bias is a tensor with shape" + "[1, num_classes - 1]."); + AddOutput("Out", + "(Tensor, required) The output of hierarchical sigmoid operator." + "The shape is [N, 1]."); + AddOutput("PreOut", + "(Tensor, required) A intermedia 2-D tensor with shape " + "[batch_size, code_length], where code_length represents the " + "maximum path length from root to leaf nodes.") + .AsIntermediate(); + AddAttr("num_classes", "(int, required), The number of classes") + .SetDefault(2); + AddComment(R"DOC( +The hierarchical sigmoid operator organize the classes into a binary tree. +At each node, a sigmoid function is used to calculate the probability of +belonging to the right branch. This idea is from +"F. Morin, Y. Bengio (AISTATS 05): +Hierarchical Probabilistic Neural Network Language Model." + )DOC"); + } +}; + +class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PreOut"), + "Input(Preout) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), + "Output(W@Grad should not be null.)"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, + ops::HierarchicalSigmoidOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); +REGISTER_OP_CPU_KERNEL( + hierarchical_sigmoid, + ops::HierarchicalSigmoidOpKernel, + ops::HierarchicalSigmoidOpKernel); +REGISTER_OP_CPU_KERNEL( + hierarchical_sigmoid_grad, + ops::HierarchicalSigmoidGradOpKernel, + ops::HierarchicalSigmoidGradOpKernel); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h new file mode 100644 index 0000000000000000000000000000000000000000..64096a717b12ed231344649f5eb76b7e4b9af4a6 --- /dev/null +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/matrix_bit_code.h" +#include "paddle/fluid/platform/transform.h" +namespace paddle { +namespace operators { + +template +using EigenMatrix = framework::EigenMatrix; +using platform::Transform; + +template +class HierarchicalSigmoidOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* label = ctx.Input("Label"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + auto* pre_out = ctx.Output("PreOut"); + size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t code_length = math::FindLastSet(num_classes - 1); + int64_t batch_size = in->dims()[0]; + framework::Tensor sum; + auto& dev_ctx = ctx.template device_context(); + auto* pre_out_data = pre_out->mutable_data( + framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); + auto pre_out_mat = EigenMatrix::From(*pre_out); + // Not all class(leaf) nodes' path lengths equal code_length, thus init as + // 0s can avoid out of path's loss. + math::SetConstant zero; + zero(dev_ctx, pre_out, static_cast(0.0)); + auto& place = *ctx.template device_context().eigen_device(); + math::RowwiseSum row_sum; + math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + std::vector sum_dims({batch_size, 1UL}); + sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); + auto sum_mat = EigenMatrix::From(sum); + out->mutable_data(ctx.GetPlace()); + auto out_mat = framework::EigenVector::Flatten(*out); + if (bias) { + bit_code.Add(pre_out, *bias); + } + bit_code.Mul(pre_out, *w, *in); + // clip to [-40, 40] + Transform trans; + trans(ctx.template device_context(), pre_out_data, + pre_out_data + pre_out->numel(), pre_out_data, + ClipFunctor(static_cast(-40.0), static_cast(40.0))); + bit_code.Sum(*pre_out, out, static_cast(-1)); + // use softrelu to calculate cross entropy + pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); + row_sum(dev_ctx, *pre_out, &sum); + // TODO(guosheng): Subtract the out of path's loss, since not all + // class(leaf) nodes' path lengths equal code_length. But it won't break the + // gradient check since both have the out of path's loss and will cancel out + // each other. + out_mat.device(place) = sum_mat + out_mat; + } +}; + +template +class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto* w_grad = ctx.Output(framework::GradVarName("W")); + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + auto* label = ctx.Input("Label"); + auto* pre_out = ctx.Input("PreOut"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + framework::Tensor pre_out_grad; + + pre_out_grad.mutable_data(pre_out->dims(), ctx.GetPlace()); + in_grad->mutable_data(ctx.GetPlace()); + w_grad->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(dev_ctx, in_grad, static_cast(0.0)); + zero(dev_ctx, w_grad, static_cast(0.0)); + + size_t num_classes = static_cast(ctx.Attr("num_classes")); + math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + auto& place = *ctx.template device_context().eigen_device(); + auto pre_out_mat = EigenMatrix::From(*pre_out); + auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); + auto out_grad_mat = EigenMatrix::From(*out_grad); + Eigen::array bcast({{1, static_cast(pre_out_grad.dims()[1])}}); + + // softrelu derivative + pre_out_grad_mat.device(place) = + static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp(); + bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b) + pre_out_grad_mat.device(place) = + pre_out_grad_mat * out_grad_mat.broadcast(bcast); + // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to + // be consistent with the clipping in forward. + if (bias_grad) { + bias_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, bias_grad, static_cast(0.0)); + bit_code.AddGrad(pre_out_grad, bias_grad); + } + bit_code.MulGradWeight(pre_out_grad, w_grad, *in); + bit_code.MulGradError(pre_out_grad, *w, in_grad); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 5571ff9a7151c1f971ad1805bf001815a651202b..d2b772d11379c218be77277b89f3ded7b59ab9f3 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -51,6 +51,7 @@ math_library(sequence_padding) math_library(sequence_pooling DEPS math_function) math_library(sequence_scale) math_library(softmax DEPS math_function) +math_library(matrix_bit_code) math_library(unpooling) math_library(vol2col) diff --git a/paddle/fluid/operators/math/math_function_impl.h b/paddle/fluid/operators/math/math_function_impl.h index b9bd49d77d935e985705f78402ffe1ea90f24cb3..895a7019aa10e5d9bb8f0c17e433a4344eac3bf4 100644 --- a/paddle/fluid/operators/math/math_function_impl.h +++ b/paddle/fluid/operators/math/math_function_impl.h @@ -155,7 +155,7 @@ class RowwiseSum { PADDLE_ENFORCE_EQ(in_dims.size(), 2U); auto height = in_dims[0]; auto size = in_dims[1]; - PADDLE_ENFORCE_EQ(out->numel(), size); + PADDLE_ENFORCE_EQ(out->numel(), height); T* out_buf = out->mutable_data(out->place()); const T* in_buf = input.data(); diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e56e297396c6e37867a53f039478191f0caf08e --- /dev/null +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -0,0 +1,176 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/matrix_bit_code.h" +#include +namespace paddle { +namespace operators { +namespace math { + +template +void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, + const framework::Tensor& vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat->dims()[0]; + size_t width = tmat->dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + tmat->data()[i * width + j] += vec.data()[index]; + } + } +} + +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, + framework::Tensor* vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + vec->data()[index] += tmat.data()[i * width + j]; + } + } +} + +template +void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, + framework::Tensor* sum, T scale_sum) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + T sm = static_cast(0.0); + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + // calc_bit starts from right most bit, while data in tmat[i] is in the + // reverse order. + sm += tmat.data()[i * o_width + j]; + } + } + sum->data()[i] = scale_sum * sm; + } +} + +template +void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, + const framework::Tensor& weight, + const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat->dims()[0]; + size_t tmat_width = tmat->dims()[1]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_value = tmat->data(); + auto weight_value = weight.data(); + auto input_value = input.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + T sum = static_cast(0.0); + for (size_t k = 0; k < input_width; ++k) { + sum += weight_value[weight_width * index + k] * + input_value[input_width * i + k]; + } + tmat_value[i * tmat_width + j] += sum; + } + } +} + +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, + framework::Tensor* weight, + const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t input_width = input.dims()[1]; + size_t tmat_width = tmat.dims()[1]; + size_t weight_width = weight->dims()[1]; + auto tmat_value = tmat.data(); + auto weight_value = weight->data(); + auto input_value = input.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + weight_value[weight_width * index + k] += + tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; + } + } + } +} + +template +void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor* input) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t tmat_width = tmat.dims()[1]; + size_t input_width = input->dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input->data(); + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + input_value[input_width * i + k] += + tmat_value[i * tmat_width + j] * + weight_value[weight_width * index + k]; + } + } + } +} + +template +void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat->dims()[0]; + size_t o_width = tmat->dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat->data()[i * o_width + j] -= 1; + } + } + } +} + +template class MatrixBitCodeFunctor; +template class MatrixBitCodeFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h new file mode 100644 index 0000000000000000000000000000000000000000..5454d58f371afb5f5d6a1c3208318f80d4e0aa36 --- /dev/null +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -0,0 +1,143 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { +/** + * SimpleCodeTable class should support 3 functions: + * + * size_t size() + * return the number of ids + * + * int get_max_code_length() + * return the maximal code length + * + * SimpleCode operator()(size_t i) + * return the i-th code. Code class is descriebed below. + * + * SimpleCode class should support 3 functions: + * + * int get_length() + * return the length of the code + * + * size_t cal_index(int bit) + * bit ranges from 0 to get_length() - 1 + * return the index for the (1+bit) level parent + * + * bool calc_bit(int bit) + * return true if the bit level parent is the right child of (1+bit) level + * parent + * + */ + +/** + * return the 1-based index of the highest bit set + * + * for x > 0: + * \f[ + * FindLastSet(x) = 1 + \floor*{\log_{2}x} + * \f] + */ +inline constexpr size_t FindLastSet(size_t x) { + return std::is_same::value + ? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0) + : (std::is_same::value // NOLINT + ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0) + : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0)); +} + +struct SimpleCode { + SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} + /** + * Here the id of root shoud be 1 rather than 0, thus the encoding of class c + * is `c + num_classes` and all siblings can get the same weight indice using + * prefixes. + * Weight index is the prefixes of encoding, thus leave out the right most + * bit in calc_index. + * Binary classification path is the suffixes of encoding, thus leave out the + * left most bit in calc_bit. + */ + inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } + inline bool calc_bit(int bit) const { return c_ & (1 << bit); } + inline int get_length() const { return FindLastSet(c_) - 1; } + + private: + size_t c_; +}; + +struct SimpleCodeTable { + explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} + SimpleCode operator()(size_t code) const { + return SimpleCode(code, num_classes_); + } + size_t size() const { return num_classes_; } + int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } + + private: + size_t num_classes_; +}; + +template +class MatrixBitCodeFunctor { + public: + explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), ids_(ids) {} + /* For j < code_length + tmat(i, j) += vec(0, index(i, j)) + */ + void Add(framework::Tensor* tmat, const framework::Tensor& vec); + + /* For j < code_length + vec(0, index(i, j)) += tmat(i, j) + */ + void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); + + /* For j < code_length + sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) + */ + void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum); + + /* For j < code_length + tmat(i, j) -= bit(i, j) + */ + void Sub(framework::Tensor* tmat); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void Mul(framework::Tensor* tmat, const framework::Tensor& weight, + const framework::Tensor& input); + + /* For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) + */ + void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, + const framework::Tensor& input); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void MulGradError(const framework::Tensor& tmat, + const framework::Tensor& weight, framework::Tensor* input); + + size_t num_classes_; + const int64_t* ids_; +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 1dbafd23e92732bdaf0d263a01e267227786d839..e17c2ffd39eea31fe85933eda144ab97cf8c3dd8 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -23,7 +23,7 @@ class BatchReader : public framework::DecoratedReader { BatchReader(const std::shared_ptr& reader, int batch_size, bool discard_leftover) : DecoratedReader(reader), - batch_size_(batch_size), + batch_size_(static_cast(batch_size)), discard_leftover_(discard_leftover) { buffer_.reserve(batch_size_); } @@ -31,7 +31,7 @@ class BatchReader : public framework::DecoratedReader { void ReadNextImpl(std::vector* out) override; private: - int batch_size_; + size_t batch_size_; bool discard_leftover_; std::vector> buffer_; }; @@ -78,7 +78,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { void BatchReader::ReadNextImpl(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); - for (int i = 0; i < batch_size_; ++i) { + for (size_t i = 0; i < batch_size_; ++i) { buffer_.push_back(std::vector()); reader_->ReadNext(&buffer_.back()); if (buffer_.back().empty()) { @@ -95,9 +95,9 @@ void BatchReader::ReadNextImpl(std::vector* out) { // if buffer_ is empty, the 'out' will return as an empty vector. return; } - int out_num = buffer_[0].size(); + size_t out_num = buffer_[0].size(); out->reserve(out_num); - for (int j = 0; j < out_num; ++j) { + for (size_t j = 0; j < out_num; ++j) { // Merge shape and check date type std::type_index batch_type = buffer_[0][j].type(); framework::DDim batch_shape = buffer_[0][j].dims(); diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c507baf3a0ab0a557d29a53700685753616193b --- /dev/null +++ b/paddle/fluid/operators/squeeze_op.cc @@ -0,0 +1,202 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class SqueezeOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SqueezeOp should not be null."); + + const auto &x_dims = ctx->GetInputDim("X"); + // Check input tensor dims (<6) Eigen limit. + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimnesions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)."); + + const auto &axes = ctx->Attrs().Get>("axes"); + for (int a : axes) { + PADDLE_ENFORCE_LT(a, x_dims.size(), + "The squeeze axis should be less than input " + "tensor's rank."); + } + + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + } + + static framework::DDim GetOutputShape(const std::vector squeeze_dims, + const framework::DDim &in_dims) { + size_t num_squeeze_dims = squeeze_dims.size(); + int cnt_squeezed_dims = 0; + bool should_squeeze[9] = {false}; + + // Determines number of dimensions of output tensor after squeeze. + // Mark and count the dimensions need to be squeezed + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < in_dims.size(); ++idx) { + if (in_dims[idx] == 1) { + should_squeeze[idx] = true; + ++cnt_squeezed_dims; + } + } + } else { + for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() + : squeeze_dims[idx]; + // Check current index, the upper limit has beed checked in line 36. + PADDLE_ENFORCE(current >= 0, + "Invalid axis, the negative axis is out of range."); + PADDLE_ENFORCE(in_dims[current] == 1, + "Invalid axis index, the axis that will be squeezed " + "should be equal to 1."); + + if (!(should_squeeze[current])) { + ++cnt_squeezed_dims; + } + should_squeeze[current] = true; + } + } + + // Make output dimensions + std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { + if (!should_squeeze[in_idx]) { + output_shape[out_idx++] = in_dims[in_idx]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class SqueezeOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axes = Attr>("axes"); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(out_dims); + attrs["inplace"] = Attr("inplace"); + // Invoke Reshape Op + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + +class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor). The input tensor of squeeze operator."); + AddOutput("Out", "(Tensor). The output tensor of squeeze operator."); + AddAttr>("axes", + "(std::vector). List of integers," + " indicating the dimensions to squeeze.") + .SetDefault({}); + AddAttr("inplace", + "(default: false) Squeeze the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Squeeze Operator. + + Remove single-dimensional entries from the shape of a tensor. + Takes a parameter axes with a list of axes to squeeze. + If axes is not provided, all the single dimensions will be removed from the shape. + If an axis is selected with shape entry not equal to one, an error is raised. + + Examples: + Case 1: + Given + X.shape = (1, 3, 1, 5) + and + axes = [0] + we get: + Out.shape = (3, 1, 5) + + Case 2: + Given + X.shape = (1, 3, 1, 5) + and + axes = [] + we get: + Out.shape = (3, 5) + )DOC"); + } +}; + +class SqueezeGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + context->SetOutputDim(framework::GradVarName("X"), + context->GetInputDim("X")); + context->ShareLoD("X", framework::GradVarName("X")); + } +}; + +class SqueezeGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(x_dims); + attrs["inplace"] = Attr("inplace"); + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle + +// Tell linker to use reshape op +USE_OP(reshape); + +namespace ops = paddle::operators; +REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, + ops::SqueezeOpInferShape, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape); diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2a15fdf572e0de30f9949dda5020e130b0c5585 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -0,0 +1,191 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class UnsqueezeOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnsqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnsqueezeOp should not be null."); + + const auto &axes = ctx->Attrs().Get>("axes"); + const auto &x_dims = ctx->GetInputDim("X"); + // Validity Check: input tensor dims (<6). + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + } + + static framework::DDim GetOutputShape(const std::vector unsqz_dims, + const framework::DDim &in_dims) { + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validity Check: rank range. + PADDLE_ENFORCE(output_size <= 6, + "The output tensor's rank should be less than 6."); + + for (int axis : unsqz_dims) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Vaildity Check: the axis bound + PADDLE_ENFORCE( + cur >= 0 && cur <= cur_output_size, + "The unsqueeze dims must be within range of current rank."); + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class UnsqueezeOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axes = Attr>("axes"); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(out_dims); + attrs["inplace"] = Attr("inplace"); + // Invoke Reshape op. + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + +class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); + AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); + AddAttr>("axes", + "(std::vector). List of integers," + " indicating the dimensions to be inserted") + .AddCustomChecker([](const std::vector &axes) { + PADDLE_ENFORCE(!axes.empty(), + "Invalid axes, The unsqueeze axes is empty."); + // Validity Check: axes dims (<6). + PADDLE_ENFORCE(static_cast(axes.size()) < 6, + "Invalid dimensions, dynamic dimensions should be " + "within [1, 6] dimensions (Eigen limit)."); + // Validity Check: the range of unsqueeze aixs. + for (int axis : axes) { + PADDLE_ENFORCE(axis < 6, + "Invalid dimensions, input axis should be" + " within [1, 6] dimensions (Eigen limit)."); + } + }); + AddAttr( + "inplace", + "(default: false) Unsqueeze the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Unsqueeze Operator. + + Insert single-dimensional entries to the shape of a tensor. + Takes one required argument axes, a list of dimensions that will be inserted. + Dimension indices in axes are as seen in the output tensor. + + For example: + Given a tensor such that tensor with shape [3, 4, 5], + then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] + )DOC"); + } +}; + +class UnsqueezeGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", framework::GradVarName("X")); + } +}; + +class UnsqueezeGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(x_dims); + attrs["inplace"] = Attr("inplace"); + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle + +// Tell linker to use reshape op. +USE_OP(reshape); + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, + ops::UnsqueezeOpInferShape, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, + ops::UnsqueezeGradInferShape); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index be9d375c69dc1d84f685687f26b3d1d950fad63c..69a067b99b2413b7ec990de43963182a022d29c6 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -66,6 +66,14 @@ bool IsCompiledWithCUDA() { #endif } +bool IsCompiledWithDIST() { +#ifdef PADDLE_WITH_DIST + return true; +#else + return false; +#endif +} + PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); @@ -508,6 +516,7 @@ All parameter, weight, gradient are variables in Paddle. [](bool init_p2p) { framework::InitDevices(init_p2p); }); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); + m.def("is_compiled_with_dist", IsCompiledWithDIST); #ifdef PADDLE_WITH_CUDA m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { // Only GPUs with Compute Capability >= 53 support float16 diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index bec161a571d0f98b99ac656f08cbc2d401c943fc..25900811509aee8b37fdaf09cf902ea2ae3eee57 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -92,8 +92,15 @@ install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR} DESTINATION opt/paddle/share/wheels ) -find_program(PATCHELF_EXECUTABLE patchelf) -if(NOT PATCHELF_EXECUTABLE) - message(FATAL_ERROR "patchelf not found, please install it.\n" - "For Ubuntu, the command is: apt-get install -y patchelf.") -endif() +if(APPLE) + find_program(INSTALL_NAME_TOOL_EXECUTABLE install_name_tool) + if(NOT INSTALL_NAME_TOOL_EXECUTABLE) + message(FATAL_ERROR "install_name_tool not found, please check.\n") + endif() +else(APPLE) + find_program(PATCHELF_EXECUTABLE patchelf) + if(NOT PATCHELF_EXECUTABLE) + message(FATAL_ERROR "patchelf not found, please install it.\n" + "For Ubuntu, the command is: apt-get install -y patchelf.") + endif() +endif(APPLE) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 3034c1a0875a71421bcba172c16ee32d809df152..ba562d3ba9755292a1962ea4fa6d65f8367138d3 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -121,6 +121,9 @@ def __bootstrap__(): 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem' ] + if core.is_compiled_with_dist(): + read_env_flags.append('rpc_deadline') + if core.is_compiled_with_cuda(): read_env_flags += [ 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic' diff --git a/python/paddle/fluid/annotations.py b/python/paddle/fluid/annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8756a4664013643c278c013ca21bb237a6b4a7 --- /dev/null +++ b/python/paddle/fluid/annotations.py @@ -0,0 +1,38 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import sys + +__all__ = ['deprecated'] + + +def deprecated(since, instead, extra_message=""): + def decorator(func): + err_msg = "API {0} is deprecated since {1}. Please use {2} instead.".format( + func.__name__, since, instead) + if len(extra_message) != 0: + err_msg += "\n" + err_msg += extra_message + + @functools.wraps(func) + def wrapper(*args, **kwargs): + print >> sys.stderr, err_msg + return func(*args, **kwargs) + + wrapper.__doc__ += "\n " + wrapper.__doc__ += err_msg + return wrapper + + return decorator diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index e7a065599ec034039442449ea20e1fcd02c44f8d..ddcde04716d21df1f18e7202936f470d3d58a661 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -18,10 +18,7 @@ import collections import copy import unique_name -__all__ = [ - 'append_backward', - 'calc_gradient', -] +__all__ = ['append_backward'] def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 5c8f4f6507c7dd9b3d005639d962ce1e55b2c704..0eb1194e2754331dcbc8436f6680ab776a999c29 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -24,10 +24,7 @@ from . import core __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model', - 'get_inference_program', 'save_checkpoint', 'load_checkpoint', - 'clean_checkpoint', 'load_persist_vars_without_grad', - 'load_lookup_table_vars', 'save_persist_vars_without_grad', - 'get_latest_checkpoint_serial' + 'get_inference_program' ] @@ -794,588 +791,6 @@ def get_parameter_value_by_name(name, executor, program=None): return get_parameter_value(var, executor) -SUCCESS_MARK_FILENAME = "_SUCCESS" -CHECKPOINT_PREFIX = "checkpoint" -MODEL_DIR = "__model__" -LOOKUP_TABLE_DIR = "__lookup_table__" -TRAINER_PREFIX = "trainer" -CHECKPOINT_SEPARATOR = "_" - - -def save_checkpoint(executor, - checkpoint_dir, - trainer_id, - trainer_args=None, - main_program=None, - max_num_checkpoints=3, - lookup_table=None, - ps_endpoint_list=None): - """ - This function filters out all checkpoint variables from the give - main_program and then saves these variables to the `checkpoint_dir` - directory. - - In the training precess, we generally save a checkpoint in each - iteration. So there might be a lot of checkpoints in the - `checkpoint_dir`. To avoid them taking too much disk space, the - `max_num_checkpoints` are introduced to limit the total number of - checkpoints. If the number of existing checkpints is greater than - the `max_num_checkpoints`, oldest ones will be scroll deleted. - - A variable is a checkpoint variable and will be saved if it meets - all following conditions: - 1. It's persistable. - 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. - 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". - - Args: - executor(Executor): The executor to run for save checkpoint. - checkpoint_dir(str): The folder where to save checkpoints. - trainer_id(int): currect trainer id, if id is equal to 0, the trainer - is chief. - trainer_args(dict|None): Current training arguments. Such as 'epoch_id' - and 'step_id'. - Defaut: None - main_program(Program|None): The program whose checkpoint variables will - be saved. If it is None, the default main program will be used. - max_num_checkpoints(int): The max number of total number of existing - checkpoints. - Default: 3 - lookup_table(string|None): the lookup table name, when use distribute - lookup table, we can get lookup table name by DistributeTranspiler. - table_name - ps_endpoint_list(list|None): the parameter server ip:port list. - when use distribute lookup table, we can get ps_endpoint_list by - distribute arguments. - - Returns: - None - - Raises: - ValueError: If `checkpoint_dir` is None. - AssertionError: If `trainer_args` is not a dict. - - Examples: - .. code-block:: python - - exe = fluid.Executor(fluid.CPUPlace()) - path = "./checkpoints" - prog = fluid.default_main_program() - trainer_args = {"epoch_id": 200, - "step_id": 20} # just an example - table_name = "share_w" - ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] - - fluid.io.save_checkpoint(executor=exe, - checkpoint_dir=path, - trainer_id=0, - trainer_args=trainer_args, - main_program=prog, - max_num_checkpoints=3, - lookup_table=table_name, - ps_endpoint_list = ps_endpoints) - """ - if checkpoint_dir is None: - raise ValueError("'checkpoint_dir' should not be None") - assert checkpoint_dir - - if trainer_args: - assert isinstance(trainer_args, dict) - - is_chief = trainer_id == 0 - - _make_chekcpoint_dirs(checkpoint_dir) - serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 - cur_dir = _get_serial_dir(checkpoint_dir, serial) - - save_trainer_args(cur_dir, trainer_id, trainer_args) - - if is_chief: - save_persist_vars_without_grad(executor, cur_dir, main_program) - - if is_chief and lookup_table and ps_endpoint_list: - save_pserver_vars_by_notify(executor, cur_dir, lookup_table, - ps_endpoint_list) - - _scroll_delete(checkpoint_dir, max_num_checkpoints) - - -def load_checkpoint(executor, checkpoint_dir, serial, main_program): - """ - This function filters out all checkpoint variables from the give - main_program and then try to load these variables from the - `checkpoint_dir` directory. - - In the training precess, we generally save a checkpoint in each - iteration. So there are more than one checkpoint in the - `checkpoint_dir` (each checkpoint has its own sub folder), use - `serial` to specify which serial of checkpoint you would like to - load. - - A variable is a checkpoint variable and will be loaded if it meets - all following conditions: - 1. It's persistable. - 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. - 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". - - Args: - executor(Executor): The executor to run for loading checkpoint. - checkpoint_dir(str): The folder where all checkpoints are. - serial(int): The serial of checkpoint you would like to load. - main_program(Program): The program whose checkpoint variables will - be loaded. - - Returns: - None - - Raises: - ValueError: If `checkpoint_dir` is None. - ValueError: If `serial` is None or `serial` is less than 0. - ValueError: If `main_program` is None. - - Examples: - .. code-block:: python - - exe = fluid.Executor(fluid.CPUPlace()) - path = "./checkpoints" - prog = fluid.default_main_program() - fluid.io.load_checkpoint(executor=exe, checkpoint_dir=path, - serial=9, main_program=prog) - - # In this example, `load_checkpoint` function - # will first filters out all checkpoint variables in the default - # main program, and then try to load these variables form the - # folder "./checkpoints/checkpoint_9/__model__". - """ - - if checkpoint_dir is None: - raise ValueError("'checkpoint_dir' should not be None") - - if serial is None or serial < 0: - raise ValueError("'serial' should not be None or <0 ") - - if main_program is None: - raise ValueError('main_program should not be None.') - - cur_dir = _get_serial_dir(checkpoint_dir, serial) - load_persist_vars_without_grad(executor, cur_dir, main_program, True) - - -def clean_checkpoint(checkpoint_dir, delete_dir=False): - """ - clean the checkpoint dir, when the train exits normally, - the trainer will call clean_checkpoint to delete checkpoint directory saved before. - delete_dir only works when the directory is empty, otherwise, OSError is raised. - - : param checkpoint_dir - : param delete_dir - """ - - if checkpoint_dir is None: - raise ValueError("'checkpoint_dir' should not be None") - _scroll_delete(checkpoint_dir, max_num_checkpoints=0) - - if delete_dir and not os.listdir(checkpoint_dir): - os.rmdir(checkpoint_dir) - - -def load_persist_vars_without_grad(executor, - dirname, - program, - has_model_dir=False): - """ - This function filters out all checkpoint variables from the give - program and then trys to load these variables from the given directory. - - A variable is a checkpoint variable if it meets all following - conditions: - 1. It's persistable. - 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. - 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". - - Args: - executor(Executor): The executor to run for loading variables. - dirname(str): The directory path. - program(Program): The program whose checkpoint variables will - be loaded. - has_model_dir(bool): if True, the function loads variables - from a sub directory named '__model__'. - Default: False - - Returns: - None - - Examples: - .. code-block:: python - - exe = fluid.Executor(fluid.CPUPlace()) - param_path = "./my_paddle_model" - prog = fluid.default_main_program() - fluid.io.load_persist_vars_without_grad(executor=exe, - dirname=param_path, program=prog, has_model_dir=True) - - # In this example, `load_persist_vars_without_grad` function - # will first filters out all checkpoint variables in the default - # main program, and then trys to load these variables form the - # folder "./my_paddle_model/__model__". - """ - - if has_model_dir: - dirname = _get_model_dir(dirname) - - load_vars( - executor, - dirname=dirname, - main_program=program, - predicate=_is_checkpoint_var, - filename=None) - - -def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): - """ - The parameter server will load lookup table's local file in - selectedrows variable. - - Args: - executor(Executor): The executor to run for loading persistable variables - dirname(str): The directory path - main_program(Program): Find the variable named table_name in main_program - pserver_id(int): the serial number in pserver_endpoints list - table_name(str): lookup table name - - Returns: - None - - Examples: - .. code-block:: python - - exe = fluid.Executor(fluid.CPUPlace()) - dirname = "./checkpoints/checkpoint_9/__model__" - prog = fluid.default_main_program() - pserver_id = 1 - table_name = "share_w" - fluid.io.load_lookup_table_vars(executor=exe, - dirname=dirname, program=prog, pserver_id=pserver_id, - table_name=table_name) - """ - - for var in program.list_vars(): - if var.name == table_name: - lookup_table_var = var - break - - assert lookup_table_var is not None - - lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) - table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) - - load_prog = Program() - load_block = load_prog.global_block() - - load_block.append_op( - type='load', - inputs={}, - outputs={'Out': [lookup_table_var]}, - attrs={'file_path': os.path.join(lookup_table_dir, table_file)}) - - executor.run(load_prog) - - -def save_persist_vars_without_grad(executor, dirname, program): - """ - This function filters out all checkpoint variables from the give - program and then save these variables to a sub-folder '__model__' of - the given directory. - - A variable is a checkpoint variable if it meets all following - conditions: - 1. It's persistable. - 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. - 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". - - Args: - executor(Executor): The executor to run for saving variables. - dirname(str): The directory path. - program(Program): The program whose checkpoint variables will - be saved. - - Returns: - None - - Examples: - .. code-block:: python - - exe = fluid.Executor(fluid.CPUPlace()) - param_path = "./my_paddle_model" - prog = fluid.default_main_program() - fluid.io.save_persist_vars_without_grad(executor=exe, - dirname=param_path, program=prog) - - # In this example, `save_persist_vars_without_grad` function - # will first filters out all checkpoint variables in the default - # main program, and then saves these variables to the folder - # "./my_paddle_model/__model__". - """ - cur_dir = _get_model_dir(dirname) - save_vars( - executor, - dirname=cur_dir, - main_program=program, - vars=None, - predicate=_is_checkpoint_var, - filename=None) - _write_success(cur_dir) - - -def save_pserver_vars_by_notify(executor, dirname, lookup_table, - ps_endpoint_list): - """ - This function will send checkpoint notify message from Trainer 0 - to all the pservers. - The checkpoint notify message contains lookup table name, - the absolute path on pserver to save lookup_table. - - Args: - executor(Executor): The executor to run for send checkpoint notify. - dirname(str): The folder where to save checkpoints. - lookup_table(string): the lookup table name, when use distribute - lookup table, we can get lookup table name by DistributeTranspiler. - table_name - ps_endpoint_list(list): the parameter server ip:port list. - when use distribute lookup table, we can get ps_endpoint_list by - distribute arguments. - Return: - None - - Examples: - .. code-block:: python - - exe = fluid.Executor(fluid.CPUPlace()) - param_path = "./my_paddle_model" - prog = fluid.default_main_program() - table_name = "share_w" - ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] - - fluid.io.save_pserver_vars_by_notify(executor=exe, - dirname=param_path, lookup_table=table_name, - ps_endpoint_list=ps_endpoints) - """ - cur_dir = _get_lookuptable_dir(dirname) - - checkpoint_notify_program = Program() - checkpoint_notify_block = checkpoint_notify_program.global_block() - - attrs = {} - attrs['epmap'] = ps_endpoint_list - attrs['dir'] = cur_dir - attrs['lookup_table'] = lookup_table - - checkpoint_notify_block.append_op( - type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) - executor.run(checkpoint_notify_program) - - -def save_trainer_args(dirname, trainer_id, trainer_args): - assert isinstance(trainer_args, dict) - - cur_dir = _get_trainer_dir(dirname, trainer_id) - - for name, value in trainer_args.iteritems(): - args_file = os.path.join(cur_dir, name) - with open(args_file, 'w') as f: - f.write(str(value)) - _write_success(cur_dir) - - -def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): - """ - trainer will load some args from it's independent directory, - such as epoch_id and step_id. - - Args: - checkpoint_dir(str): The folder where all checkpoints are. - serial(int): The serial of checkpoint you would like to load. - trainer_id(int): current trainer id. - trainer_args(list): list about load trainer args - Return: - None - - Examples: - .. code-block:: python - - param_path = "./checkpoint/" - serial = 7 - trainer_id = 2 - trainer_args = ["epoch_id", "step_id"] - - fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial, - trainer_id=trainer_id, trainer_args=trainer_args) - """ - assert isinstance(trainer_args, list) - - cur_dir = _get_serial_dir(checkpoint_dir, serial) - cur_dir = _get_trainer_dir(cur_dir, trainer_id) - - ret_values = [] - - for arg in trainer_args: - cur_file = os.path.join(cur_dir, arg) - with open(cur_file, 'r') as f: - contents = f.read() - ret_values.append(contents.strip()) - return ret_values - - -def _is_checkpoint_var(var): - """ - the checkpoint will not save or load all the variables. - var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. - - : param var(Variable) - """ - if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ - var.desc.type() == core.VarDesc.VarType.RAW: - return False - # @GRAD are named for gradient variables, checkpoint will not save it. - if "@GRAD" in var.name: - return False - # .trainer_ are named for distribute train variables, checkpoint will not save it. - if ".trainer_" in var.name: - return False - - # .block is named for distribute train variables, checkpoint will not save it. - if ".block" in var.name: - return False - - return var.persistable - - -def _make_chekcpoint_dirs(dirs): - """ - _make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it. - """ - assert dirs is not None - - if os.path.isfile(dirs): - raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs) - - if not os.path.isdir(dirs): - try: - os.makedirs(dirs) - except OSError as err: - if err.errno != errno.EEXIST: - raise err - - -def _get_dir_serial(dirname): - _, serial = dirname.split(CHECKPOINT_SEPARATOR) - - try: - serial_num = int(serial) - except ValueError: - serial_num = -1 - return serial_num - - -def _get_serial_dir(dirname, serial): - serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) - serial_dir = os.path.join(dirname, serial_folder) - _make_chekcpoint_dirs(serial_dir) - - return serial_dir - - -def _get_model_dir(dirname): - model_dir = os.path.join(dirname, MODEL_DIR) - _make_chekcpoint_dirs(model_dir) - return model_dir - - -def _get_lookuptable_dir(dirname): - lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) - _make_chekcpoint_dirs(lookuptable_dir) - return lookuptable_dir - - -def _get_trainer_dir(dirname, trainer_id): - trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) - trainer_dir = os.path.join(dirname, trainer_folder) - _make_chekcpoint_dirs(trainer_dir) - return trainer_dir - - -def _scroll_delete(dirname, max_num_checkpoints=3): - dirs = os.listdir(dirname) - serial_map = {} - for serial in dirs: - serial_num = _get_dir_serial(serial) - serial_map[serial_num] = serial - - if len(serial_map.keys()) <= max_num_checkpoints: - return - - serials = serial_map.keys() - serials.sort(reverse=True) - serials = serials[max_num_checkpoints:] - for serial in serials: - cur_dir = _get_serial_dir(dirname, serial) - try: - shutil.rmtree(cur_dir) - except OSError as err: - if err.errno != errno.ENOENT: - raise err - - -def _write_success(dirname): - """ - write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct. - - : param dirname - """ - success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) - with open(success_file, 'a') as f: - now = time.ctime() - f.write(now) - - -def get_latest_checkpoint_serial(checkpoint_dir): - """ - get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory - - : param checkpoint_dir - """ - if not checkpoint_dir: - return -1 - - def has_success(checkpoint_dir, cur_dir): - """ - is _SUCCESS in this dir - """ - - serial = _get_dir_serial(cur_dir) - if serial == -1 or not os.path.isdir( - os.path.join(checkpoint_dir, cur_dir)): - return -1 - - success_path = os.path.join( - _get_serial_dir(checkpoint_dir, serial), MODEL_DIR, - SUCCESS_MARK_FILENAME) - if os.path.isfile(success_path): - return serial - - if not os.path.isdir(checkpoint_dir): - return -1 - - current_dir = -1 - dirs = os.listdir(checkpoint_dir) - for cur_dir in dirs: - success_num = has_success(checkpoint_dir, cur_dir) - if success_num > current_dir: - current_dir = success_num - return current_dir - - def get_test_program(filelist, program=None, startup_program=None): """ Transpile current train program to a program to read test dataset diff --git a/python/paddle/fluid/layers/device.py b/python/paddle/fluid/layers/device.py index e0c1aab230aeed7fb858e91e7da7eae58032ee16..384d302a709eeec220864b9e8c9210ed028470f6 100644 --- a/python/paddle/fluid/layers/device.py +++ b/python/paddle/fluid/layers/device.py @@ -18,10 +18,12 @@ All util layers. from layer_function_generator import autodoc from ..framework import unique_name from ..layer_helper import LayerHelper +from ..annotations import deprecated -__all__ = ['get_places'] +__all__ = [] +@deprecated(since='0.15.0', instead="ParallelExecutor") @autodoc() def get_places(device_count=None, device_type=None): helper = LayerHelper('get_places', **locals()) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 07b806f544497ccabe4dde9a370e90da372e6cba..cc223899c73deb173701db0fba4123c8442bfd43 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -85,6 +85,7 @@ __all__ = [ 'transpose', 'im2sequence', 'nce', + 'hsigmoid', 'beam_search', 'row_conv', 'multiplex', @@ -3871,6 +3872,74 @@ def nce(input, return cost / (num_neg_samples + 1) +def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None): + """ + The hierarchical sigmoid operator is used to accelerate the training + process of language model. This operator organizes the classes into a + complete binary tree, each leaf node represents a class(a word) and each + internal node acts as a binary classifier. For each word there's a unique + path from root to it's leaf node, hsigmoid calculate the cost for each + internal node on the path, and sum them to get a total cost. hsigmoid can + achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N` + represents the size of word dict. + + Refer to `Hierarchical Probabilistic Neural Network Language Model + `_ + + Args: + input (Variable): The input tensor variable with shape + :math:`[N \\times D]`, where :math:`N` is the size of mini-batch, + and :math:`D` is the feature size. + label (Variable): The tensor variable contains labels of training data. + It's a tensor with shape is :math:`[N \\times 1]`. + num_classes: (int), The number of classes, must not be less than 2. + param_attr (ParamAttr|list of ParamAttr, default None): The parameter + attribute for learnable parameters/weights of this layer. + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter + attribute for the bias of this layer. If it is set to False, no + bias will be applied. + + Returns: + Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[2], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='int64') + out = fluid.layers.hsigmoid(input=x, label=y, num_classes=6) + """ + + helper = LayerHelper('hierarchical_sigmoid', **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + pre_out = helper.create_tmp_variable(dtype) + dim = input.shape[1] + if num_classes < 2: + raise ValueError("num_classes must not be less than 2.") + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[num_classes - 1, dim], + is_bias=False, + dtype=input.dtype) + inputs = {"X": input, "W": weights, "Label": label} + if helper.bias_attr: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[1, num_classes - 1], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias + helper.append_op( + type="hierarchical_sigmoid", + inputs=inputs, + outputs={"Out": out, + "PreOut": pre_out}, + attrs={"num_classes": num_classes}) + return out + + def transpose(x, perm, name=None): """ Permute the dimensions of `input` according to `perm`. diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 75ee40fa9ca94cdd84ee7acbb62d6e652ac7fa33..e2acf6d41a0085e6f741e46063b47d2ff1e769cb 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -29,7 +29,7 @@ __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', - 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'Optimizer', 'RMSPropOptimizer' + 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer' ] @@ -67,7 +67,7 @@ class Optimizer(object): self._LARS_weight_decay = LARS_weight_decay def _create_global_learning_rate(self): - lr = self.global_learning_rate() + lr = self._global_learning_rate() if isinstance(lr, framework.Variable): return @@ -86,7 +86,7 @@ class Optimizer(object): dtype='float32' if self._dtype == None else self._dtype, persistable=True) - def global_learning_rate(self, program=None): + def _global_learning_rate(self, program=None): """ get global decayed learning rate :return: @@ -110,9 +110,9 @@ class Optimizer(object): return param_lr else: if param_lr == 1.0: - return self.global_learning_rate() + return self._global_learning_rate() else: - return self.global_learning_rate() * param_lr + return self._global_learning_rate() * param_lr def _create_accumulators(self, block, parameters): """Create all accumulators needed by the parameters @@ -185,10 +185,10 @@ class Optimizer(object): format(name, param.name)) return self._accumulators[name][param.name] - def create_optimization_pass(self, - parameters_and_grads, - loss, - startup_program=None): + def _create_optimization_pass(self, + parameters_and_grads, + loss, + startup_program=None): """Add optimization operators to update gradients to variables. Args: @@ -221,7 +221,7 @@ class Optimizer(object): self._create_global_learning_rate() if self._LARS_weight_decay > 0.0: layers.append_LARS(parameters_and_grads, - self.global_learning_rate(), + self._global_learning_rate(), self._LARS_weight_decay) optimize_ops = [] @@ -262,8 +262,8 @@ class Optimizer(object): params_grads = append_regularization_ops(params_grads, self.regularization) - optimize_ops = self.create_optimization_pass(params_grads, loss, - startup_program) + optimize_ops = self._create_optimization_pass(params_grads, loss, + startup_program) return optimize_ops, params_grads diff --git a/python/paddle/fluid/tests/book/notest_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py index 1df7b99aad6094a8b8ddfe783b9de35cef61c524..95002aa7f9bb639828b47eb1e86e4ef954fb85ff 100644 --- a/python/paddle/fluid/tests/book/notest_understand_sentiment.py +++ b/python/paddle/fluid/tests/book/notest_understand_sentiment.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function - +from paddle.fluid.layers.device import get_places import unittest import paddle.fluid as fluid import paddle @@ -144,7 +144,7 @@ def train(word_dict, cost, acc_out, prediction = net_method( data, label, input_dim=dict_dim, class_dim=class_dim) else: - places = fluid.layers.get_places() + places = get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): cost, acc, _ = net_method( diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index 5f5c8544bbdb87421f129b201a0ebaf4cb8602a1..49f549fa184037a64aa846f0d1d0e1b57db1f2ef 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function -import argparse -import paddle.fluid as fluid -import paddle -import sys -import numpy -import unittest + import math -import sys import os +import sys +import unittest + +import numpy + +import paddle +import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places BATCH_SIZE = 64 @@ -76,7 +78,7 @@ def train(nn_type, net_conf = conv_net if parallel: - places = fluid.layers.get_places() + places = get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): img_ = pd.read_input(img) diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py index 49bd72c7a53c0ae740bdbabe15b1d37340699d41..80e0692bc640efc280c43bd5b929847ad29207c4 100644 --- a/python/paddle/fluid/tests/book/test_word2vec.py +++ b/python/paddle/fluid/tests/book/test_word2vec.py @@ -14,6 +14,7 @@ import paddle import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places import unittest import os import numpy as np @@ -80,7 +81,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True): avg_cost, predict_word = __network__( [first_word, second_word, third_word, forth_word, next_word]) else: - places = fluid.layers.get_places() + places = get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): avg_cost, predict_word = __network__( diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py index be347cd5315668dde0454d7959dbf9bcfa465b5f..bec9f8594ff7c1aff8ae5ed55c9623754d9ea091 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import paddle -import paddle.fluid as fluid import math import sys +import paddle +import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places + # need to fix random seed and training data to compare the loss # value accurately calculated by the default and the memory optimization # version. @@ -34,7 +35,7 @@ if fluid.core.is_compiled_with_cuda(): use_nccl = False place = fluid.CUDAPlace(0) -places = fluid.layers.get_places(device_count=0, device_type=device_type) +places = get_places(device_count=0, device_type=device_type) pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl) with pd.do(): x_ = pd.read_input(x) diff --git a/python/paddle/fluid/tests/unittests/test_calc_gradient.py b/python/paddle/fluid/tests/unittests/test_calc_gradient.py index 06e676cd83e77549afd679e730426c590cc046bf..7f2a9e6971ed933463216e38498d48ab132a1a37 100644 --- a/python/paddle/fluid/tests/unittests/test_calc_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_calc_gradient.py @@ -16,8 +16,6 @@ import unittest import paddle.fluid as fluid import paddle.fluid.layers as layers -import paddle.fluid.framework as framework -import paddle.fluid.optimizer as optimizer from paddle.fluid.backward import calc_gradient diff --git a/python/paddle/fluid/tests/unittests/test_checkpoint.py b/python/paddle/fluid/tests/unittests/test_checkpoint.py deleted file mode 100644 index e22400a045ced16c46b0bf005155f621f249d263..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/test_checkpoint.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.fluid as fluid -import unittest -import os -import tempfile - - -class TestCheckpoint(unittest.TestCase): - def setUp(self): - self.dirname = tempfile.mktemp() - self.max_num_checkpoints = 3 - self.epoch_interval = 1 - self.step_interval = 1 - self.trainer_id = 0 - self.chief = self.trainer_id == 0 - self.place = fluid.CPUPlace() - self.epoch_id = 100 - self.step_id = 20 - - def test_checkpoint(self): - self.save_checkpoint() - serial = fluid.io.get_latest_checkpoint_serial(self.dirname) - self.assertTrue(serial >= 0) - trainer_args = ["epoch_id", "step_id"] - epoch_id, step_id = fluid.io.load_trainer_args( - self.dirname, serial, self.trainer_id, trainer_args) - self.assertEqual(self.step_id, int(step_id)) - self.assertEqual(self.epoch_id, int(epoch_id)) - - program = fluid.Program() - with fluid.program_guard(program): - exe = fluid.Executor(self.place) - fluid.io.load_checkpoint(exe, self.dirname, serial, program) - - fluid.io.clean_checkpoint(self.dirname, delete_dir=True) - self.assertFalse(os.path.isdir(self.dirname)) - - def save_checkpoint(self): - config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints, - self.epoch_interval, self.step_interval) - - trainer_args = {} - trainer_args["epoch_id"] = self.epoch_id - trainer_args["step_id"] = self.step_id - - program = fluid.Program() - with fluid.program_guard(program): - program.global_block().create_var( - name="scale_0", - psersistable=True, - dtype="float32", - shape=[32, 32]) - - exe = fluid.Executor(self.place) - for i in xrange(10): - fluid.io.save_checkpoint(exe, config.checkpoint_dir, - self.trainer_id, trainer_args, program, - config.max_num_checkpoints) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_get_places_op.py b/python/paddle/fluid/tests/unittests/test_get_places_op.py index 6dab1e22f0c50ab011d6b8e8944097600cf3fecc..964423e2d2638224244b4ca774d8eee08f3ec989 100644 --- a/python/paddle/fluid/tests/unittests/test_get_places_op.py +++ b/python/paddle/fluid/tests/unittests/test_get_places_op.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places import decorators import unittest @@ -20,7 +21,7 @@ import unittest class TestGetPlaces(unittest.TestCase): @decorators.prog_scope() def test_get_places(self): - places = fluid.layers.get_places() + places = get_places() cpu = fluid.CPUPlace() exe = fluid.Executor(cpu) exe.run(fluid.default_main_program()) diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d090960c84e47da68a0ebea4609dfc3ed76e114e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -0,0 +1,99 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import math +from op_test import OpTest + + +def find_latest_set(num): + return 1 + int(math.floor(math.log(num, 2))) + + +class CodeTable(object): + def __init__(self, num_classes, code): + self.c = num_classes + code + + def cal_index(self, bit): + return (self.c >> (bit + 1)) - 1 + + def get_length(self): + return find_latest_set(self.c) - 1 + + def cal_bit(self, bit): + return self.c & (1 << bit) + + +def hsigmoid(x, w, label, bias, num_classes): + batch_size = x.shape[0] + code_length = find_latest_set(num_classes - 1) + code_table = [0 for _ in range(code_length)] + pre_output = np.zeros((batch_size, code_length)) + pre_sum = np.zeros((batch_size, 1)) + out = np.zeros((batch_size, 1)).astype("float32") + for i in range(batch_size): + code_table = CodeTable(num_classes, label[i]) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[0][idx] + for i in range(batch_size): + code_table = CodeTable(num_classes, label[i]) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += np.dot(w[idx], x[i]) + # clip[-40.0, 40.0] + pre_output = np.clip(pre_output, -40.0, 40.0) + # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + for i in range(batch_size): + code_table = CodeTable(num_classes, label[i]) + length = code_table.get_length() + sum = 0.0 + for j in range(length): + if code_table.cal_bit(j): + sum += pre_output[i][j] + out[i] = -1.0 * sum + # soft relu + pre_output = np.log(1 + np.exp(pre_output)) + pre_sum = pre_output.sum(1).reshape((batch_size, 1)) + out += pre_sum + return pre_output, out + + +class TestHSigmoidOp(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") + w = np.random.random((num_classes - 1, feature_size)).astype("float32") + label = np.random.randint(0, num_classes, (batch_size, 1)) + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes} + self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} + pre_output, out = hsigmoid(x, w, label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 31ae4e7d8c4107fce7022234d3aa47827647d012..6b1f206ea2f5a6226cfdb01c70a8ce4646ae4788 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import paddle.fluid.layers as layers +from paddle.fluid.layers.device import get_places import paddle.fluid.nets as nets from paddle.fluid.framework import Program, program_guard, default_main_program from paddle.fluid.param_attr import ParamAttr @@ -173,6 +174,16 @@ class TestBook(unittest.TestCase): x=dat, label=lbl)) print(str(program)) + def test_hsigmoid(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[2], dtype='float32') + y = layers.data(name='y', shape=[2], dtype='int64') + self.assertIsNotNone( + layers.hsigmoid( + input=x, label=y, num_classes=2)) + print(str(program)) + def test_sequence_expand(self): program = Program() with program_guard(program): @@ -238,7 +249,7 @@ class TestBook(unittest.TestCase): def test_get_places(self): program = Program() with program_guard(program): - x = layers.get_places(device_count=4) + x = get_places(device_count=4) self.assertIsNotNone(x) print(str(program)) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 7286c7c450108c4b5ad7136041bc4e989894a2ba..43385691bb3960004b5b69a1c55e41dd4252fa71 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -97,7 +97,7 @@ class TestMomentumOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) - opts = momentum_optimizer.create_optimization_pass( + opts = momentum_optimizer._create_optimization_pass( params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) sgd_op = opts[-1] @@ -151,7 +151,7 @@ class TestMomentumOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) - opts = momentum_optimizer.create_optimization_pass( + opts = momentum_optimizer._create_optimization_pass( params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) sgd_op = opts[-1] @@ -214,8 +214,8 @@ class TestAdagradOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0) - opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = adagrad_optimizer._create_optimization_pass( + params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) self.assertEqual([op.type for op in opts], ["fill_constant", "elementwise_mul", "adagrad"]) @@ -278,8 +278,8 @@ class TestAdamOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adam_optimizer.get_accumulators()), 0) - opts = adam_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = adam_optimizer._create_optimization_pass(params_grads, mul_out, + init_program) self.assertEqual(len(opts), 5) self.assertEqual( [op.type for op in opts], @@ -345,8 +345,8 @@ class TestAdamaxOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) - opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = adamax_optimizer._create_optimization_pass(params_grads, mul_out, + init_program) self.assertEqual(len(opts), 4) self.assertEqual( [op.type for op in opts], @@ -409,7 +409,7 @@ class TestDecayedAdagradOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) - opts = decayed_adagrad_optimizer.create_optimization_pass( + opts = decayed_adagrad_optimizer._create_optimization_pass( params_grads, mul_out, init_program) self.assertEqual(len(opts), 3) self.assertEqual( @@ -475,8 +475,8 @@ class TestFtrlOptimizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(ftrl_optimizer.get_accumulators()), 0) - opts = ftrl_optimizer.create_optimization_pass(params_grads, mul_out, - init_program) + opts = ftrl_optimizer._create_optimization_pass(params_grads, mul_out, + init_program) self.assertEqual(len(opts), 3) self.assertEqual([op.type for op in opts], ["fill_constant", "elementwise_mul", "ftrl"]) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_op.py b/python/paddle/fluid/tests/unittests/test_parallel_op.py index 9ba5f988f317a515b77c0b428da236626419a2c3..9ec05e02973138e3ec233ef07f98afd598ec86b1 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_op.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_op.py @@ -15,6 +15,7 @@ import unittest import paddle.fluid as fluid +from paddle.fluid.layers.device import get_places import paddle.fluid.profiler as profiler import numpy @@ -115,7 +116,7 @@ class BaseParallelForTest(unittest.TestCase): if use_parallel: thread_num = fluid.core.get_cuda_device_count( ) if use_gpu else 8 - places = fluid.layers.get_places(thread_num) + places = get_places(thread_num) pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl) data = next(generator) diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bca6af2fd5dfadbc48cf1a76cfa6ffd4f1fdfdef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -0,0 +1,114 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestSqueezeOp(OpTest): + def setUp(self): + self.op_type = "squeeze" + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, 2) + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} + + +# Correct: There is mins axis. +class TestSqueezeOp1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, -2) + self.new_shape = (3, 5) + + +# Correct: No axes input. +class TestSqueezeOp2(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = () + self.new_shape = (3, 5) + + +# Correct: Just part of axes be squeezed. +class TestSqueezeOp3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (3, 5, 1, 4) + + +# Correct: Inplace. +class TestSqueezeOpInplace1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, 2) + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is mins axis. +class TestSqueezeOpInplace2(TestSqueezeOp): + def inti_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, -2) + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. No axes input. +class TestSqueezeOpInplace3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = () + self.new_shape = (3, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inpalce. Just part of axes be squeezed. +class TestSqueezeOpInplace4(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (3, 5, 1, 4) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4aa0a40b5eb494f6027e800ca6b466bbe1c302 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -0,0 +1,111 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestUnsqueezeOp(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} + + +# Correct: Single input index. +class TestUnsqueezeOp1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (-1, ) + self.new_shape = (3, 5, 1) + + +# Correct: Mixed input axis. +class TestUnsqueezeOp2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -1) + self.new_shape = (1, 3, 5, 1) + + +# Correct: There is duplicated axis. +class TestUnsqueezeOp3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + +# Correct: Inplace. +class TestUnsqueezeOpInplace1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, 2) + self.new_shape = (1, 3, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is mins index. +class TestUnsqueezeOpInplace2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -2) + self.new_shape = (1, 3, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +# Correct: Inplace. There is duplicated axis. +class TestUnsqueezeOpInplace3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index b6e0241265b18377874efb0d223441994b4650d0..64049a93cb0a267722de9cd94961b6256551330d 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -14,6 +14,9 @@ import contextlib import os +import errno +import shutil +import time import core @@ -94,7 +97,7 @@ class EndStepEvent(object): class CheckpointConfig(object): """ - Parameter object for :code:`fluid.io.save_checkpoint` and + Parameter object for :code:`save_checkpoint` and :code:`fluid.Trainer`. Used to configuration how to save checkpoint. Args: @@ -237,7 +240,7 @@ class Trainer(object): self.checkpoint_cfg = checkpoint_config if self.checkpoint_cfg: assert isinstance(self.checkpoint_cfg, CheckpointConfig) - serial = io.get_latest_checkpoint_serial( + serial = _get_latest_checkpoint_serial( self.checkpoint_cfg.checkpoint_dir) self.checkpoint_cfg.load_serial = serial if serial >= 0 else None @@ -276,32 +279,15 @@ class Trainer(object): exe = executor.Executor(place) exe.run(self.startup_program) - if self.checkpoint_cfg and self.checkpoint_cfg.load_serial: - with self._prog_and_scope_guard(): - exe = executor.Executor(place) - io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, - self.startup_program) - - if not self.checkpoint_cfg.pserver_id: - epoch_id, step_id = io.load_trainer_args( - self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, self.trainer_id, - self._get_checkpoint_load_args()) - self.checkpoint_cfg.epoch_id = int(epoch_id) - self.checkpoint_cfg.step_id = int(step_id) - else: - if self.checkpoint_cfg.lookup_table_name: - io.load_lookup_table_vars( - exe, self.checkpoint_cfg.checkpoint_dir, - self.startup_program, - self.checkpoint_cfg.pserver_id, - self.checkpoint_cfg.lookup_table_name) + if self.checkpoint_cfg and self.checkpoint_cfg.load_serial is not None: + self._load_checkpoint() if param_path and os.path.isdir(param_path): # load params from param_path into scope - io.load_persist_vars_without_grad( - exe, dirname=param_path, program=self.startup_program) + io.load_persistables( + executor=exe, + dirname=param_path, + main_program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS @@ -549,7 +535,7 @@ class Trainer(object): def _clean_checkpoint(self): assert self.checkpoint_cfg - io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir) + clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir) def _get_checkpoint_load_args(self): """ @@ -572,7 +558,7 @@ class Trainer(object): if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ and step_id % self.checkpoint_cfg.step_interval == 0: exe = executor.Executor(self.place) - io.save_checkpoint( + save_checkpoint( executor=exe, checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, trainer_id=self.trainer_id, @@ -580,6 +566,41 @@ class Trainer(object): main_program=self.train_program, max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) + def _load_checkpoint(self): + with self._prog_and_scope_guard(): + exe = executor.Executor(self.place) + load_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + main_program=self.startup_program) + + if not self.checkpoint_cfg.pserver_id: + load_trainer_args = self._get_checkpoint_load_args() + trainer_args = load_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + main_program=self.startup_program, + role_id=self.trainer_id, + is_trainer=True, + load_trainer_args=load_trainer_args) + + if len(trainer_args) != 2: + raise ValueError( + "the return trainer_args length do not equal _get_checkpoint_load_args" + ) + self.checkpoint_cfg.epoch_id = int(trainer_args[0]) + self.checkpoint_cfg.step_id = int(trainer_args[1]) + else: + if self.checkpoint_cfg.lookup_table_name: + load_checkpoint( + executor=exe, + checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, + main_program=self.startup_program, + role_id=self.checkpoint_cfg.pserver_id, + is_trainer=False, + load_trainer_args=None, + load_lookup_table=self.checkpoint_cfg.lookup_table_name) + def build_feed_var_list(program, feed_order): if not isinstance(program, framework.Program): @@ -602,3 +623,610 @@ def build_feed_var_list(program, feed_order): program.global_block().var(pair[0]) for pair in sorted_pair_list ] return feed_var_list + + +# move Checkpoint APIs from io.py to trainer.py, make all of them are private. +SUCCESS_MARK_FILENAME = "_SUCCESS" +CHECKPOINT_PREFIX = "checkpoint" +MODEL_DIR = "__model__" +LOOKUP_TABLE_DIR = "__lookup_table__" +TRAINER_PREFIX = "trainer" +CHECKPOINT_SEPARATOR = "_" + + +def save_checkpoint(executor, + checkpoint_dir, + trainer_id, + main_program, + trainer_args=None, + max_num_checkpoints=3, + lookup_table=None, + pserver_endpoints=None): + """ + This function filters out all checkpoint variables from the give + main_program and then saves these variables to the `checkpoint_dir` + directory. + + In the training precess, we generally save a checkpoint in each + iteration. So there might be a lot of checkpoints in the + `checkpoint_dir`. To avoid them taking too much disk space, the + `max_num_checkpoints` are introduced to limit the total number of + checkpoints. If the number of existing checkpints is greater than + the `max_num_checkpoints`, oldest ones will be scroll deleted. + + A variable is a checkpoint variable and will be saved if it meets + all following conditions: + 1. It's persistable. + 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. + 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". + + Args: + executor(Executor): The executor to run for save checkpoint. + checkpoint_dir(str): The folder where to save checkpoints. + trainer_id(int): currect trainer id, if id is equal to 0, the trainer + is chief. + trainer_args(dict|None): Current training arguments. Such as 'epoch_id' + and 'step_id'. + Defaut: None + main_program(Program): The program whose checkpoint variables will + be saved. + max_num_checkpoints(int): The max number of total number of existing + checkpoints. + Default: 3 + lookup_table(string|None): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + pserver_endpoints(list|None): the parameter server ip:port list. + when use distribute lookup table, we can get pserver_endpoints by + distribute arguments. + + Returns: + None + + Raises: + ValueError: If `checkpoint_dir` is None. + AssertionError: If `trainer_args` is not a dict. + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + path = "./checkpoints" + prog = fluid.default_main_program() + trainer_args = {"epoch_id": 200, + "step_id": 20} # just an example + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + + save_checkpoint(executor=exe, + checkpoint_dir=path, + trainer_id=0, + trainer_args=trainer_args, + main_program=prog, + max_num_checkpoints=3, + lookup_table=table_name, + pserver_endpoints = ps_endpoints) + """ + if checkpoint_dir is None: + raise ValueError("'checkpoint_dir' should not be None") + + if main_program is None: + raise ValueError('main_program should not be None.') + + if trainer_args: + assert isinstance(trainer_args, dict) + + is_chief = trainer_id == 0 + + _make_chekcpoint_dirs(checkpoint_dir) + serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 + cur_dir = _get_serial_dir(checkpoint_dir, serial) + + _save_trainer_args(cur_dir, trainer_id, trainer_args) + + if is_chief: + _save_persist_vars_without_grad(executor, cur_dir, main_program) + + if is_chief and lookup_table and pserver_endpoints: + _save_pserver_vars_by_notify(executor, cur_dir, lookup_table, + pserver_endpoints) + + _scroll_delete(checkpoint_dir, max_num_checkpoints) + + +def load_checkpoint(executor, + checkpoint_dir, + main_program, + role_id=0, + is_trainer=True, + load_trainer_args=None, + load_lookup_table=None): + """ + This function filters out all checkpoint variables from the give + main_program and then try to load these variables from the + `checkpoint_dir` directory. + + In the training precess, we generally save a checkpoint in each + iteration. So there are more than one checkpoint in the + `checkpoint_dir` (each checkpoint has its own sub folder), use + `serial` to specify which serial of checkpoint you would like to + load. + + A variable is a checkpoint variable and will be loaded if it meets + all following conditions: + 1. It's persistable. + 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. + 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". + + Args: + executor(Executor): The executor to run for loading checkpoint. + checkpoint_dir(str): The folder where all checkpoints are. + serial(int): The serial of checkpoint you would like to load. + main_program(Program): The program whose checkpoint variables will + be loaded. + role_id(int): the trainer id or the parameter server id. + is_trainer(bool): trainer is True and parameter server is False. + load_trainer_args(list|None): list about load trainer args. + load_lookup_table(str|None): the lookup table name + + Returns: + None + + Raises: + ValueError: If `checkpoint_dir` is None. + ValueError: If `main_program` is None. + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + path = "./checkpoints" + prog = fluid.default_main_program() + load_checkpoint(executor=exe, checkpoint_dir=path, + serial=9, main_program=prog) + + # In this example, `load_checkpoint` function + # will first filters out all checkpoint variables in the default + # main program, and then try to load these variables form the + # folder "./checkpoints/checkpoint_9/__model__". + """ + + if checkpoint_dir is None: + raise ValueError("'checkpoint_dir' should not be None") + + serial = _get_latest_checkpoint_serial(checkpoint_dir) + + # there are nothing need to be loaded + if serial is None or serial < 0: + return + + if main_program is None: + raise ValueError('main_program should not be None.') + + if is_trainer and load_trainer_args is None: + cur_dir = _get_serial_dir(checkpoint_dir, serial) + _load_persist_vars_without_grad(executor, cur_dir, main_program, True) + return + + if is_trainer and load_trainer_args: + return _load_trainer_args(checkpoint_dir, serial, role_id, + load_trainer_args) + + if not is_trainer and load_lookup_table: + _load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id, + load_lookup_table) + + +def clean_checkpoint(checkpoint_dir, delete_dir=False): + """ + clean the checkpoint dir, when the train exits normally, + the trainer will call clean_checkpoint to delete checkpoint directory saved before. + delete_dir only works when the directory is empty, otherwise, OSError is raised. + + : param checkpoint_dir + : param delete_dir + """ + + if checkpoint_dir is None: + raise ValueError("'checkpoint_dir' should not be None") + _scroll_delete(checkpoint_dir, max_num_checkpoints=0) + + if delete_dir and not os.listdir(checkpoint_dir): + os.rmdir(checkpoint_dir) + + +def _load_persist_vars_without_grad(executor, + dirname, + program, + has_model_dir=False): + """ + This function filters out all checkpoint variables from the give + program and then trys to load these variables from the given directory. + + A variable is a checkpoint variable if it meets all following + conditions: + 1. It's persistable. + 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. + 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". + + Args: + executor(Executor): The executor to run for loading variables. + dirname(str): The directory path. + program(Program): The program whose checkpoint variables will + be loaded. + has_model_dir(bool): if True, the function loads variables + from a sub directory named '__model__'. + Default: False + + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + prog = fluid.default_main_program() + _load_persist_vars_without_grad(executor=exe, + dirname=param_path, program=prog, has_model_dir=True) + + # In this example, `_load_persist_vars_without_grad` function + # will first filters out all checkpoint variables in the default + # main program, and then trys to load these variables form the + # folder "./my_paddle_model/__model__". + """ + + if has_model_dir: + dirname = _get_model_dir(dirname) + + io.load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var, + filename=None) + + +def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): + """ + The parameter server will load lookup table's local file in + selectedrows variable. + + Args: + executor(Executor): The executor to run for loading persistable variables + dirname(str): The directory path + main_program(Program): Find the variable named table_name in main_program + pserver_id(int): the serial number in pserver_endpoints list + table_name(str): lookup table name + + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + dirname = "./checkpoints/checkpoint_9/" + prog = fluid.default_main_program() + pserver_id = 1 + table_name = "share_w" + _load_lookup_table_vars(executor=exe, + dirname=dirname, program=prog, pserver_id=pserver_id, + table_name=table_name) + """ + + for var in program.list_vars(): + if var.name == table_name: + lookup_table_var = var + break + + assert lookup_table_var is not None + + lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) + + load_prog = framework.Program() + load_block = load_prog.global_block() + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [lookup_table_var]}, + attrs={'file_path': os.path.join(lookup_table_dir, table_file)}) + + executor.run(load_prog) + + +def _save_persist_vars_without_grad(executor, dirname, program): + """ + This function filters out all checkpoint variables from the give + program and then save these variables to a sub-folder '__model__' of + the given directory. + + A variable is a checkpoint variable if it meets all following + conditions: + 1. It's persistable. + 2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW. + 3. It's name contains no "@GRAD" nor ".trainer_" nor ".block". + + Args: + executor(Executor): The executor to run for saving variables. + dirname(str): The directory path. + program(Program): The program whose checkpoint variables will + be saved. + + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + prog = fluid.default_main_program() + _save_persist_vars_without_grad(executor=exe, + dirname=param_path, program=prog) + + # In this example, `_save_persist_vars_without_grad` function + # will first filters out all checkpoint variables in the default + # main program, and then saves these variables to the folder + # "./my_paddle_model/__model__". + """ + cur_dir = _get_model_dir(dirname) + io.save_vars( + executor, + dirname=cur_dir, + main_program=program, + vars=None, + predicate=_is_checkpoint_var, + filename=None) + _write_success(cur_dir) + + +def _save_pserver_vars_by_notify(executor, dirname, lookup_table, + ps_endpoint_list): + """ + This function will send checkpoint notify message from Trainer 0 + to all the pservers. + The checkpoint notify message contains lookup table name, + the absolute path on pserver to save lookup_table. + + Args: + executor(Executor): The executor to run for send checkpoint notify. + dirname(str): The folder where to save checkpoints. + lookup_table(string): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. + Return: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + prog = fluid.default_main_program() + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + + _save_pserver_vars_by_notify(executor=exe, + dirname=param_path, lookup_table=table_name, + ps_endpoint_list=ps_endpoints) + """ + cur_dir = _get_lookuptable_dir(dirname) + + checkpoint_notify_program = framework.Program() + checkpoint_notify_block = checkpoint_notify_program.global_block() + + attrs = {} + attrs['epmap'] = ps_endpoint_list + attrs['dir'] = cur_dir + attrs['lookup_table'] = lookup_table + + checkpoint_notify_block.append_op( + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) + executor.run(checkpoint_notify_program) + + +def _save_trainer_args(dirname, trainer_id, trainer_args): + assert isinstance(trainer_args, dict) + + cur_dir = _get_trainer_dir(dirname, trainer_id) + + for name, value in trainer_args.iteritems(): + args_file = os.path.join(cur_dir, name) + with open(args_file, 'w') as f: + f.write(str(value)) + _write_success(cur_dir) + + +def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): + """ + trainer will load some args from it's independent directory, + such as epoch_id and step_id. + + Args: + checkpoint_dir(str): The folder where all checkpoints are. + serial(int): The serial of checkpoint you would like to load. + trainer_id(int): current trainer id. + trainer_args(list): list about load trainer args + Return: + None + + Examples: + .. code-block:: python + + param_path = "./checkpoint/" + serial = 7 + trainer_id = 2 + trainer_args = ["epoch_id", "step_id"] + + _load_trainer_args(checkpoint_dir=param_path, serial=serial, + trainer_id=trainer_id, trainer_args=trainer_args) + """ + assert isinstance(trainer_args, list) + + cur_dir = _get_serial_dir(checkpoint_dir, serial) + cur_dir = _get_trainer_dir(cur_dir, trainer_id) + + ret_values = [] + + for arg in trainer_args: + cur_file = os.path.join(cur_dir, arg) + with open(cur_file, 'r') as f: + contents = f.read() + ret_values.append(contents.strip()) + return ret_values + + +def _is_checkpoint_var(var): + """ + the checkpoint will not save or load all the variables. + var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. + + : param var(Variable) + """ + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW: + return False + # @GRAD are named for gradient variables, checkpoint will not save it. + if "@GRAD" in var.name: + return False + # .trainer_ are named for distribute train variables, checkpoint will not save it. + if ".trainer_" in var.name: + return False + + # .block is named for distribute train variables, checkpoint will not save it. + if ".block" in var.name: + return False + + return var.persistable + + +def _make_chekcpoint_dirs(dirs): + """ + _make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it. + """ + assert dirs is not None + + if os.path.isfile(dirs): + raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs) + + if not os.path.isdir(dirs): + try: + os.makedirs(dirs) + except OSError as err: + if err.errno != errno.EEXIST: + raise err + + +def _get_dir_serial(dirname): + _, serial = dirname.split(CHECKPOINT_SEPARATOR) + + try: + serial_num = int(serial) + except ValueError: + serial_num = -1 + return serial_num + + +def _get_serial_dir(dirname, serial): + serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) + serial_dir = os.path.join(dirname, serial_folder) + _make_chekcpoint_dirs(serial_dir) + + return serial_dir + + +def _get_model_dir(dirname): + model_dir = os.path.join(dirname, MODEL_DIR) + _make_chekcpoint_dirs(model_dir) + return model_dir + + +def _get_lookuptable_dir(dirname): + lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + _make_chekcpoint_dirs(lookuptable_dir) + return lookuptable_dir + + +def _get_trainer_dir(dirname, trainer_id): + trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) + trainer_dir = os.path.join(dirname, trainer_folder) + _make_chekcpoint_dirs(trainer_dir) + return trainer_dir + + +def _scroll_delete(dirname, max_num_checkpoints=3): + dirs = os.listdir(dirname) + serial_map = {} + for serial in dirs: + serial_num = _get_dir_serial(serial) + serial_map[serial_num] = serial + + if len(serial_map.keys()) <= max_num_checkpoints: + return + + serials = serial_map.keys() + serials.sort(reverse=True) + serials = serials[max_num_checkpoints:] + for serial in serials: + cur_dir = _get_serial_dir(dirname, serial) + try: + shutil.rmtree(cur_dir) + except OSError as err: + if err.errno != errno.ENOENT: + raise err + + +def _write_success(dirname): + """ + write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct. + + : param dirname + """ + success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) + with open(success_file, 'a') as f: + now = time.ctime() + f.write(now) + + +def _get_latest_checkpoint_serial(checkpoint_dir): + """ + get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory + + : param checkpoint_dir + """ + if not checkpoint_dir: + return -1 + + def has_success(checkpoint_dir, cur_dir): + """ + is _SUCCESS in this dir + """ + + serial = _get_dir_serial(cur_dir) + if serial == -1 or not os.path.isdir( + os.path.join(checkpoint_dir, cur_dir)): + return -1 + + success_path = os.path.join( + _get_serial_dir(checkpoint_dir, serial), MODEL_DIR, + SUCCESS_MARK_FILENAME) + if os.path.isfile(success_path): + return serial + + if not os.path.isdir(checkpoint_dir): + return -1 + + current_dir = -1 + dirs = os.listdir(checkpoint_dir) + for cur_dir in dirs: + success_num = has_success(checkpoint_dir, cur_dir) + if success_num > current_dir: + current_dir = success_num + return current_dir