提交 a7fa2051 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into add-multiBatch-chunkEval-dev

......@@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://doc.paddlepaddle.org/develop/doc/)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://doc.paddlepaddle.org/develop/doc_cn/)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html)
[![Coverage Status](https://coveralls.io/repos/github/PaddlePaddle/Paddle/badge.svg?branch=develop)](https://coveralls.io/github/PaddlePaddle/Paddle?branch=develop)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "error.h"
const char* paddle_error_string(paddle_error err) {
extern "C" const char* paddle_error_string(paddle_error err) {
switch (err) {
case kPD_NULLPTR:
return "nullptr error";
......
......@@ -29,9 +29,17 @@ typedef enum {
kPD_UNDEFINED_ERROR = -1,
} paddle_error;
#ifdef __cplusplus
extern "C" {
#endif
/**
* Error string for Paddle API.
*/
PD_API const char* paddle_error_string(paddle_error err);
#ifdef __cplusplus
}
#endif
#endif
......@@ -430,14 +430,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> op_grads;
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
int step_block_idx = (*it)->GetBlockAttr("step_block");
int step_block_idx = (*it)->GetBlockAttr("sub_block");
BlockDescBind* backward_block = CreateStepBlock(
program_desc, no_grad_vars, grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block =
CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
(*it)->GetBlockAttr("block"));
(*it)->GetBlockAttr("sub_block"));
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
......
......@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front();
auto *block = Attr<framework::BlockDescBind *>("block");
auto *block = Attr<framework::BlockDescBind *>("sub_block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
}
......@@ -88,7 +88,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"unify the conditional block, rnn and while op, the type of "
"scope is std::vector<Scope*>");
AddAttr<framework::BlockDescBind *>(
"block", "The step block of conditional block operator");
"sub_block", "The step block of conditional block operator");
AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
......@@ -117,7 +117,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
framework::Scope &cur_scope = *scopes[0];
auto *block = Attr<framework::BlockDescBind *>("block");
auto *block = Attr<framework::BlockDescBind *>("sub_block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
......@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetInput("Scope", Output("Scope"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params"));
grad_op->SetBlockAttr("block", *this->grad_block_[0]);
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
......
......@@ -273,6 +273,13 @@ void set_constant_with_place<platform::GPUPlace>(
TensorSetConstantGPU(context, tensor, value));
}
template <>
void set_constant_with_place<platform::CudnnPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
set_constant_with_place<platform::GPUPlace>(context, tensor, value);
}
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
......
......@@ -25,7 +25,7 @@ constexpr char kOutputs[] = "outputs";
constexpr char kStepScopes[] = "step_scopes";
constexpr char kExStates[] = "ex_states";
constexpr char kStates[] = "states";
constexpr char kStepBlock[] = "step_block";
constexpr char kStepBlock[] = "sub_block";
constexpr char kReverse[] = "reverse";
constexpr char kIsTrain[] = "is_train";
#define GRAD_SUFFIX "@GRAD"
......
......@@ -25,7 +25,7 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "step_block";
constexpr char kStepBlock[] = "sub_block";
constexpr char kCondition[] = "Condition";
constexpr char kStepScopes[] = "StepScopes";
constexpr char kParameters[] = "X";
......
......@@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place)
: CUDADeviceContext(place), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
}
CudnnDeviceContext::~CudnnDeviceContext() {
SetDeviceId(place_.device);
Wait();
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); }
cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; }
#endif
} // namespace platform
......
......@@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_;
};
class CudnnDeviceContext : public CUDADeviceContext {
public:
explicit CudnnDeviceContext(CudnnPlace place);
virtual ~CudnnDeviceContext();
/*! \brief Return place in the device context. */
Place GetPlace() const final;
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
private:
cudnnHandle_t cudnn_handle_;
CudnnPlace place_;
};
#endif
} // namespace platform
......
......@@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) {
delete device_context;
}
}
TEST(Device, CudnnDeviceContext) {
using paddle::platform::CudnnDeviceContext;
using paddle::platform::CudnnPlace;
if (paddle::platform::dynload::HasCUDNN()) {
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
CudnnDeviceContext* device_context =
new CudnnDeviceContext(CudnnPlace(i));
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context;
}
}
}
......@@ -43,6 +43,11 @@ struct GPUPlace {
int device;
};
struct CudnnPlace : public GPUPlace {
CudnnPlace() : GPUPlace() {}
explicit CudnnPlace(int d) : GPUPlace(d) {}
};
struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GPUPlace &gpu) const { return true; }
......@@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<GPUPlace, CPUPlace> Place;
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace> Place;
// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
......
......@@ -282,6 +282,23 @@ All parameter, weight, gradient are variables in Paddle.
}
return ret_values;
});
m.def("get_grad_op_descs",
[](const OpDescBind &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> &grad_to_var,
const std::vector<BlockDescBind *> &grad_sub_block) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, &grad_to_var,
grad_sub_block);
std::vector<OpDescBind *> grad_op_desc_ptrs(grad_op_descs.size());
std::transform(
grad_op_descs.begin(), grad_op_descs.end(),
grad_op_desc_ptrs.begin(),
[](std::unique_ptr<OpDescBind> &p) { return p.release(); });
return grad_op_desc_ptrs;
});
m.def("prune", [](const ProgramDescBind &origin,
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin);
......
import ops
from ops import *
import nn
from nn import *
import io
from io import *
import tensor
from tensor import *
import control_flow
from control_flow import *
__all__ = []
__all__ += nn.__all__
__all__ += io.__all__
__all__ += tensor.__all__
__all__ += control_flow.__all__
__all__ += ops.__all__
from ..layer_helper import LayerHelper, unique_name
from ..framework import Program, Variable, Operator
from .. import core
from tensor import assign, fill_constant
import contextlib
import proto.framework_pb2 as framework_pb2
import core
from framework import OpProtoHolder, Variable, Program, Operator
from initializer import Constant, Normal, Xavier, Initializer
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
from registry import register_layer
from param_attr import ParamAttr
__all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim',
'batch_norm', 'accuracy', 'split_lod_tensor', 'While'
]
_REGISTER_LAYER_FROM_OPS = [
'mean', 'mul', 'dropout', 'reshape', 'sigmoid', 'scale', 'transpose',
'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div',
'elementwise_sub', 'elementwise_mul', 'clip', 'abs'
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard',
'StaticRNNMemoryLink', 'WhileGuard', 'While', 'lod_rank_table',
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor',
'increment', 'array_write', 'create_array', 'less_than', 'array_read',
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock',
'StaticRNN'
]
for _OP in set(_REGISTER_LAYER_FROM_OPS):
globals()[_OP] = register_layer(_OP)
__all__.append(_OP)
def fc(input,
size,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
act=None,
name=None,
main_program=None,
startup_program=None):
"""
Fully Connected Layer.
Args:
input: The input tensor to the function
size: The size of the layer
num_flatten_dims: Number of columns in input
param_attr: The parameters/weights to the FC Layer
param_initializer: Initializer used for the weight/parameter. If None, XavierInitializer() is used
bias_attr: The bias parameter for the FC layer
bias_initializer: Initializer used for the bias. If None, then ConstantInitializer() is used
act: Activation to be applied to the output of FC layer
name: Name/alias of the function
main_program: Name of the main program that calls this
startup_program: Name of the startup program
This function can take in multiple inputs and performs the Fully Connected
function (linear transformation) on top of each of them.
So for input x, the output will be : Wx + b. Where W is the parameter,
b the bias and x is the input.
The function also applies an activation (non-linearity) on top of the
output, if activation is passed in the input.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper = LayerHelper('fc', **locals())
dtype = helper.input_dtype()
mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
helper.append_op(
type="mul",
inputs={
"X": input_var,
"Y": w,
},
outputs={"Out": tmp},
attrs={'x_num_col_dims': num_flatten_dims,
'y_num_col_dims': 1})
mul_results.append(tmp)
# sum
if len(mul_results) == 1:
pre_bias = mul_results[0]
else:
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias})
# add bias
pre_activation = helper.append_bias_op(pre_bias)
# add activation
return helper.append_activation(pre_activation)
def embedding(input,
size,
is_sparse=False,
param_attr=None,
dtype='float32',
main_program=None,
startup_program=None):
"""
Embedding Layer.
Args:
param_initializer:
input: The input to the function
size: The size of the layer
is_sparse: A flag that decleares whether the input is sparse
param_attr: Parameters for this layer
dtype: The type of data : float32, float_16, int etc
main_program: Name of the main program that calls this
startup_program: Name of the startup program
This function can take in the input (which is a vector of IDs) and
performs a lookup in the lookup_table using these IDs, to result into
the embedding of each ID in the input.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper = LayerHelper('embedding', **locals())
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
helper.append_op(
type='lookup_table',
inputs={'Ids': input,
'W': w},
outputs={'Out': tmp},
attrs={'is_sparse': is_sparse})
return tmp
# TODO(qijun): expose H0 and C0
def dynamic_lstm(input,
size,
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
dtype='float32',
main_program=None,
startup_program=None):
helper = LayerHelper('lstm', **locals())
size = size / 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
hidden = helper.create_tmp_variable(dtype)
cell = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype)
helper.append_op(
type='lstm',
inputs={'Input': input,
'Weight': weight,
'Bias': bias},
outputs={
'Hidden': hidden,
'Cell': cell,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act
},
attrs={
'use_peepholes': use_peepholes,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation
})
return hidden, cell
def gru_unit(input,
hidden,
size,
weight=None,
bias=None,
activation='tanh',
gate_activation='sigmoid',
main_program=None,
startup_program=None):
"""
GRUUnit Operator implements partial calculations of the GRU unit as following:
$$
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
$$
which is same as one time step of GRU Operator.
@note To implement the complete GRU unit, fully-connected operator must be
used before to feed xu, xr and xc as the Input of GRUUnit operator.
TODO(ChunweiYan) add more document here
"""
activation_dict = dict(
identity=0,
sigmoid=1,
tanh=2,
relu=3, )
activation = activation_dict[activation]
gate_activation = activation_dict[gate_activation]
helper = LayerHelper('gru_unit', **locals())
dtype = helper.input_dtype()
size = size / 3
# create weight
if weight is None:
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
# create bias
if bias is None:
bias_size = [1, 3 * size]
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
gate = helper.create_tmp_variable(dtype)
reset_hidden_pre = helper.create_tmp_variable(dtype)
updated_hidden = helper.create_tmp_variable(dtype)
helper.append_op(
type='gru_unit',
inputs={'Input': input,
'HiddenPrev': hidden,
'Weight': weight},
outputs={
'Gate': gate,
'ResetHiddenPrev': reset_hidden_pre,
'Hidden': updated_hidden,
},
attrs={
'activation': 0,
'gate_activation': 1,
})
return updated_hidden, reset_hidden_pre, gate
def data(name,
shape,
append_batch_size=True,
dtype='float32',
lod_level=0,
type=core.VarDesc.VarType.LOD_TENSOR,
main_program=None,
startup_program=None,
stop_gradient=True):
"""
Data Layer.
Args:
name: The name/alias of the function
shape: Tuple declaring the shape.
append_batch_size: Whether or not to append the data as a batch.
dtype: The type of data : float32, float_16, int etc
type: The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
main_program: Name of the main program that calls this
startup_program: Name of the startup program
stop_gradient: A boolean that mentions whether gradient should flow.
This function takes in input and based on whether data has
to be returned back as a minibatch, it creates the global variable using
the helper functions. The global variables can be accessed by all the
following operations and layers in the graph.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in xrange(len(shape)):
if shape[i] is None:
shape[i] = -1
append_batch_size = False
elif shape[i] < 0:
append_batch_size = False
if append_batch_size:
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=type,
stop_gradient=stop_gradient,
lod_level=lod_level)
def create_tensor(dtype, name=None, main_program=None, startup_program=None):
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name, dtype=dtype)
def cast(x, dtype, main_program=None):
"""
This function takes in the input with input_dtype
and casts it to the output_dtype as the output.
"""
helper = LayerHelper('cast', **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='cast',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'in_dtype': x.dtype,
'out_dtype': out.dtype})
return out
def concat(input, axis, main_program=None, startup_program=None):
"""
This function concats the input along the axis mentioned
and returns that as the output.
"""
helper = LayerHelper('concat', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='concat',
inputs={'X': input},
outputs={'Out': [out]},
attrs={'axis': axis})
return out
def sums(input, out=None, main_program=None, startup_program=None):
"""
This function takes in the input and performs the sum operation on it
and returns that as the output.
"""
helper = LayerHelper('sum', **locals())
if out is None:
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out})
return out
def linear_chain_crf(input,
label,
param_attr=None,
main_program=None,
startup_program=None):
helper = LayerHelper('linear_chain_crf', **locals())
size = input.shape[1]
transition = helper.create_parameter(
attr=helper.param_attr,
shape=[size + 2, size],
dtype=helper.input_dtype())
alpha = helper.create_tmp_variable(dtype=helper.input_dtype())
emission_exps = helper.create_tmp_variable(dtype=helper.input_dtype())
transition_exps = helper.create_tmp_variable(dtype=helper.input_dtype())
log_likelihood = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='linear_chain_crf',
inputs={"Emission": [input],
"Transition": transition,
"Label": label},
outputs={
"Alpha": [alpha],
"EmissionExps": [emission_exps],
"TransitionExps": transition_exps,
"LogLikelihood": log_likelihood
})
return log_likelihood
def crf_decoding(input,
param_attr,
label=None,
main_program=None,
startup_program=None):
helper = LayerHelper('crf_decoding', **locals())
transition = helper.get_parameter(param_attr.name)
viterbi_path = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='crf_decoding',
inputs={"Emission": [input],
"Transition": transition,
"Label": label},
outputs={"ViterbiPath": [viterbi_path]})
return viterbi_path
def assign(input, output, main_program=None, startup_program=None):
helper = LayerHelper('assign', **locals())
helper.append_op(
type='scale',
inputs={'X': [input]},
outputs={'Out': [output]},
attrs={'scale': 1.0})
return output
def split_lod_tensor(input,
mask,
......@@ -460,410 +54,6 @@ def merge_lod_tensor(in_true,
return out
def cos_sim(X, Y, **kwargs):
"""
This function performs the cosine similarity between two tensors
X and Y and returns that as the output.
"""
helper = LayerHelper('cos_sim', **kwargs)
out = helper.create_tmp_variable(dtype=X.dtype)
xnorm = helper.create_tmp_variable(dtype=X.dtype)
ynorm = helper.create_tmp_variable(dtype=X.dtype)
helper.append_op(
type='cos_sim',
inputs={'X': [X],
'Y': [Y]},
outputs={'Out': [out],
'XNorm': [xnorm],
'YNorm': [ynorm]})
return out
def cross_entropy(input, label, **kwargs):
"""
This function computes cross_entropy using the input and label.
"""
helper = LayerHelper('cross_entropy', **kwargs)
out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='cross_entropy',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs=kwargs)
return out
def square_error_cost(input, label, **kwargs):
"""
This functions returns the squared error cost using the input and label.
The output is appending the op to do the above.
"""
helper = LayerHelper('square_error_cost', **kwargs)
minus_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]})
return square_out
def accuracy(input, label, k=1, correct=None, total=None, **kwargs):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
"""
helper = LayerHelper("accuracy", **kwargs)
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
acc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
if total is None:
total = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="accuracy",
inputs={
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
},
outputs={
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total],
})
return acc_out
def chunk_eval(input,
label,
chunk_scheme,
num_chunk_types,
excluded_chunk_types=None,
**kwargs):
"""
This function computes and outputs the precision, recall and
F1-score of chunk detection.
"""
helper = LayerHelper("chunk_eval", **kwargs)
# prepare output
precision = helper.create_tmp_variable(dtype="float32")
recall = helper.create_tmp_variable(dtype="float32")
f1_score = helper.create_tmp_variable(dtype="float32")
num_infer_chunks = helper.create_tmp_variable(dtype="int64")
num_label_chunks = helper.create_tmp_variable(dtype="int64")
num_correct_chunks = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="chunk_eval",
inputs={"Inference": [input],
"Label": [label]},
outputs={
"Precision": [precision],
"Recall": [recall],
"F1-Score": [f1_score],
"NumInferChunks": [num_infer_chunks],
"NumLabelChunks": [num_label_chunks],
"NumCorrectChunks": [num_correct_chunks]
},
attrs={
"num_chunk_types": num_chunk_types,
'chunk_scheme': chunk_scheme,
'excluded_chunk_types': excluded_chunk_types or []
})
return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks
def sequence_conv(input,
num_filters,
filter_size=3,
filter_stride=1,
padding=None,
bias_attr=None,
param_attr=None,
act=None,
main_program=None,
startup_program=None):
"""
This function creates the op for sequence_conv, using the inputs and
other convolutional configurations for the filters and stride as given
in the input parameters to the function.
"""
# FIXME(dzh) : want to unify the argument of python layer
# function. So we ignore some unecessary attributes.
# such as, padding_trainable, context_start.
helper = LayerHelper('sequence_conv', **locals())
dtype = helper.input_dtype()
filter_shape = [filter_size * input.shape[1], num_filters]
filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype)
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type='sequence_conv',
inputs={
'X': [input],
'Filter': [filter_param],
},
outputs={"Out": pre_bias},
attrs={
'contextStride': filter_stride,
'contextStart': -int(filter_size / 2),
'contextLength': filter_size
})
pre_act = helper.append_bias_op(pre_bias)
return helper.append_activation(pre_act)
def conv2d(input,
num_filters,
filter_size,
stride=None,
padding=None,
groups=None,
param_attr=None,
bias_attr=None,
act=None,
name=None,
main_program=None,
startup_program=None):
"""
This function creates the op for a 2-dimensional Convolution.
This is performed using the parameters of filters(size, dimensionality etc)
, stride and other configurations for a Convolution operation.
This funciton can also append an activation on top of the
conv-2d output, if mentioned in the input parameters.
"""
if stride is None:
stride = [1, 1]
helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype()
num_channels = input.shape[1]
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups
if isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
if isinstance(stride, int):
stride = [stride, stride]
if isinstance(padding, int):
padding = [padding, padding]
input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size
def _get_default_param_initializer():
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
return Normal(0.0, std, 0)
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
default_initializer=_get_default_param_initializer())
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type='conv2d_cudnn',
inputs={
'Input': input,
'Filter': filter_param,
},
outputs={"Output": pre_bias},
attrs={'strides': stride,
'paddings': padding,
'groups': groups})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
return helper.append_activation(pre_act)
def sequence_pool(input, pool_type, **kwargs):
"""
This function add the operator for sequence pooling.
This is applied on top of the input using pool_type mentioned
in the parameters.
"""
helper = LayerHelper('sequence_pool', input=input, **kwargs)
dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype)
max_index = helper.create_tmp_variable(dtype)
helper.append_op(
type="sequence_pool",
inputs={"X": input},
outputs={"Out": pool_out,
"MaxIndex": max_index},
attrs={"pooltype": pool_type.upper()})
return pool_out
def pool2d(input,
pool_size,
pool_type,
pool_stride=None,
pool_padding=None,
global_pooling=False,
main_program=None,
startup_program=None):
"""
This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters.
"""
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if isinstance(pool_size, int):
pool_size = [pool_size, pool_size]
if isinstance(pool_stride, int):
pool_stride = [pool_stride, pool_stride]
if isinstance(pool_padding, int):
pool_padding = [pool_padding, pool_padding]
helper = LayerHelper('pool2d', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="pool2d",
inputs={"X": input},
outputs={"Out": pool_out},
attrs={
"pooling_type": pool_type,
"ksize": pool_size,
"global_pooling": global_pooling,
"strides": pool_stride,
"paddings": pool_padding
})
return pool_out
def batch_norm(input,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
main_program=None,
startup_program=None):
"""
This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters.
"""
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
if data_layout == 'NHWC':
channel_num = input_shape[-1]
else:
raise ValueError("unsupported data layout:" + data_layout)
param_shape = [channel_num]
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
bias = helper.create_parameter(
attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable(
dtype=input.dtype, shape=param_shape, persistable=True)
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
variance = helper.create_global_variable(
dtype=input.dtype, shape=param_shape, persistable=True)
helper.set_variable_initializer(var=variance, initializer=Constant(1.0))
# create output
# mean and mean_out share the same memory
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
saved_mean = helper.create_tmp_variable(dtype)
saved_variance = helper.create_tmp_variable(dtype)
batch_norm_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="batch_norm",
inputs={
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
},
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test})
return helper.append_activation(batch_norm_out)
def beam_search_decode(ids, scores, main_program=None, startup_program=None):
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype)
helper.append_op(
type="beam_search_decode",
inputs={"Ids": ids,
"Scores": scores},
outputs={
"SentenceIds": sentence_ids,
"SentenceScores": sentence_scores
})
return sentence_ids, sentence_scores
class BlockGuard(object):
"""
BlockGuard class.
......@@ -1136,7 +326,7 @@ class StaticRNN(object):
attrs={
'ex_states': pre_memories,
'states': memories,
'step_block': rnn_block
'sub_block': rnn_block
})
......@@ -1213,51 +403,7 @@ class While(object):
},
outputs={'Out': out_vars,
'StepScopes': [step_scope]},
attrs={'step_block': while_block})
def lstm(x,
c_pre_init,
hidden_dim,
forget_bias=None,
main_program=None,
startup_program=None):
"""
This function helps create an operator for the LSTM (Long Short Term
Memory) cell that can be used inside an RNN.
"""
helper = LayerHelper('lstm_unit', **locals())
rnn = StaticRNN()
with rnn.step():
c_pre = rnn.memory(init=c_pre_init)
x_t = rnn.step_input(x)
before_fc = concat(
input=[x_t, c_pre],
axis=1,
main_program=main_program,
startup_program=startup_program)
after_fc = fc(input=before_fc,
size=hidden_dim * 4,
main_program=main_program,
startup_program=startup_program)
dtype = x.dtype
c = helper.create_tmp_variable(dtype)
h = helper.create_tmp_variable(dtype)
helper.append_op(
type='lstm_unit',
inputs={"X": after_fc,
"C_prev": c_pre},
outputs={"C": c,
"H": h},
attrs={"forget_bias": forget_bias})
rnn.update_memory(c_pre, c)
rnn.output(h)
return rnn()
attrs={'sub_block': while_block})
def lod_rank_table(x, level=0, main_program=None):
......@@ -1337,72 +483,6 @@ def array_to_lod_tensor(x, table, main_program=None, startup_program=None):
return tmp
def fill_constant(shape,
dtype,
value,
out=None,
main_program=None,
startup_program=None):
"""
This function creates a tensor , with shape as mentioned in the input and
specified dtype and fills this up with a constant value that
comes in the input. It also sets the stop_gradient to be True.
"""
helper = LayerHelper("fill_constant", **locals())
if out is None:
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={'shape': shape,
'dtype': out.dtype,
'value': float(value)})
out.stop_gradient = True
return out
def fill_constant_batch_size_like(input,
shape,
dtype,
value,
input_dim_idx=0,
output_dim_idx=0,
main_program=None,
startup_program=None):
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx
})
out.stop_gradient = True
return out
def ones(shape, dtype, main_program=None):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 1.0.
"""
return fill_constant(value=1.0, **locals())
def zeros(shape, dtype, main_program=None):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 0.0.
"""
return fill_constant(value=0.0, **locals())
def increment(x,
value=1.0,
in_place=True,
......@@ -1514,95 +594,6 @@ def array_length(array, main_program=None):
return tmp
def conv2d_transpose(input,
num_filters,
output_size=None,
filter_size=None,
padding=None,
stride=None,
param_attr=None,
main_program=None,
startup_program=None):
"""
The transpose of conv2d layer.
This layer is also known as deconvolution layer.
Args:
input(Variable): The input image with [N, C, H, W] format.
num_filters(int): The number of filter. It is as same as the output
image channel.
output_size(int|tuple|None): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). This
parameter only works when filter_size is None.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square. None if use output size to
calculate filter_size
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding.
stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride.
param_attr: Parameter Attribute.
main_program(Program): the main program
startup_program(Program): the startup program
Returns:
Variable: Output image.
"""
helper = LayerHelper("conv2d_transpose", **locals())
if not isinstance(input, Variable):
raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1]
op_attr = dict()
if isinstance(padding, int):
op_attr['paddings'] = [padding, padding]
elif padding is not None:
op_attr['paddings'] = padding
if isinstance(stride, int):
op_attr['strides'] = stride
elif stride is not None:
op_attr['strides'] = stride
if filter_size is None:
if output_size is None:
raise ValueError("output_size must be set when filter_size is None")
if isinstance(output_size, int):
output_size = [output_size, output_size]
padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1])
h_in = input.shape[2]
w_in = input.shape[3]
filter_size_h = output_size[0] - \
(h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - \
(w_in - 1) * stride[1] + 2 * padding[1]
filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
filter_shape = [input_channel, num_filters] + filter_size
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='conv2d_transpose',
inputs={'Input': [input],
'Filter': [img_filter]},
outputs={'Output': out},
attrs=op_attr)
return out
class ConditionalBlockGuard(BlockGuard):
def __init__(self, block):
if not isinstance(block, ConditionalBlock):
......@@ -1677,7 +668,7 @@ class ConditionalBlock(object):
},
outputs={'Out': out_list,
'Scope': [step_scope]},
attrs={'block': inside_block})
attrs={'sub_block': inside_block})
class IfElseBlockGuard(object):
......
from .. import core
from ..layer_helper import LayerHelper
__all__ = ['data']
def data(name,
shape,
append_batch_size=True,
dtype='float32',
lod_level=0,
type=core.VarDesc.VarType.LOD_TENSOR,
main_program=None,
startup_program=None,
stop_gradient=True):
"""
Data Layer.
Args:
name: The name/alias of the function
shape: Tuple declaring the shape.
append_batch_size: Whether or not to append the data as a batch.
dtype: The type of data : float32, float_16, int etc
type: The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
main_program: Name of the main program that calls this
startup_program: Name of the startup program
stop_gradient: A boolean that mentions whether gradient should flow.
This function takes in input and based on whether data has
to be returned back as a minibatch, it creates the global variable using
the helper functions. The global variables can be accessed by all the
following operations and layers in the graph.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in xrange(len(shape)):
if shape[i] is None:
shape[i] = -1
append_batch_size = False
elif shape[i] < 0:
append_batch_size = False
if append_batch_size:
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=type,
stop_gradient=stop_gradient,
lod_level=lod_level)
"""
All layers just related to the neural network.
"""
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
__all__ = [
'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf',
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose'
]
def fc(input,
size,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
act=None,
name=None,
main_program=None,
startup_program=None):
"""
Fully Connected Layer.
Args:
input: The input tensor to the function
size: The size of the layer
num_flatten_dims: Number of columns in input
param_attr: The parameters/weights to the FC Layer
param_initializer: Initializer used for the weight/parameter. If None, XavierInitializer() is used
bias_attr: The bias parameter for the FC layer
bias_initializer: Initializer used for the bias. If None, then ConstantInitializer() is used
act: Activation to be applied to the output of FC layer
name: Name/alias of the function
main_program: Name of the main program that calls this
startup_program: Name of the startup program
This function can take in multiple inputs and performs the Fully Connected
function (linear transformation) on top of each of them.
So for input x, the output will be : Wx + b. Where W is the parameter,
b the bias and x is the input.
The function also applies an activation (non-linearity) on top of the
output, if activation is passed in the input.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper = LayerHelper('fc', **locals())
dtype = helper.input_dtype()
mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
helper.append_op(
type="mul",
inputs={
"X": input_var,
"Y": w,
},
outputs={"Out": tmp},
attrs={'x_num_col_dims': num_flatten_dims,
'y_num_col_dims': 1})
mul_results.append(tmp)
# sum
if len(mul_results) == 1:
pre_bias = mul_results[0]
else:
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias})
# add bias
pre_activation = helper.append_bias_op(pre_bias)
# add activation
return helper.append_activation(pre_activation)
def embedding(input,
size,
is_sparse=False,
param_attr=None,
dtype='float32',
main_program=None,
startup_program=None):
"""
Embedding Layer.
Args:
param_initializer:
input: The input to the function
size: The size of the layer
is_sparse: A flag that decleares whether the input is sparse
param_attr: Parameters for this layer
dtype: The type of data : float32, float_16, int etc
main_program: Name of the main program that calls this
startup_program: Name of the startup program
This function can take in the input (which is a vector of IDs) and
performs a lookup in the lookup_table using these IDs, to result into
the embedding of each ID in the input.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper = LayerHelper('embedding', **locals())
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
helper.append_op(
type='lookup_table',
inputs={'Ids': input,
'W': w},
outputs={'Out': tmp},
attrs={'is_sparse': is_sparse})
return tmp
# TODO(qijun): expose H0 and C0
def dynamic_lstm(input,
size,
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
dtype='float32',
main_program=None,
startup_program=None):
helper = LayerHelper('lstm', **locals())
size = size / 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
hidden = helper.create_tmp_variable(dtype)
cell = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype)
helper.append_op(
type='lstm',
inputs={'Input': input,
'Weight': weight,
'Bias': bias},
outputs={
'Hidden': hidden,
'Cell': cell,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act
},
attrs={
'use_peepholes': use_peepholes,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation
})
return hidden, cell
def gru_unit(input,
hidden,
size,
weight=None,
bias=None,
activation='tanh',
gate_activation='sigmoid',
main_program=None,
startup_program=None):
"""
GRUUnit Operator implements partial calculations of the GRU unit as following:
$$
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
$$
which is same as one time step of GRU Operator.
@note To implement the complete GRU unit, fully-connected operator must be
used before to feed xu, xr and xc as the Input of GRUUnit operator.
TODO(ChunweiYan) add more document here
"""
activation_dict = dict(
identity=0,
sigmoid=1,
tanh=2,
relu=3, )
activation = activation_dict[activation]
gate_activation = activation_dict[gate_activation]
helper = LayerHelper('gru_unit', **locals())
dtype = helper.input_dtype()
size = size / 3
# create weight
if weight is None:
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
# create bias
if bias is None:
bias_size = [1, 3 * size]
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
gate = helper.create_tmp_variable(dtype)
reset_hidden_pre = helper.create_tmp_variable(dtype)
updated_hidden = helper.create_tmp_variable(dtype)
helper.append_op(
type='gru_unit',
inputs={'Input': input,
'HiddenPrev': hidden,
'Weight': weight},
outputs={
'Gate': gate,
'ResetHiddenPrev': reset_hidden_pre,
'Hidden': updated_hidden,
},
attrs={
'activation': 0,
'gate_activation': 1,
})
return updated_hidden, reset_hidden_pre, gate
def linear_chain_crf(input,
label,
param_attr=None,
main_program=None,
startup_program=None):
helper = LayerHelper('linear_chain_crf', **locals())
size = input.shape[1]
transition = helper.create_parameter(
attr=helper.param_attr,
shape=[size + 2, size],
dtype=helper.input_dtype())
alpha = helper.create_tmp_variable(dtype=helper.input_dtype())
emission_exps = helper.create_tmp_variable(dtype=helper.input_dtype())
transition_exps = helper.create_tmp_variable(dtype=helper.input_dtype())
log_likelihood = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='linear_chain_crf',
inputs={"Emission": [input],
"Transition": transition,
"Label": label},
outputs={
"Alpha": [alpha],
"EmissionExps": [emission_exps],
"TransitionExps": transition_exps,
"LogLikelihood": log_likelihood
})
return log_likelihood
def crf_decoding(input,
param_attr,
label=None,
main_program=None,
startup_program=None):
helper = LayerHelper('crf_decoding', **locals())
transition = helper.get_parameter(param_attr.name)
viterbi_path = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='crf_decoding',
inputs={"Emission": [input],
"Transition": transition,
"Label": label},
outputs={"ViterbiPath": [viterbi_path]})
return viterbi_path
def cos_sim(X, Y, **kwargs):
"""
This function performs the cosine similarity between two tensors
X and Y and returns that as the output.
"""
helper = LayerHelper('cos_sim', **kwargs)
out = helper.create_tmp_variable(dtype=X.dtype)
xnorm = helper.create_tmp_variable(dtype=X.dtype)
ynorm = helper.create_tmp_variable(dtype=X.dtype)
helper.append_op(
type='cos_sim',
inputs={'X': [X],
'Y': [Y]},
outputs={'Out': [out],
'XNorm': [xnorm],
'YNorm': [ynorm]})
return out
def cross_entropy(input, label, **kwargs):
"""
This function computes cross_entropy using the input and label.
"""
helper = LayerHelper('cross_entropy', **kwargs)
out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='cross_entropy',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs=kwargs)
return out
def square_error_cost(input, label, **kwargs):
"""
This functions returns the squared error cost using the input and label.
The output is appending the op to do the above.
"""
helper = LayerHelper('square_error_cost', **kwargs)
minus_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]})
return square_out
def accuracy(input, label, k=1, correct=None, total=None, **kwargs):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
"""
helper = LayerHelper("accuracy", **kwargs)
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
acc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
if total is None:
total = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="accuracy",
inputs={
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
},
outputs={
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total],
})
return acc_out
def chunk_eval(input,
label,
chunk_scheme,
num_chunk_types,
excluded_chunk_types=None,
**kwargs):
"""
This function computes and outputs the precision, recall and
F1-score of chunk detection.
"""
helper = LayerHelper("chunk_eval", **kwargs)
# prepare output
precision = helper.create_tmp_variable(dtype="float32")
recall = helper.create_tmp_variable(dtype="float32")
f1_score = helper.create_tmp_variable(dtype="float32")
num_infer_chunks = helper.create_tmp_variable(dtype="int64")
num_label_chunks = helper.create_tmp_variable(dtype="int64")
num_correct_chunks = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="chunk_eval",
inputs={"Inference": [input],
"Label": [label]},
outputs={
"Precision": [precision],
"Recall": [recall],
"F1-Score": [f1_score],
"NumInferChunks": [num_infer_chunks],
"NumLabelChunks": [num_label_chunks],
"NumCorrectChunks": [num_correct_chunks]
},
attrs={
"num_chunk_types": num_chunk_types,
'chunk_scheme': chunk_scheme,
'excluded_chunk_types': excluded_chunk_types or []
})
return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks
def sequence_conv(input,
num_filters,
filter_size=3,
filter_stride=1,
padding=None,
bias_attr=None,
param_attr=None,
act=None,
main_program=None,
startup_program=None):
"""
This function creates the op for sequence_conv, using the inputs and
other convolutional configurations for the filters and stride as given
in the input parameters to the function.
"""
# FIXME(dzh) : want to unify the argument of python layer
# function. So we ignore some unecessary attributes.
# such as, padding_trainable, context_start.
helper = LayerHelper('sequence_conv', **locals())
dtype = helper.input_dtype()
filter_shape = [filter_size * input.shape[1], num_filters]
filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype)
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type='sequence_conv',
inputs={
'X': [input],
'Filter': [filter_param],
},
outputs={"Out": pre_bias},
attrs={
'contextStride': filter_stride,
'contextStart': -int(filter_size / 2),
'contextLength': filter_size
})
pre_act = helper.append_bias_op(pre_bias)
return helper.append_activation(pre_act)
def conv2d(input,
num_filters,
filter_size,
stride=None,
padding=None,
groups=None,
param_attr=None,
bias_attr=None,
act=None,
name=None,
main_program=None,
startup_program=None):
"""
This function creates the op for a 2-dimensional Convolution.
This is performed using the parameters of filters(size, dimensionality etc)
, stride and other configurations for a Convolution operation.
This funciton can also append an activation on top of the
conv-2d output, if mentioned in the input parameters.
"""
if stride is None:
stride = [1, 1]
helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype()
num_channels = input.shape[1]
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups
if isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
if isinstance(stride, int):
stride = [stride, stride]
if isinstance(padding, int):
padding = [padding, padding]
input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size
def _get_default_param_initializer():
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
return Normal(0.0, std, 0)
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
default_initializer=_get_default_param_initializer())
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type='conv2d_cudnn',
inputs={
'Input': input,
'Filter': filter_param,
},
outputs={"Output": pre_bias},
attrs={'strides': stride,
'paddings': padding,
'groups': groups})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
return helper.append_activation(pre_act)
def sequence_pool(input, pool_type, **kwargs):
"""
This function add the operator for sequence pooling.
This is applied on top of the input using pool_type mentioned
in the parameters.
"""
helper = LayerHelper('sequence_pool', input=input, **kwargs)
dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype)
max_index = helper.create_tmp_variable(dtype)
helper.append_op(
type="sequence_pool",
inputs={"X": input},
outputs={"Out": pool_out,
"MaxIndex": max_index},
attrs={"pooltype": pool_type.upper()})
return pool_out
def pool2d(input,
pool_size,
pool_type,
pool_stride=None,
pool_padding=None,
global_pooling=False,
main_program=None,
startup_program=None):
"""
This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters.
"""
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if isinstance(pool_size, int):
pool_size = [pool_size, pool_size]
if isinstance(pool_stride, int):
pool_stride = [pool_stride, pool_stride]
if isinstance(pool_padding, int):
pool_padding = [pool_padding, pool_padding]
helper = LayerHelper('pool2d', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="pool2d",
inputs={"X": input},
outputs={"Out": pool_out},
attrs={
"pooling_type": pool_type,
"ksize": pool_size,
"global_pooling": global_pooling,
"strides": pool_stride,
"paddings": pool_padding
})
return pool_out
def batch_norm(input,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
main_program=None,
startup_program=None):
"""
This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters.
"""
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
if data_layout == 'NHWC':
channel_num = input_shape[-1]
else:
raise ValueError("unsupported data layout:" + data_layout)
param_shape = [channel_num]
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
bias = helper.create_parameter(
attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable(
dtype=input.dtype, shape=param_shape, persistable=True)
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
variance = helper.create_global_variable(
dtype=input.dtype, shape=param_shape, persistable=True)
helper.set_variable_initializer(var=variance, initializer=Constant(1.0))
# create output
# mean and mean_out share the same memory
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
saved_mean = helper.create_tmp_variable(dtype)
saved_variance = helper.create_tmp_variable(dtype)
batch_norm_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="batch_norm",
inputs={
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
},
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test})
return helper.append_activation(batch_norm_out)
def beam_search_decode(ids, scores, main_program=None, startup_program=None):
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype)
helper.append_op(
type="beam_search_decode",
inputs={"Ids": ids,
"Scores": scores},
outputs={
"SentenceIds": sentence_ids,
"SentenceScores": sentence_scores
})
return sentence_ids, sentence_scores
def conv2d_transpose(input,
num_filters,
output_size=None,
filter_size=None,
padding=None,
stride=None,
param_attr=None,
main_program=None,
startup_program=None):
"""
The transpose of conv2d layer.
This layer is also known as deconvolution layer.
Args:
input(Variable): The input image with [N, C, H, W] format.
num_filters(int): The number of filter. It is as same as the output
image channel.
output_size(int|tuple|None): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). This
parameter only works when filter_size is None.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square. None if use output size to
calculate filter_size
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding.
stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride.
param_attr: Parameter Attribute.
main_program(Program): the main program
startup_program(Program): the startup program
Returns:
Variable: Output image.
"""
helper = LayerHelper("conv2d_transpose", **locals())
if not isinstance(input, Variable):
raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1]
op_attr = dict()
if isinstance(padding, int):
op_attr['paddings'] = [padding, padding]
elif padding is not None:
op_attr['paddings'] = padding
if isinstance(stride, int):
op_attr['strides'] = stride
elif stride is not None:
op_attr['strides'] = stride
if filter_size is None:
if output_size is None:
raise ValueError("output_size must be set when filter_size is None")
if isinstance(output_size, int):
output_size = [output_size, output_size]
padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1])
h_in = input.shape[2]
w_in = input.shape[3]
filter_size_h = output_size[0] - \
(h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - \
(w_in - 1) * stride[1] + 2 * padding[1]
filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
filter_shape = [input_channel, num_filters] + filter_size
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='conv2d_transpose',
inputs={'Input': [input],
'Filter': [img_filter]},
outputs={'Output': out},
attrs=op_attr)
return out
from ..registry import register_layer
__all__ = [
'mean', 'mul', 'dropout', 'reshape', 'sigmoid', 'scale', 'transpose',
'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div',
'elementwise_sub', 'elementwise_mul', 'clip', 'abs'
]
for _OP in set(__all__):
globals()[_OP] = register_layer(_OP)
from ..layer_helper import LayerHelper
__all__ = [
'create_tensor', 'cast', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'ones', 'zeros'
]
def create_tensor(dtype, name=None, main_program=None, startup_program=None):
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name, dtype=dtype)
def cast(x, dtype, main_program=None):
"""
This function takes in the input with input_dtype
and casts it to the output_dtype as the output.
"""
helper = LayerHelper('cast', **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='cast',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'in_dtype': x.dtype,
'out_dtype': out.dtype})
return out
def concat(input, axis, main_program=None, startup_program=None):
"""
This function concats the input along the axis mentioned
and returns that as the output.
"""
helper = LayerHelper('concat', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='concat',
inputs={'X': input},
outputs={'Out': [out]},
attrs={'axis': axis})
return out
def sums(input, out=None, main_program=None, startup_program=None):
"""
This function takes in the input and performs the sum operation on it
and returns that as the output.
"""
helper = LayerHelper('sum', **locals())
if out is None:
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out})
return out
def assign(input, output, main_program=None, startup_program=None):
helper = LayerHelper('assign', **locals())
helper.append_op(
type='scale',
inputs={'X': [input]},
outputs={'Out': [output]},
attrs={'scale': 1.0})
return output
def fill_constant(shape,
dtype,
value,
out=None,
main_program=None,
startup_program=None):
"""
This function creates a tensor , with shape as mentioned in the input and
specified dtype and fills this up with a constant value that
comes in the input. It also sets the stop_gradient to be True.
"""
helper = LayerHelper("fill_constant", **locals())
if out is None:
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={'shape': shape,
'dtype': out.dtype,
'value': float(value)})
out.stop_gradient = True
return out
def fill_constant_batch_size_like(input,
shape,
dtype,
value,
input_dim_idx=0,
output_dim_idx=0,
main_program=None,
startup_program=None):
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx
})
out.stop_gradient = True
return out
def ones(shape, dtype, main_program=None):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 1.0.
"""
return fill_constant(value=1.0, **locals())
def zeros(shape, dtype, main_program=None):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 0.0.
"""
return fill_constant(value=0.0, **locals())
from __future__ import print_function
import numpy as np
import sys
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import sys
def resnet_cifar10(input, depth=32):
......
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid.layer_helper import LayerHelper
def lstm(x,
c_pre_init,
hidden_dim,
forget_bias=None,
main_program=None,
startup_program=None):
"""
This function helps create an operator for the LSTM (Long Short Term
Memory) cell that can be used inside an RNN.
"""
helper = LayerHelper('lstm_unit', **locals())
rnn = fluid.layers.StaticRNN()
with rnn.step():
c_pre = rnn.memory(init=c_pre_init)
x_t = rnn.step_input(x)
before_fc = fluid.layers.concat(
input=[x_t, c_pre],
axis=1,
main_program=main_program,
startup_program=startup_program)
after_fc = fluid.layers.fc(input=before_fc,
size=hidden_dim * 4,
main_program=main_program,
startup_program=startup_program)
dtype = x.dtype
c = helper.create_tmp_variable(dtype)
h = helper.create_tmp_variable(dtype)
helper.append_op(
type='lstm_unit',
inputs={"X": after_fc,
"C_prev": c_pre},
outputs={"C": c,
"H": h},
attrs={"forget_bias": forget_bias})
rnn.update_memory(c_pre, c)
rnn.output(h)
return rnn()
def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50):
......@@ -23,8 +68,7 @@ def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50):
c_pre_init = fluid.layers.fill_constant(
dtype=emb.dtype, shape=[batch_size, emb_dim], value=0.0)
c_pre_init.stop_gradient = False
layer_1_out = fluid.layers.lstm(
emb, c_pre_init=c_pre_init, hidden_dim=emb_dim)
layer_1_out = lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim)
layer_1_out = fluid.layers.transpose(x=layer_1_out, axis=[1, 0, 2])
prediction = fluid.layers.fc(input=layer_1_out,
......
......@@ -68,6 +68,7 @@ packages=['paddle',
'paddle.v2.plot',
'paddle.v2.fluid',
'paddle.v2.fluid.proto',
'paddle.v2.fluid.layers',
'py_paddle']
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册