未验证 提交 4ff1bde5 编写于 作者: Y yuyang18

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/hide_api_cont

......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
......@@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
}
}
std::vector<framework::LoDTensor> fetch_data;
std::exception_ptr eptr;
try {
fetch_data = underlying_executor_->Run(fetch_tensors);
} catch (...) {
eptr = std::current_exception();
}
auto fetch_data = underlying_executor_->Run(fetch_tensors);
drop_scope_counter_ += 1;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
......@@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
scope->DeleteScope(local_scope);
}
}
return fetch_data;
if (eptr) {
std::rethrow_exception(eptr);
} else {
return fetch_data;
}
}
} // namespace details
} // namespace framework
......
......@@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
set.clear();
};
// Clean run context
run_op_futures_.clear();
exception_.reset();
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
......@@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) {
std::lock_guard<std::mutex> l(exception_mu_);
std::unique_lock<std::mutex> l(exception_mu_);
if (exception_) {
l.unlock();
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
l.lock();
std::exception *exp = exception_.get();
if (dynamic_cast<platform::EOFException *>(exp)) {
auto e = *static_cast<platform::EOFException *>(exp);
exception_.reset();
throw e;
} else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
auto e = *static_cast<platform::EnforceNotMet *>(exp);
exception_.reset();
throw e;
} else {
LOG(FATAL) << "Unknown exception.";
......@@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp(
}
};
if (pool_) {
pool_->enqueue(op_run);
run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else {
op_run();
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <deque>
#include <list>
#include <string>
#include <unordered_set>
#include <utility>
......@@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private:
ExecutionStrategy strategy_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
};
} // namespace details
......
......@@ -95,7 +95,7 @@ ParallelExecutor::ParallelExecutor(
}
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToGPUs(bcast_vars);
BCastParamsToDevs(bcast_vars);
}
// Startup Program has been run. All local scopes has correct parameters.
......@@ -131,7 +131,7 @@ ParallelExecutor::ParallelExecutor(
member_->places_, std::move(member_->executor_)));
}
void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::BCastParamsToDevs(
const std::unordered_set<std::string> &vars) const {
// the the initializing bcast, all vars would be bcast from device(0),
// otherwise
......@@ -202,7 +202,11 @@ void ParallelExecutor::BCastParamsToGPUs(
#endif
} else {
platform::CPUPlace cpu;
for (size_t i = 1; i < member_->places_.size(); ++i) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
if ((initializing && i == 0) ||
(!initializing && static_cast<int>(i) == var_dev_id))
continue;
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
t->Resize(dims);
......
......@@ -66,7 +66,7 @@ class ParallelExecutor {
void Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name);
void BCastParamsToGPUs(const std::unordered_set<std::string> &vars) const;
void BCastParamsToDevs(const std::unordered_set<std::string> &vars) const;
private:
ParallelExecutorPrivate *member_;
......
......@@ -29,11 +29,11 @@ enum ReaderStatus { kRunning, kStopped };
class ReaderBase {
public:
void ReadNext(std::vector<LoDTensor>* out);
virtual void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
virtual void Shutdown();
void Start();
virtual void Start();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
......@@ -42,7 +42,7 @@ class ReaderBase {
virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ReadNextImpl(std::vector<LoDTensor>* out) {}
virtual void ShutdownImpl() {}
......
......@@ -259,12 +259,15 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <vector>
namespace paddle {
namespace operators {
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
"Output(PreOut) should not be null.");
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
template <typename AttrType>
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size.");
AddInput("W",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is"
"[num_classes - 1, D].");
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"tensor with shape [N, 1].");
AddInput("Bias",
"(Tensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1].");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1].");
AddOutput("PreOut",
"(Tensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
.SetDefault(2);
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to calculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC");
}
};
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.)");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) {
bit_code.Add(pre_out, *bias);
}
bit_code.Mul(pre_out, *w, *in);
// clip to [-40, 40]
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat;
}
};
template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* label = ctx.Input<framework::Tensor>("Label");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
// softrelu derivative
pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) =
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code.AddGrad(pre_out_grad, bias_grad);
}
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
bit_code.MulGradError(pre_out_grad, *w, in_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
......
......@@ -155,7 +155,7 @@ class RowwiseSum<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
PADDLE_ENFORCE_EQ(out->numel(), height);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
namespace paddle {
namespace operators {
namespace math {
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0);
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order.
sm += tmat.data<T>()[i * o_width + j];
}
}
sum->data<T>()[i] = scale_sum * sm;
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat->data<T>();
auto weight_value = weight.data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] *
input_value[input_width * i + k];
}
tmat_value[i * tmat_width + j] += sum;
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor* input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight.data<T>();
auto input_value = input->data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] +=
tmat_value[i * tmat_width + j] *
weight_value[weight_width * index + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1;
}
}
}
}
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/**
* SimpleCodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int get_max_code_length()
* return the maximal code length
*
* SimpleCode operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* SimpleCode class should support 3 functions:
*
* int get_length()
* return the length of the code
*
* size_t cal_index(int bit)
* bit ranges from 0 to get_length() - 1
* return the index for the (1+bit) level parent
*
* bool calc_bit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline constexpr size_t FindLastSet(size_t x) {
return std::is_same<size_t, unsigned int>::value
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
: (std::is_same<size_t, unsigned long>::value // NOLINT
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
}
struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
* prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; }
private:
size_t c_;
};
struct SimpleCodeTable {
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
SimpleCode operator()(size_t code) const {
return SimpleCode(code, num_classes_);
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
size_t num_classes_;
};
template <typename T>
class MatrixBitCodeFunctor {
public:
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
/* For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
void Sub(framework::Tensor* tmat);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
const framework::Tensor& input);
/* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
const framework::Tensor& input);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, framework::Tensor* input);
size_t num_classes_;
const int64_t* ids_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -81,6 +81,15 @@ class BlockingQueue {
}
}
void ReOpen() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = false;
std::deque<T> new_deque;
queue_.swap(new_deque);
send_cv_.notify_all();
receive_cv_.notify_all();
}
void Close() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
......
......@@ -23,7 +23,7 @@ class BatchReader : public framework::DecoratedReader {
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
bool discard_leftover)
: DecoratedReader(reader),
batch_size_(batch_size),
batch_size_(static_cast<size_t>(batch_size)),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_);
}
......@@ -31,7 +31,7 @@ class BatchReader : public framework::DecoratedReader {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
int batch_size_;
size_t batch_size_;
bool discard_leftover_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
};
......@@ -78,7 +78,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
for (size_t i = 0; i < batch_size_; ++i) {
buffer_.push_back(std::vector<framework::LoDTensor>());
reader_->ReadNext(&buffer_.back());
if (buffer_.back().empty()) {
......@@ -95,9 +95,9 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
// if buffer_ is empty, the 'out' will return as an empty vector.
return;
}
int out_num = buffer_[0].size();
size_t out_num = buffer_[0].size();
out->reserve(out_num);
for (int j = 0; j < out_num; ++j) {
for (size_t j = 0; j < out_num; ++j) {
// Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type();
framework::DDim batch_shape = buffer_[0][j].dims();
......
......@@ -27,19 +27,17 @@ class PyReader : public framework::FileReader {
queue_ = queue;
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
*out = queue_->Pop(&success);
if (!success) out->clear();
}
private:
void ShutdownImpl() override { /* TODO */
}
void Shutdown() override { queue_->Close(); }
void StartImpl() override { /* TODO */
}
void Start() override { queue_->ReOpen(); }
private:
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
......
......@@ -58,12 +58,15 @@ class LoDTensorBlockingQueue {
inline size_t Size() const { return queue_.Size(); }
inline void Close() { return queue_.Close(); }
inline void ReOpen() { queue_.ReOpen(); }
inline void Close() { queue_.Close(); }
inline bool IsClosed() const { return queue_.IsClosed(); }
private:
void CheckDims(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
void CheckDims(
const std::vector<framework::LoDTensor>& lod_tensor_vec) const {
PADDLE_ENFORCE(dims_.size() == lod_tensor_vec.size(),
"Expect input size is %d but found %s", dims_.size(),
lod_tensor_vec.size());
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class SqueezeOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SqueezeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SqueezeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The squeeze axis should be less than input "
"tensor's rank.");
}
auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
}
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
}
}
} else {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0,
"Invalid axis, the negative axis is out of range.");
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true;
}
}
// Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
}
return framework::make_ddim(output_shape);
}
};
class SqueezeOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of integers,"
" indicating the dimensions to squeeze.")
.SetDefault({});
AddAttr<bool>("inplace",
"(default: false) Squeeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).")
.SetDefault(false);
AddComment(R"DOC(
Squeeze Operator.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Case 1:
Given
X.shape = (1, 3, 1, 5)
and
axes = [0]
we get:
Out.shape = (3, 1, 5)
Case 2:
Given
X.shape = (1, 3, 1, 5)
and
axes = []
we get:
Out.shape = (3, 5)
)DOC");
}
};
class SqueezeGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X"));
context->ShareLoD("X", framework::GradVarName("X"));
}
};
class SqueezeGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
// Tell linker to use reshape op
USE_OP(reshape);
namespace ops = paddle::operators;
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class UnsqueezeOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnsqueezeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnsqueezeOp should not be null.");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)");
auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
}
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim &in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range.
PADDLE_ENFORCE(output_size <= 6,
"The output tensor's rank should be less than 6.");
for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE(
cur >= 0 && cur <= cur_output_size,
"The unsqueeze dims must be within range of current rank.");
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
}
// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
}
return framework::make_ddim(output_shape);
}
};
class UnsqueezeOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of integers,"
" indicating the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) {
PADDLE_ENFORCE(!axes.empty(),
"Invalid axes, The unsqueeze axes is empty.");
// Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
"Invalid dimensions, dynamic dimensions should be "
"within [1, 6] dimensions (Eigen limit).");
// Validity Check: the range of unsqueeze aixs.
for (int axis : axes) {
PADDLE_ENFORCE(axis < 6,
"Invalid dimensions, input axis should be"
" within [1, 6] dimensions (Eigen limit).");
}
});
AddAttr<bool>(
"inplace",
"(default: false) Unsqueeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).")
.SetDefault(false);
AddComment(R"DOC(
Unsqueeze Operator.
Insert single-dimensional entries to the shape of a tensor.
Takes one required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example:
Given a tensor such that tensor with shape [3, 4, 5],
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
)DOC");
}
};
class UnsqueezeGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
class UnsqueezeGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
// Tell linker to use reshape op.
USE_OP(reshape);
namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradInferShape);
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <Python.h>
#include <algorithm>
#include <map>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <string>
#include <unordered_map>
......@@ -66,6 +67,14 @@ bool IsCompiledWithCUDA() {
#endif
}
bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DIST
return true;
#else
return false;
#endif
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");
......@@ -302,7 +311,8 @@ All parameter, weight, gradient are variables in Paddle.
::paddle::operators::reader::LoDTensorBlockingQueue;
using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue>(m, "LoDTensorBlockingQueue", "")
py::class_<LoDTensorBlockingQueue, std::shared_ptr<LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
......@@ -317,7 +327,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_lod_tensor_blocking_queue",
[](Variable &var, size_t capacity,
const std::vector<std::vector<int64_t>> &shapes)
-> LoDTensorBlockingQueue * {
-> std::shared_ptr<LoDTensorBlockingQueue> {
std::vector<DDim> dims(shapes.size());
std::transform(shapes.begin(), shapes.end(), dims.begin(),
[](const std::vector<int64_t> &shape) {
......@@ -325,9 +335,9 @@ All parameter, weight, gradient are variables in Paddle.
});
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, dims);
return holder->GetQueue().get();
return holder->GetQueue();
},
py::return_value_policy::reference);
py::return_value_policy::copy);
py::class_<Scope>(m, "Scope", "")
.def("var",
......@@ -508,6 +518,7 @@ All parameter, weight, gradient are variables in Paddle.
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
......@@ -534,6 +545,8 @@ All parameter, weight, gradient are variables in Paddle.
});
py::class_<LoDTensorArray>(m, "LoDTensorArray")
.def("__init__",
[](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); })
.def("__getitem__",
[](LoDTensorArray &self, size_t i) { return &self.at(i); },
py::return_value_policy::reference)
......@@ -656,7 +669,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t,
size_t>())
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
.def("bcast_params", &ParallelExecutor::BCastParamsToDevs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
......@@ -92,8 +92,15 @@ install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
DESTINATION opt/paddle/share/wheels
)
find_program(PATCHELF_EXECUTABLE patchelf)
if(NOT PATCHELF_EXECUTABLE)
message(FATAL_ERROR "patchelf not found, please install it.\n"
"For Ubuntu, the command is: apt-get install -y patchelf.")
endif()
if(APPLE)
find_program(INSTALL_NAME_TOOL_EXECUTABLE install_name_tool)
if(NOT INSTALL_NAME_TOOL_EXECUTABLE)
message(FATAL_ERROR "install_name_tool not found, please check.\n")
endif()
else(APPLE)
find_program(PATCHELF_EXECUTABLE patchelf)
if(NOT PATCHELF_EXECUTABLE)
message(FATAL_ERROR "patchelf not found, please install it.\n"
"For Ubuntu, the command is: apt-get install -y patchelf.")
endif()
endif(APPLE)
......@@ -44,7 +44,7 @@ import metrics
import transpiler
from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
from core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
from transpiler import DistributeTranspiler, InferenceTranspiler, \
memory_optimize, release_memory
from concurrency import (Go, make_channel, channel_send, channel_recv,
......@@ -65,13 +65,14 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
'io',
'initializer',
'layers',
'transpiler'
'transpiler',
'nets',
'optimizer',
'learning_rate_decay',
'backward',
'regularizer',
'LoDTensor',
'LoDTensorArray',
'CPUPlace',
'CUDAPlace',
'CUDAPinnedPlace',
......@@ -121,6 +122,9 @@ def __bootstrap__():
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem'
]
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
if core.is_compiled_with_cuda():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic'
......
此差异已折叠。
......@@ -24,7 +24,8 @@ from layer_function_generator import generate_layer_fn, templatedoc
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
'double_buffer', 'random_data_generator', 'py_reader', 'Preprocessor',
'load'
]
......@@ -445,6 +446,88 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var)
def py_reader(capacity, shapes, dtypes, lod_levels=None):
"""
Create a reader and blocking queue for data feeding in Python
This layer returns a Reader Variable and a BlockingQueue.
The BlockingQueue provides `push()` method to push a `LoDTensorArray`
object into the queue in Python side. In C++ side, the Reader
Variable would invoke `pop()` method of the queue to retrieve the
feeding data. The process of feeding data in Python side and fetching
data in C++ side can run in parallel. The BlockingQueue should be closed
using `close()` method when unused.
Args:
capacity(int): The maximum capacity of the BlockingQueue.
shapes(list): List of tuples which declaring data shapes.
dtypes(list): List of strs which declaring data type.
lod_levels(list): List of ints which declaring data lod_level.
Returns:
tuple(Variable, BlockingQueue):
A Reader Variable from which we can get feeding data.
A BlockingQueue object for data feeding.
Examples:
.. code-block:: python
reader, queue = fluid.layers.py_reader(
capacity=10,
shapes=[[-1,3,224,224], [-1,1]],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
# Via the blocking queue, we can feed data using threads
def feed_data(queue, feed_images, feed_labels):
for feed_image, feed_label in zip(feed_images, feed_labels):
data = core.LoDTensorArray()
data.append(feed_image)
data.append(feed_label)
queue.push(data)
thread = threading.Thread(target=feed_data, args=(queue, feed_images, feed_labels))
thread.start()
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
if lod_levels is None:
lod_levels = [0] * len(shapes)
queue_name = unique_name('lod_tensor_blocking_queue')
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=unique_name('create_py_reader'))
startup_blk.append_op(
type='create_py_reader',
inputs={'blocking_queue': queue_name},
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var), feed_queue
def open_files(filenames,
shapes,
lod_levels,
......
......@@ -85,6 +85,7 @@ __all__ = [
'transpose',
'im2sequence',
'nce',
'hsigmoid',
'beam_search',
'row_conv',
'multiplex',
......@@ -3871,6 +3872,74 @@ def nce(input,
return cost / (num_neg_samples + 1)
def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each
internal node on the path, and sum them to get a total cost. hsigmoid can
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the size of word dict.
Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
Args:
input (Variable): The input tensor variable with shape
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to False, no
bias will be applied.
Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='int64')
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=6)
"""
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
pre_out = helper.create_tmp_variable(dtype)
dim = input.shape[1]
if num_classes < 2:
raise ValueError("num_classes must not be less than 2.")
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
inputs = {"X": input, "W": weights, "Label": label}
if helper.bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out},
attrs={"num_classes": num_classes})
return out
def transpose(x, perm, name=None):
"""
Permute the dimensions of `input` according to `perm`.
......
......@@ -123,7 +123,7 @@ class Optimizer(object):
"""
pass
def _finish_update(self, block):
def _finish_update(self, block, parameters):
"""Finish any custom updates needed
before completing an optimization step
......@@ -132,7 +132,7 @@ class Optimizer(object):
parameters: list of parameter variables for the optimizer
Returns:
list of finish ops or None
None
"""
pass
......@@ -236,7 +236,8 @@ class Optimizer(object):
# Get custom finish ops for subclasses
# FIXME: Need to fix this once we figure out how to handle dependencies
self._finish_update(loss.block)
self._finish_update(loss.block,
[p[0] for p in parameters_and_grads])
end = len(global_block.ops)
return global_block.slice_ops(start, end)
......@@ -486,6 +487,8 @@ class AdamOptimizer(Optimizer):
"""
_moment1_acc_str = "moment1"
_moment2_acc_str = "moment2"
_beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc"
def __init__(self,
learning_rate=0.001,
......@@ -507,32 +510,22 @@ class AdamOptimizer(Optimizer):
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
# Create beta1 and beta2 power tensors
beta_shape = [1]
self._beta1_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta1_pow_acc'),
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
self.helper.set_variable_initializer(
self._beta1_pow_acc, initializer=Constant(self._beta1))
self._beta2_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta2_pow_acc'),
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
self.helper.set_variable_initializer(
self._beta2_pow_acc, initializer=Constant(self._beta2))
# Create accumulator tensors for first and second moments
for p in parameters:
self._add_accumulator(self._moment1_acc_str, p)
self._add_accumulator(self._moment2_acc_str, p)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype='float32',
fill_value=self._beta1,
shape=[1])
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype='float32',
fill_value=self._beta2,
shape=[1])
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......@@ -541,6 +534,11 @@ class AdamOptimizer(Optimizer):
param_and_grad[0])
moment2 = self._get_accumulator(self._moment2_acc_str,
param_and_grad[0])
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
# create the adam optimize op
adam_op = block.append_op(
type=self.type,
......@@ -550,8 +548,8 @@ class AdamOptimizer(Optimizer):
"LearningRate": self._create_param_lr(param_and_grad),
"Moment1": moment1,
"Moment2": moment2,
"Beta1Pow": self._beta1_pow_acc,
"Beta2Pow": self._beta2_pow_acc
"Beta1Pow": beta1_pow_acc,
"Beta2Pow": beta2_pow_acc
},
outputs={
"ParamOut": param_and_grad[0],
......@@ -566,24 +564,28 @@ class AdamOptimizer(Optimizer):
return adam_op
def _finish_update(self, block):
def _finish_update(self, block, parameters):
"""Update Beta1 and Beta2 Power accumulators
"""
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
scale_beta1 = main_block.append_op(
type="scale",
inputs={"X": self._beta1_pow_acc},
outputs={"Out": self._beta1_pow_acc},
attrs={"scale": self._beta1})
scale_beta2 = main_block.append_op(
type="scale",
inputs={"X": self._beta2_pow_acc},
outputs={"Out": self._beta2_pow_acc},
attrs={"scale": self._beta2})
return [scale_beta1, scale_beta2]
for param in parameters:
with param.block.program.optimized_guard(param):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param)
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param)
main_block.append_op(
type="scale",
inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1})
main_block.append_op(
type="scale",
inputs={"X": beta2_pow_acc},
outputs={"Out": beta2_pow_acc},
attrs={"scale": self._beta2})
class AdamaxOptimizer(Optimizer):
......@@ -626,6 +628,7 @@ class AdamaxOptimizer(Optimizer):
"""
_moment_acc_str = "moment"
_inf_norm_acc_str = "inf_norm"
_beta1_pow_acc_str = "beta1_pow_acc"
def __init__(self,
learning_rate=0.001,
......@@ -645,21 +648,16 @@ class AdamaxOptimizer(Optimizer):
self._epsilon = epsilon
def _create_accumulators(self, block, parameters):
# Create beta1 power accumulator tensor
beta_shape = [1]
self._beta1_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta1_pow_acc'),
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
self.helper.set_variable_initializer(
self._beta1_pow_acc, initializer=Constant(self._beta1))
# Create accumulator tensors for first moment and infinity norm
for p in parameters:
self._add_accumulator(self._moment_acc_str, p)
self._add_accumulator(self._inf_norm_acc_str, p)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype='float32',
fill_value=self._beta1,
shape=[1])
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......@@ -667,6 +665,8 @@ class AdamaxOptimizer(Optimizer):
moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0])
inf_norm = self._get_accumulator(self._inf_norm_acc_str,
param_and_grad[0])
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
# create the adamax optimize op
adamax_op = block.append_op(
type=self.type,
......@@ -676,7 +676,7 @@ class AdamaxOptimizer(Optimizer):
"LearningRate": self._create_param_lr(param_and_grad),
"Moment": moment,
"InfNorm": inf_norm,
"Beta1Pow": self._beta1_pow_acc
"Beta1Pow": beta1_pow_acc
},
outputs={
"ParamOut": param_and_grad[0],
......@@ -691,18 +691,20 @@ class AdamaxOptimizer(Optimizer):
return adamax_op
def _finish_update(self, block):
def _finish_update(self, block, parameters):
"""Update Beta1 Power accumulator
"""
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
scale_beta1 = main_block.append_op(
type="scale",
inputs={"X": self._beta1_pow_acc},
outputs={"Out": self._beta1_pow_acc},
attrs={"scale": self._beta1})
return [scale_beta1]
for param in parameters:
with param.block.program.optimized_guard(param):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param)
main_block.append_op(
type="scale",
inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1})
class DecayedAdagradOptimizer(Optimizer):
......@@ -1156,7 +1158,8 @@ class ModelAverage(Optimizer):
self.params_grads.append((param, grad))
for param, grad in self.params_grads:
self._append_average_accumulate_op(param)
with param.block.program.optimized_guard(param):
self._append_average_accumulate_op(param)
self.apply_program = Program()
block = self.apply_program.global_block()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import os
import tempfile
class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.dirname = tempfile.mktemp()
self.max_num_checkpoints = 3
self.epoch_interval = 1
self.step_interval = 1
self.trainer_id = 0
self.chief = self.trainer_id == 0
self.place = fluid.CPUPlace()
self.epoch_id = 100
self.step_id = 20
def test_checkpoint(self):
self.save_checkpoint()
serial = fluid.io.get_latest_checkpoint_serial(self.dirname)
self.assertTrue(serial >= 0)
trainer_args = ["epoch_id", "step_id"]
epoch_id, step_id = fluid.io.load_trainer_args(
self.dirname, serial, self.trainer_id, trainer_args)
self.assertEqual(self.step_id, int(step_id))
self.assertEqual(self.epoch_id, int(epoch_id))
program = fluid.Program()
with fluid.program_guard(program):
exe = fluid.Executor(self.place)
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
self.assertFalse(os.path.isdir(self.dirname))
def save_checkpoint(self):
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
self.epoch_interval, self.step_interval)
trainer_args = {}
trainer_args["epoch_id"] = self.epoch_id
trainer_args["step_id"] = self.step_id
program = fluid.Program()
with fluid.program_guard(program):
program.global_block().create_var(
name="scale_0",
psersistable=True,
dtype="float32",
shape=[32, 32])
exe = fluid.Executor(self.place)
for i in xrange(10):
fluid.io.save_checkpoint(exe, config.checkpoint_dir,
self.trainer_id, trainer_args, program,
config.max_num_checkpoints)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
def find_latest_set(num):
return 1 + int(math.floor(math.log(num, 2)))
class CodeTable(object):
def __init__(self, num_classes, code):
self.c = num_classes + code
def cal_index(self, bit):
return (self.c >> (bit + 1)) - 1
def get_length(self):
return find_latest_set(self.c) - 1
def cal_bit(self, bit):
return self.c & (1 << bit)
def hsigmoid(x, w, label, bias, num_classes):
batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1)
code_table = [0 for _ in range(code_length)]
pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx]
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
sum = 0.0
for j in range(length):
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
# soft relu
pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum
return pre_output, out
class TestHSigmoidOp(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.random.randint(0, num_classes, (batch_size, 1))
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
if __name__ == '__main__':
unittest.main()
......@@ -174,6 +174,16 @@ class TestBook(unittest.TestCase):
x=dat, label=lbl))
print(str(program))
def test_hsigmoid(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[2], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x, label=y, num_classes=2))
print(str(program))
def test_sequence_expand(self):
program = Program()
with program_guard(program):
......
......@@ -287,7 +287,7 @@ class TestAdamOptimizer(unittest.TestCase):
# Check accumulators
accumulators = adam_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 2)
self.assertEqual(len(accumulators), 4)
self.assertTrue(adam_optimizer.get_moment1_str() in accumulators)
self.assertTrue(adam_optimizer.get_moment2_str() in accumulators)
moment1_acc = accumulators[adam_optimizer.get_moment1_str()]
......@@ -354,7 +354,7 @@ class TestAdamaxOptimizer(unittest.TestCase):
# Check accumulators
accumulators = adamax_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 2)
self.assertEqual(len(accumulators), 3)
self.assertTrue(adamax_optimizer.get_moment_str() in accumulators)
self.assertTrue(adamax_optimizer.get_inf_norm_str() in accumulators)
moment_acc = accumulators[adamax_optimizer.get_moment_str()]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import numpy as np
from threading import Thread
def feed_data(feed_queue, inputs):
for in_data in inputs:
feed_queue.push(in_data)
class TestPyReader(unittest.TestCase):
def setUp(self):
self.capacity = 10
self.batch_size_min = 10
self.batch_size_max = 20
self.shapes = [(-1, 3, 2, 1), (-1, 1)]
self.lod_levels = [0, 0]
self.dtypes = ['float32', 'int64']
self.iterations = 20
def test_single_thread_main(self):
self.main(use_thread=False)
def test_multiple_thread_main(self):
self.main(use_thread=True)
def main(self, use_thread=False):
with fluid.program_guard(fluid.Program(), fluid.Program()):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
executor = fluid.Executor(place)
data_file, feed_queue = fluid.layers.py_reader(
capacity=self.capacity,
dtypes=self.dtypes,
lod_levels=self.lod_levels,
shapes=self.shapes)
read_out_data = fluid.layers.read_file(data_file)
self.inputs = []
for i in range(self.iterations):
in_data = fluid.LoDTensorArray()
batch_size = np.random.random_integers(self.batch_size_min,
self.batch_size_max)
for shape, dtype in zip(self.shapes, self.dtypes):
next_data = np.random.uniform(
low=0, high=1000,
size=(batch_size, ) + shape[1:]).astype(dtype)
in_data.append(executor.as_lodtensor(next_data))
self.inputs.append(in_data)
executor.run(fluid.default_startup_program())
self.outputs = []
if use_thread:
thread = Thread(
target=feed_data, args=(feed_queue, self.inputs))
thread.start()
for in_data in self.inputs:
self.outputs.append(
executor.run(fetch_list=list(read_out_data)))
else:
for in_data in self.inputs:
feed_queue.push(in_data)
self.outputs.append(
executor.run(fetch_list=list(read_out_data)))
feed_queue.close()
self.validate()
def validate(self):
self.assertEqual(len(self.inputs), len(self.outputs))
for in_data_list, out_data_list in zip(self.inputs, self.outputs):
self.assertEqual(len(in_data_list), len(out_data_list))
in_data_list_np = [
np.array(in_lod_tensor) for in_lod_tensor in in_data_list
]
for in_data, out_data in zip(in_data_list_np, out_data_list):
self.assertTrue((in_data == out_data).all())
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import threading
import multiprocessing
import os
def as_tensor(np_array_or_tensor, place=None):
if isinstance(np_array_or_tensor, fluid.LoDTensor):
return np_array_or_tensor
if place is None:
place = fluid.CPUPlace()
tensor = fluid.LoDTensor()
tensor.set(np_array_or_tensor, place)
return tensor
def as_numpy(tensor_or_numpy):
return tensor_or_numpy if isinstance(
tensor_or_numpy, np.ndarray) else np.array(tensor_or_numpy)
def feed_data(feed_queue, reader):
data_generator = reader()
while True:
data = next(data_generator, None)
if data is None or not feed_queue.push(data):
break
def simple_fc_net(in_size,
class_num,
hidden_sizes,
batch_size,
queue_capacity,
use_double_buffer=False):
reader, feed_queue = fluid.layers.py_reader(
capacity=queue_capacity,
shapes=[[-1, in_size], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.batch(reader, batch_size=batch_size)
if use_double_buffer:
reader = fluid.layers.double_buffer(reader)
in_data, label = fluid.layers.read_file(reader)
hidden = in_data
for hidden_size in hidden_sizes:
hidden = fluid.layers.fc(
hidden,
size=hidden_size,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
predict_label = fluid.layers.fc(hidden, size=class_num, act='softmax')
loss = fluid.layers.mean(
fluid.layers.cross_entropy(
input=predict_label, label=label))
optimizer = fluid.optimizer.Adam()
optimizer.minimize(loss)
return in_data, label, loss, optimizer, feed_queue
class TestPyReaderUsingExecutor(unittest.TestCase):
def setUp(self):
self.in_size = 1000
self.hidden_sizes = [50, 30, 20]
self.class_num = 10
self.batch_size = 32
self.iterations = 10
self.queue_capacity = 50
def test(self):
for use_cuda in [False, True]:
for use_parallel_executor in [False, True]:
for use_double_buffer in [False, True]:
print('Test Parameters:'),
print({
'use_cuda': use_cuda,
'use_parallel_executor': use_parallel_executor,
'use_double_buffer': use_double_buffer
})
self.main(use_cuda, use_parallel_executor,
use_double_buffer)
def random_reader(self):
def reader():
self.inputs = []
cnt = 0
while True:
tensors = fluid.LoDTensorArray()
in_data = np.random.uniform(
low=0, high=1, size=(1, self.in_size)).astype('float32')
tensors.append(as_tensor(in_data))
label = np.random.random_integers(
low=0, high=self.class_num - 1, size=(1, 1)).astype('int64')
tensors.append(as_tensor(label))
if cnt < self.iterations * self.batch_size * self.batch_size_times:
if cnt % (self.batch_size * self.batch_size_times) == 0:
self.inputs.append([in_data, label])
else:
self.inputs[-1][0] = np.concatenate(
(self.inputs[-1][0], in_data), axis=0)
self.inputs[-1][1] = np.concatenate(
(self.inputs[-1][1], label), axis=0)
elif not self.use_double_buffer:
break
yield tensors
cnt += 1
yield None
return reader
def main(self,
use_cuda=True,
use_parallel_executor=False,
use_double_buffer=False):
assert not use_cuda or use_cuda and core.is_compiled_with_cuda()
self.use_cuda = use_cuda
self.use_parallel_executor = use_parallel_executor
self.use_double_buffer = use_double_buffer
startup_program = fluid.Program()
main_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
in_data, label, loss, optimizer, feed_queue = simple_fc_net(
in_size=self.in_size,
class_num=self.class_num,
hidden_sizes=self.hidden_sizes,
batch_size=self.batch_size,
queue_capacity=self.queue_capacity,
use_double_buffer=self.use_double_buffer)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place)
startup_exe.run(startup_program)
if use_parallel_executor:
main_exe = fluid.ParallelExecutor(use_cuda, loss_name=loss.name)
if use_cuda:
self.batch_size_times = core.get_cuda_device_count()
else:
self.batch_size_times = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
main_exe = startup_exe
self.batch_size_times = 1
reader = self.random_reader()
thread = threading.Thread(
target=feed_data, args=(feed_queue, reader))
thread.start()
self.outputs = []
for _ in range(self.iterations):
fetches = main_exe.run(fetch_list=[in_data.name, label.name])
fetches = [as_numpy(fetch) for fetch in fetches]
self.outputs.append(fetches)
feed_queue.close()
self.validate()
def validate(self):
self.assertEqual(len(self.inputs), len(self.outputs))
for batch_in, batch_out in zip(self.inputs, self.outputs):
self.assertEqual(len(batch_in), len(batch_out))
if self.use_parallel_executor and not self.use_double_buffer:
self.validate_unordered_batch(batch_in, batch_out)
else:
for in_data, out_data in zip(batch_in, batch_out):
self.assertEqual(in_data.shape, out_data.shape)
if not self.use_parallel_executor:
self.assertTrue((in_data == out_data).all())
def validate_unordered_batch(self, batch_in, batch_out):
out_index_left_set = set(range(self.batch_size * self.batch_size_times))
mapping_num = 0
for i in range(self.batch_size * self.batch_size_times):
for j in out_index_left_set:
flag = True
for k in range(len(batch_in)):
in_data = batch_in[k][i]
out_data = batch_out[k][j]
if (in_data != out_data).any():
flag = False
break
if flag:
out_index_left_set.remove(j)
mapping_num += 1
break
self.assertEqual(mapping_num, self.batch_size * self.batch_size_times)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.v2 as paddle
import numpy as np
import unittest
class TestReaderReset(unittest.TestCase):
def prepare_data(self):
def fake_data_generator():
for n in xrange(self.total_ins_num):
yield np.ones(self.ins_shape) * n, n
# Prepare data
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(fake_data_generator, batch_size=1)
feeder = fluid.DataFeeder(
feed_list=[
fluid.layers.data(
name='data', shape=[3], dtype='float32'),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
self.data_file_name, reader, feeder)
def setUp(self):
self.use_cuda = fluid.core.is_compiled_with_cuda()
self.data_file_name = './reader_reset_test.recordio'
self.ins_shape = [3]
self.batch_size = 5
self.total_ins_num = self.batch_size * 20
self.test_pass_num = 100
self.prepare_data()
def main(self, with_double_buffer):
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
data_reader_handle = fluid.layers.io.open_files(
filenames=[self.data_file_name],
shapes=[[-1] + self.ins_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=1,
pass_num=1)
data_reader = fluid.layers.io.batch(data_reader_handle,
self.batch_size)
if with_double_buffer:
data_reader = fluid.layers.double_buffer(data_reader)
image, label = fluid.layers.read_file(data_reader)
fetch_list = [image.name, label.name]
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
if with_double_buffer:
build_strategy.enable_data_balance = True
exec_strategy = fluid.ExecutionStrategy()
parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
data_appeared = [False] * self.total_ins_num
pass_count = 0
while (True):
try:
data_val, label_val = parallel_exe.run(fetch_list,
return_numpy=True)
ins_num = data_val.shape[0]
broadcasted_label = np.ones((ins_num, ) + tuple(
self.ins_shape)) * label_val.reshape((ins_num, 1))
self.assertEqual(data_val.all(), broadcasted_label.all())
for l in label_val:
self.assertFalse(data_appeared[l[0]])
data_appeared[l[0]] = True
except fluid.core.EOFException:
pass_count += 1
if with_double_buffer:
data_appeared = data_appeared[:-parallel_exe.device_count *
self.batch_size]
for i in data_appeared:
self.assertTrue(i)
if pass_count < self.test_pass_num:
data_appeared = [False] * self.total_ins_num
data_reader_handle.reset()
else:
break
def test_all(self):
self.main(with_double_buffer=False)
self.main(with_double_buffer=True)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze"
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
# Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
# Correct: Inplace.
class TestSqueezeOpInplace1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins axis.
class TestSqueezeOpInplace2(TestSqueezeOp):
def inti_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. No axes input.
class TestSqueezeOpInplace3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestUnsqueezeOp(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (-1, )
self.new_shape = (3, 5, 1)
# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, -1)
self.new_shape = (1, 3, 5, 1)
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1)
# Correct: Inplace.
class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, 2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins index.
class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, -2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
unittest.main()
此差异已折叠。
......@@ -377,11 +377,6 @@ class DistributeTranspiler(object):
# append it into the sub program.
global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
for op in self.optimize_ops:
if self._is_adam_connected_op(op):
global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops):
......@@ -1289,22 +1284,16 @@ class DistributeTranspiler(object):
# If one op's input is another op's output or
# one op's output is another op's input, we say
# the two operator is connected.
def _append_inname_remove_beta(varname_list):
def _append_inname(varname_list):
op_input_names = []
for in_name in varname_list:
# HACK: remove beta1 and beta2 to avoid let all
# ops connected.
if in_name.startswith("beta2_pow_acc") or \
in_name.startswith("beta1_pow_acc"):
continue
else:
op_input_names.append(in_name)
op_input_names.append(in_name)
return op_input_names
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
op1_input_names = _append_inname(op1.desc.input_arg_names())
op1_output_names = op1.desc.output_arg_names()
op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
op2_input_names = _append_inname(op2.desc.input_arg_names())
op2_output_names = op2.desc.output_arg_names()
if set(op1_output_names) & set(op2_input_names) or \
......@@ -1413,7 +1402,7 @@ class DistributeTranspiler(object):
def _get_optimize_pass(self):
"""
Get optimizer operators, paramters and gradients from origin_program
Get optimizer operators, parameters and gradients from origin_program
Returns:
opt_ops (list): optimize operators.
params_grads (dict): paramter->gradient.
......@@ -1436,20 +1425,6 @@ class DistributeTranspiler(object):
origin_var_dict[param_name],
origin_var_dict[input_name]
])
elif self._is_adam_connected_op(op):
opt_ops.append(op)
else:
pass
return opt_ops, params_grads
def _is_adam_connected_op(self, op):
"""
A hack function to determinate whether the input operator
is connected to optimize operator.
"""
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
return True
return False
......@@ -42,12 +42,12 @@ def get_patch():
def is_taged():
try:
cmd = ['git', 'describe', '--exact-match', '--tags']
cmd = ['git', 'describe', '--exact-match', '--tags', 'HEAD', '2>/dev/null']
git_tag = subprocess.Popen(cmd, stdout = subprocess.PIPE).communicate()[0].strip()
except:
return False
if git_tag.replace('v', '') == '@PADDLE_VERSION@':
if str(git_tag).replace('v', '') == '@PADDLE_VERSION@':
return True
else:
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册