提交 39d88ebc 编写于 作者: Q qiaolongfei

Merge branch 'fix-optimizer-accumulator' of...

Merge branch 'fix-optimizer-accumulator' of ssh://github.com/jacquesqiao/Paddle into distribute-transpiler-handle-adam-accumulator
......@@ -259,6 +259,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <vector>
namespace paddle {
namespace operators {
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
"Output(PreOut) should not be null.");
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
template <typename AttrType>
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size.");
AddInput("W",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is"
"[num_classes - 1, D].");
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"tensor with shape [N, 1].");
AddInput("Bias",
"(Tensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1].");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1].");
AddOutput("PreOut",
"(Tensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
.SetDefault(2);
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to calculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC");
}
};
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.)");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) {
bit_code.Add(pre_out, *bias);
}
bit_code.Mul(pre_out, *w, *in);
// clip to [-40, 40]
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat;
}
};
template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* label = ctx.Input<framework::Tensor>("Label");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
// softrelu derivative
pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) =
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code.AddGrad(pre_out_grad, bias_grad);
}
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
bit_code.MulGradError(pre_out_grad, *w, in_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
......
......@@ -155,7 +155,7 @@ class RowwiseSum<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
PADDLE_ENFORCE_EQ(out->numel(), height);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
namespace paddle {
namespace operators {
namespace math {
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0);
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order.
sm += tmat.data<T>()[i * o_width + j];
}
}
sum->data<T>()[i] = scale_sum * sm;
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat->data<T>();
auto weight_value = weight.data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] *
input_value[input_width * i + k];
}
tmat_value[i * tmat_width + j] += sum;
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor* input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight.data<T>();
auto input_value = input->data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] +=
tmat_value[i * tmat_width + j] *
weight_value[weight_width * index + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1;
}
}
}
}
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/**
* SimpleCodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int get_max_code_length()
* return the maximal code length
*
* SimpleCode operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* SimpleCode class should support 3 functions:
*
* int get_length()
* return the length of the code
*
* size_t cal_index(int bit)
* bit ranges from 0 to get_length() - 1
* return the index for the (1+bit) level parent
*
* bool calc_bit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline constexpr size_t FindLastSet(size_t x) {
return std::is_same<size_t, unsigned int>::value
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
: (std::is_same<size_t, unsigned long>::value // NOLINT
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
}
struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
* prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; }
private:
size_t c_;
};
struct SimpleCodeTable {
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
SimpleCode operator()(size_t code) const {
return SimpleCode(code, num_classes_);
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
size_t num_classes_;
};
template <typename T>
class MatrixBitCodeFunctor {
public:
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
/* For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
void Sub(framework::Tensor* tmat);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
const framework::Tensor& input);
/* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
const framework::Tensor& input);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, framework::Tensor* input);
size_t num_classes_;
const int64_t* ids_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -66,6 +66,14 @@ bool IsCompiledWithCUDA() {
#endif
}
bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DIST
return true;
#else
return false;
#endif
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");
......@@ -508,6 +516,7 @@ All parameter, weight, gradient are variables in Paddle.
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
......
......@@ -121,6 +121,9 @@ def __bootstrap__():
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem'
]
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
if core.is_compiled_with_cuda():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic'
......
......@@ -24,10 +24,7 @@ from . import core
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad',
'load_lookup_table_vars', 'save_persist_vars_without_grad',
'get_latest_checkpoint_serial'
'get_inference_program'
]
......@@ -794,588 +791,6 @@ def get_parameter_value_by_name(name, executor, program=None):
return get_parameter_value(var, executor)
SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
LOOKUP_TABLE_DIR = "__lookup_table__"
TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir,
trainer_id,
trainer_args=None,
main_program=None,
max_num_checkpoints=3,
lookup_table=None,
ps_endpoint_list=None):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program|None): The program whose checkpoint variables will
be saved. If it is None, the default main program will be used.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
AssertionError: If `trainer_args` is not a dict.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_checkpoint(executor=exe,
checkpoint_dir=path,
trainer_id=0,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
ps_endpoint_list = ps_endpoints)
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
assert checkpoint_dir
if trainer_args:
assert isinstance(trainer_args, dict)
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir)
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program)
if is_chief and lookup_table and ps_endpoint_list:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
A variable is a checkpoint variable and will be loaded if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `serial` is None or `serial` is less than 0.
ValueError: If `main_program` is None.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
fluid.io.load_checkpoint(executor=exe, checkpoint_dir=path,
serial=9, main_program=prog)
# In this example, `load_checkpoint` function
# will first filters out all checkpoint variables in the default
# main program, and then try to load these variables form the
# folder "./checkpoints/checkpoint_9/__model__".
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
if serial is None or serial < 0:
raise ValueError("'serial' should not be None or <0 ")
if main_program is None:
raise ValueError('main_program should not be None.')
cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program, True)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
_scroll_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)
def load_persist_vars_without_grad(executor,
dirname,
program,
has_model_dir=False):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be loaded.
has_model_dir(bool): if True, the function loads variables
from a sub directory named '__model__'.
Default: False
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
"""
if has_model_dir:
dirname = _get_model_dir(dirname)
load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)
def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for var in program.list_vars():
if var.name == table_name:
lookup_table_var = var
break
assert lookup_table_var is not None
lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id)
load_prog = Program()
load_block = load_prog.global_block()
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [lookup_table_var]},
attrs={'file_path': os.path.join(lookup_table_dir, table_file)})
executor.run(load_prog)
def save_persist_vars_without_grad(executor, dirname, program):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for saving variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be saved.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir = _get_model_dir(dirname)
save_vars(
executor,
dirname=cur_dir,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir = _get_lookuptable_dir(dirname)
checkpoint_notify_program = Program()
checkpoint_notify_block = checkpoint_notify_program.global_block()
attrs = {}
attrs['epmap'] = ps_endpoint_list
attrs['dir'] = cur_dir
attrs['lookup_table'] = lookup_table
checkpoint_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(checkpoint_notify_program)
def save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict)
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in trainer_args.iteritems():
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
_write_success(cur_dir)
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
ret_values = []
for arg in trainer_args:
cur_file = os.path.join(cur_dir, arg)
with open(cur_file, 'r') as f:
contents = f.read()
ret_values.append(contents.strip())
return ret_values
def _is_checkpoint_var(var):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False
# @GRAD are named for gradient variables, checkpoint will not save it.
if "@GRAD" in var.name:
return False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if ".trainer_" in var.name:
return False
# .block is named for distribute train variables, checkpoint will not save it.
if ".block" in var.name:
return False
return var.persistable
def _make_chekcpoint_dirs(dirs):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert dirs is not None
if os.path.isfile(dirs):
raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)
if not os.path.isdir(dirs):
try:
os.makedirs(dirs)
except OSError as err:
if err.errno != errno.EEXIST:
raise err
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
try:
serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num
def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)
_make_chekcpoint_dirs(serial_dir)
return serial_dir
def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)
_make_chekcpoint_dirs(model_dir)
return model_dir
def _get_lookuptable_dir(dirname):
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
_make_chekcpoint_dirs(lookuptable_dir)
return lookuptable_dir
def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)
_make_chekcpoint_dirs(trainer_dir)
return trainer_dir
def _scroll_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname)
serial_map = {}
for serial in dirs:
serial_num = _get_dir_serial(serial)
serial_map[serial_num] = serial
if len(serial_map.keys()) <= max_num_checkpoints:
return
serials = serial_map.keys()
serials.sort(reverse=True)
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = _get_serial_dir(dirname, serial)
try:
shutil.rmtree(cur_dir)
except OSError as err:
if err.errno != errno.ENOENT:
raise err
def _write_success(dirname):
"""
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
: param dirname
"""
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a') as f:
now = time.ctime()
f.write(now)
def get_latest_checkpoint_serial(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
: param checkpoint_dir
"""
if not checkpoint_dir:
return -1
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
serial = _get_dir_serial(cur_dir)
if serial == -1 or not os.path.isdir(
os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return serial
if not os.path.isdir(checkpoint_dir):
return -1
current_dir = -1
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return current_dir
def get_test_program(filelist, program=None, startup_program=None):
"""
Transpile current train program to a program to read test dataset
......
......@@ -85,6 +85,7 @@ __all__ = [
'transpose',
'im2sequence',
'nce',
'hsigmoid',
'beam_search',
'row_conv',
'multiplex',
......@@ -3871,6 +3872,74 @@ def nce(input,
return cost / (num_neg_samples + 1)
def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each
internal node on the path, and sum them to get a total cost. hsigmoid can
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the size of word dict.
Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
Args:
input (Variable): The input tensor variable with shape
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to False, no
bias will be applied.
Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='int64')
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=6)
"""
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
pre_out = helper.create_tmp_variable(dtype)
dim = input.shape[1]
if num_classes < 2:
raise ValueError("num_classes must not be less than 2.")
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
inputs = {"X": input, "W": weights, "Label": label}
if helper.bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out},
attrs={"num_classes": num_classes})
return out
def transpose(x, perm, name=None):
"""
Permute the dimensions of `input` according to `perm`.
......
......@@ -123,7 +123,7 @@ class Optimizer(object):
"""
pass
def _finish_update(self, block):
def _finish_update(self, block, parameters):
"""Finish any custom updates needed
before completing an optimization step
......@@ -132,7 +132,7 @@ class Optimizer(object):
parameters: list of parameter variables for the optimizer
Returns:
list of finish ops or None
None
"""
pass
......@@ -236,7 +236,8 @@ class Optimizer(object):
# Get custom finish ops for subclasses
# FIXME: Need to fix this once we figure out how to handle dependencies
self._finish_update(loss.block)
self._finish_update(loss.block,
[p[0] for p in parameters_and_grads])
end = len(global_block.ops)
return global_block.slice_ops(start, end)
......@@ -486,6 +487,8 @@ class AdamOptimizer(Optimizer):
"""
_moment1_acc_str = "moment1"
_moment2_acc_str = "moment2"
_beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc"
def __init__(self,
learning_rate=0.001,
......@@ -507,32 +510,22 @@ class AdamOptimizer(Optimizer):
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
# Create beta1 and beta2 power tensors
beta_shape = [1]
self._beta1_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta1_pow_acc'),
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
self.helper.set_variable_initializer(
self._beta1_pow_acc, initializer=Constant(self._beta1))
self._beta2_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta2_pow_acc'),
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
self.helper.set_variable_initializer(
self._beta2_pow_acc, initializer=Constant(self._beta2))
# Create accumulator tensors for first and second moments
for p in parameters:
self._add_accumulator(self._moment1_acc_str, p)
self._add_accumulator(self._moment2_acc_str, p)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype='float32',
fill_value=self._beta1,
shape=[1])
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype='float32',
fill_value=self._beta2,
shape=[1])
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......@@ -541,6 +534,11 @@ class AdamOptimizer(Optimizer):
param_and_grad[0])
moment2 = self._get_accumulator(self._moment2_acc_str,
param_and_grad[0])
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
# create the adam optimize op
adam_op = block.append_op(
type=self.type,
......@@ -550,8 +548,8 @@ class AdamOptimizer(Optimizer):
"LearningRate": self._create_param_lr(param_and_grad),
"Moment1": moment1,
"Moment2": moment2,
"Beta1Pow": self._beta1_pow_acc,
"Beta2Pow": self._beta2_pow_acc
"Beta1Pow": beta1_pow_acc,
"Beta2Pow": beta2_pow_acc
},
outputs={
"ParamOut": param_and_grad[0],
......@@ -566,24 +564,28 @@ class AdamOptimizer(Optimizer):
return adam_op
def _finish_update(self, block):
def _finish_update(self, block, parameters):
"""Update Beta1 and Beta2 Power accumulators
"""
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
scale_beta1 = main_block.append_op(
type="scale",
inputs={"X": self._beta1_pow_acc},
outputs={"Out": self._beta1_pow_acc},
attrs={"scale": self._beta1})
scale_beta2 = main_block.append_op(
type="scale",
inputs={"X": self._beta2_pow_acc},
outputs={"Out": self._beta2_pow_acc},
attrs={"scale": self._beta2})
return [scale_beta1, scale_beta2]
for param in parameters:
with param.block.program.optimized_guard(param):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param)
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param)
main_block.append_op(
type="scale",
inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1})
main_block.append_op(
type="scale",
inputs={"X": beta2_pow_acc},
outputs={"Out": beta2_pow_acc},
attrs={"scale": self._beta2})
class AdamaxOptimizer(Optimizer):
......@@ -626,6 +628,7 @@ class AdamaxOptimizer(Optimizer):
"""
_moment_acc_str = "moment"
_inf_norm_acc_str = "inf_norm"
_beta1_pow_acc_str = "beta1_pow_acc"
def __init__(self,
learning_rate=0.001,
......@@ -645,21 +648,16 @@ class AdamaxOptimizer(Optimizer):
self._epsilon = epsilon
def _create_accumulators(self, block, parameters):
# Create beta1 power accumulator tensor
beta_shape = [1]
self._beta1_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta1_pow_acc'),
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
self.helper.set_variable_initializer(
self._beta1_pow_acc, initializer=Constant(self._beta1))
# Create accumulator tensors for first moment and infinity norm
for p in parameters:
self._add_accumulator(self._moment_acc_str, p)
self._add_accumulator(self._inf_norm_acc_str, p)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype='float32',
fill_value=self._beta1,
shape=[1])
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......@@ -667,6 +665,8 @@ class AdamaxOptimizer(Optimizer):
moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0])
inf_norm = self._get_accumulator(self._inf_norm_acc_str,
param_and_grad[0])
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
# create the adamax optimize op
adamax_op = block.append_op(
type=self.type,
......@@ -676,7 +676,7 @@ class AdamaxOptimizer(Optimizer):
"LearningRate": self._create_param_lr(param_and_grad),
"Moment": moment,
"InfNorm": inf_norm,
"Beta1Pow": self._beta1_pow_acc
"Beta1Pow": beta1_pow_acc
},
outputs={
"ParamOut": param_and_grad[0],
......@@ -691,18 +691,20 @@ class AdamaxOptimizer(Optimizer):
return adamax_op
def _finish_update(self, block):
def _finish_update(self, block, parameters):
"""Update Beta1 Power accumulator
"""
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
scale_beta1 = main_block.append_op(
type="scale",
inputs={"X": self._beta1_pow_acc},
outputs={"Out": self._beta1_pow_acc},
attrs={"scale": self._beta1})
return [scale_beta1]
for param in parameters:
with param.block.program.optimized_guard(param):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param)
main_block.append_op(
type="scale",
inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1})
class DecayedAdagradOptimizer(Optimizer):
......@@ -1156,7 +1158,8 @@ class ModelAverage(Optimizer):
self.params_grads.append((param, grad))
for param, grad in self.params_grads:
self._append_average_accumulate_op(param)
with param.block.program.optimized_guard(param):
self._append_average_accumulate_op(param)
self.apply_program = Program()
block = self.apply_program.global_block()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import os
import tempfile
class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.dirname = tempfile.mktemp()
self.max_num_checkpoints = 3
self.epoch_interval = 1
self.step_interval = 1
self.trainer_id = 0
self.chief = self.trainer_id == 0
self.place = fluid.CPUPlace()
self.epoch_id = 100
self.step_id = 20
def test_checkpoint(self):
self.save_checkpoint()
serial = fluid.io.get_latest_checkpoint_serial(self.dirname)
self.assertTrue(serial >= 0)
trainer_args = ["epoch_id", "step_id"]
epoch_id, step_id = fluid.io.load_trainer_args(
self.dirname, serial, self.trainer_id, trainer_args)
self.assertEqual(self.step_id, int(step_id))
self.assertEqual(self.epoch_id, int(epoch_id))
program = fluid.Program()
with fluid.program_guard(program):
exe = fluid.Executor(self.place)
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
self.assertFalse(os.path.isdir(self.dirname))
def save_checkpoint(self):
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
self.epoch_interval, self.step_interval)
trainer_args = {}
trainer_args["epoch_id"] = self.epoch_id
trainer_args["step_id"] = self.step_id
program = fluid.Program()
with fluid.program_guard(program):
program.global_block().create_var(
name="scale_0",
psersistable=True,
dtype="float32",
shape=[32, 32])
exe = fluid.Executor(self.place)
for i in xrange(10):
fluid.io.save_checkpoint(exe, config.checkpoint_dir,
self.trainer_id, trainer_args, program,
config.max_num_checkpoints)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
def find_latest_set(num):
return 1 + int(math.floor(math.log(num, 2)))
class CodeTable(object):
def __init__(self, num_classes, code):
self.c = num_classes + code
def cal_index(self, bit):
return (self.c >> (bit + 1)) - 1
def get_length(self):
return find_latest_set(self.c) - 1
def cal_bit(self, bit):
return self.c & (1 << bit)
def hsigmoid(x, w, label, bias, num_classes):
batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1)
code_table = [0 for _ in range(code_length)]
pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx]
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
sum = 0.0
for j in range(length):
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
# soft relu
pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum
return pre_output, out
class TestHSigmoidOp(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.random.randint(0, num_classes, (batch_size, 1))
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
if __name__ == '__main__':
unittest.main()
......@@ -174,6 +174,16 @@ class TestBook(unittest.TestCase):
x=dat, label=lbl))
print(str(program))
def test_hsigmoid(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[2], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x, label=y, num_classes=2))
print(str(program))
def test_sequence_expand(self):
program = Program()
with program_guard(program):
......
......@@ -287,7 +287,7 @@ class TestAdamOptimizer(unittest.TestCase):
# Check accumulators
accumulators = adam_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 2)
self.assertEqual(len(accumulators), 4)
self.assertTrue(adam_optimizer.get_moment1_str() in accumulators)
self.assertTrue(adam_optimizer.get_moment2_str() in accumulators)
moment1_acc = accumulators[adam_optimizer.get_moment1_str()]
......@@ -354,7 +354,7 @@ class TestAdamaxOptimizer(unittest.TestCase):
# Check accumulators
accumulators = adamax_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 2)
self.assertEqual(len(accumulators), 3)
self.assertTrue(adamax_optimizer.get_moment_str() in accumulators)
self.assertTrue(adamax_optimizer.get_inf_norm_str() in accumulators)
moment_acc = accumulators[adamax_optimizer.get_moment_str()]
......
......@@ -14,6 +14,9 @@
import contextlib
import os
import errno
import shutil
import time
import core
......@@ -94,7 +97,7 @@ class EndStepEvent(object):
class CheckpointConfig(object):
"""
Parameter object for :code:`fluid.io.save_checkpoint` and
Parameter object for :code:`save_checkpoint` and
:code:`fluid.Trainer`. Used to configuration how to save checkpoint.
Args:
......@@ -237,7 +240,7 @@ class Trainer(object):
self.checkpoint_cfg = checkpoint_config
if self.checkpoint_cfg:
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
serial = io.get_latest_checkpoint_serial(
serial = _get_latest_checkpoint_serial(
self.checkpoint_cfg.checkpoint_dir)
self.checkpoint_cfg.load_serial = serial if serial >= 0 else None
......@@ -276,32 +279,15 @@ class Trainer(object):
exe = executor.Executor(place)
exe.run(self.startup_program)
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard():
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint_cfg.pserver_id:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir,
self.startup_program,
self.checkpoint_cfg.pserver_id,
self.checkpoint_cfg.lookup_table_name)
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial is not None:
self._load_checkpoint()
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
io.load_persist_vars_without_grad(
exe, dirname=param_path, program=self.startup_program)
io.load_persistables(
executor=exe,
dirname=param_path,
main_program=self.startup_program)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
......@@ -549,7 +535,7 @@ class Trainer(object):
def _clean_checkpoint(self):
assert self.checkpoint_cfg
io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
def _get_checkpoint_load_args(self):
"""
......@@ -572,7 +558,7 @@ class Trainer(object):
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
and step_id % self.checkpoint_cfg.step_interval == 0:
exe = executor.Executor(self.place)
io.save_checkpoint(
save_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
trainer_id=self.trainer_id,
......@@ -580,6 +566,41 @@ class Trainer(object):
main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def _load_checkpoint(self):
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program)
if not self.checkpoint_cfg.pserver_id:
load_trainer_args = self._get_checkpoint_load_args()
trainer_args = load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.trainer_id,
is_trainer=True,
load_trainer_args=load_trainer_args)
if len(trainer_args) != 2:
raise ValueError(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
self.checkpoint_cfg.step_id = int(trainer_args[1])
else:
if self.checkpoint_cfg.lookup_table_name:
load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_trainer_args=None,
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program):
......@@ -602,3 +623,610 @@ def build_feed_var_list(program, feed_order):
program.global_block().var(pair[0]) for pair in sorted_pair_list
]
return feed_var_list
# move Checkpoint APIs from io.py to trainer.py, make all of them are private.
SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
LOOKUP_TABLE_DIR = "__lookup_table__"
TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir,
trainer_id,
main_program,
trainer_args=None,
max_num_checkpoints=3,
lookup_table=None,
pserver_endpoints=None):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
distribute arguments.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
AssertionError: If `trainer_args` is not a dict.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
save_checkpoint(executor=exe,
checkpoint_dir=path,
trainer_id=0,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
pserver_endpoints = ps_endpoints)
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
if main_program is None:
raise ValueError('main_program should not be None.')
if trainer_args:
assert isinstance(trainer_args, dict)
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir)
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
_save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief:
_save_persist_vars_without_grad(executor, cur_dir, main_program)
if is_chief and lookup_table and pserver_endpoints:
_save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
pserver_endpoints)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor,
checkpoint_dir,
main_program,
role_id=0,
is_trainer=True,
load_trainer_args=None,
load_lookup_table=None):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
A variable is a checkpoint variable and will be loaded if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `main_program` is None.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
load_checkpoint(executor=exe, checkpoint_dir=path,
serial=9, main_program=prog)
# In this example, `load_checkpoint` function
# will first filters out all checkpoint variables in the default
# main program, and then try to load these variables form the
# folder "./checkpoints/checkpoint_9/__model__".
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
serial = _get_latest_checkpoint_serial(checkpoint_dir)
# there are nothing need to be loaded
if serial is None or serial < 0:
return
if main_program is None:
raise ValueError('main_program should not be None.')
if is_trainer and load_trainer_args is None:
cur_dir = _get_serial_dir(checkpoint_dir, serial)
_load_persist_vars_without_grad(executor, cur_dir, main_program, True)
return
if is_trainer and load_trainer_args:
return _load_trainer_args(checkpoint_dir, serial, role_id,
load_trainer_args)
if not is_trainer and load_lookup_table:
_load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id,
load_lookup_table)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
"""
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
_scroll_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)
def _load_persist_vars_without_grad(executor,
dirname,
program,
has_model_dir=False):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be loaded.
has_model_dir(bool): if True, the function loads variables
from a sub directory named '__model__'.
Default: False
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
"""
if has_model_dir:
dirname = _get_model_dir(dirname)
io.load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
_load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for var in program.list_vars():
if var.name == table_name:
lookup_table_var = var
break
assert lookup_table_var is not None
lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id)
load_prog = framework.Program()
load_block = load_prog.global_block()
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [lookup_table_var]},
attrs={'file_path': os.path.join(lookup_table_dir, table_file)})
executor.run(load_prog)
def _save_persist_vars_without_grad(executor, dirname, program):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for saving variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be saved.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `_save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir = _get_model_dir(dirname)
io.save_vars(
executor,
dirname=cur_dir,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir = _get_lookuptable_dir(dirname)
checkpoint_notify_program = framework.Program()
checkpoint_notify_block = checkpoint_notify_program.global_block()
attrs = {}
attrs['epmap'] = ps_endpoint_list
attrs['dir'] = cur_dir
attrs['lookup_table'] = lookup_table
checkpoint_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(checkpoint_notify_program)
def _save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict)
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in trainer_args.iteritems():
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
_write_success(cur_dir)
def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
_load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
ret_values = []
for arg in trainer_args:
cur_file = os.path.join(cur_dir, arg)
with open(cur_file, 'r') as f:
contents = f.read()
ret_values.append(contents.strip())
return ret_values
def _is_checkpoint_var(var):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False
# @GRAD are named for gradient variables, checkpoint will not save it.
if "@GRAD" in var.name:
return False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if ".trainer_" in var.name:
return False
# .block is named for distribute train variables, checkpoint will not save it.
if ".block" in var.name:
return False
return var.persistable
def _make_chekcpoint_dirs(dirs):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert dirs is not None
if os.path.isfile(dirs):
raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)
if not os.path.isdir(dirs):
try:
os.makedirs(dirs)
except OSError as err:
if err.errno != errno.EEXIST:
raise err
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
try:
serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num
def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)
_make_chekcpoint_dirs(serial_dir)
return serial_dir
def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)
_make_chekcpoint_dirs(model_dir)
return model_dir
def _get_lookuptable_dir(dirname):
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
_make_chekcpoint_dirs(lookuptable_dir)
return lookuptable_dir
def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)
_make_chekcpoint_dirs(trainer_dir)
return trainer_dir
def _scroll_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname)
serial_map = {}
for serial in dirs:
serial_num = _get_dir_serial(serial)
serial_map[serial_num] = serial
if len(serial_map.keys()) <= max_num_checkpoints:
return
serials = serial_map.keys()
serials.sort(reverse=True)
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = _get_serial_dir(dirname, serial)
try:
shutil.rmtree(cur_dir)
except OSError as err:
if err.errno != errno.ENOENT:
raise err
def _write_success(dirname):
"""
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
: param dirname
"""
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a') as f:
now = time.ctime()
f.write(now)
def _get_latest_checkpoint_serial(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
: param checkpoint_dir
"""
if not checkpoint_dir:
return -1
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
serial = _get_dir_serial(cur_dir)
if serial == -1 or not os.path.isdir(
os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return serial
if not os.path.isdir(checkpoint_dir):
return -1
current_dir = -1
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return current_dir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册