提交 abfa81b1 编写于 作者: L Luo Tao

Merge branch 'develop' into seqpool

......@@ -76,6 +76,17 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
const std::vector<DySeqMeta>& meta, const LoD& lod,
size_t level);
std::vector<size_t> GenDyBatchIndice(const DySeqMetaBatch& meta, int batch_id) {
// collect indice need to copy to the batch
std::vector<size_t> indice;
for (const auto& seq : meta) {
size_t id = seq.begin + batch_id;
if (id >= seq.end) break;
indice.push_back(id);
}
return indice;
}
} // namespace detail
const LoDTensor& TensorArray::Read(size_t index) const {
......@@ -113,8 +124,8 @@ LoDTensor TensorArray::Pack(size_t level, const std::vector<DySeqMeta>& meta,
return detail::PackDynamicBatch(values_, meta, lod, level);
}
std::vector<DySeqMeta> TensorArray::Unpack(const LoDTensor& source, int level,
bool length_desend) {
DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level,
bool length_desend) {
detail::DynamicBatchUnpacker unpacker(source, level,
length_desend /*descend*/);
......@@ -129,6 +140,7 @@ std::vector<DySeqMeta> TensorArray::Unpack(const LoDTensor& source, int level,
Write(batch_id, unpacker.GetBatch(batch_id));
}
PADDLE_ENFORCE(!unpacker.meta.empty());
return unpacker.meta;
}
......@@ -218,13 +230,7 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
PADDLE_ENFORCE(!meta.empty(), "should build meta first");
LoDTensor result;
// collect indice need to copy to the batch
std::vector<size_t> indice;
for (const auto& seq : meta) {
size_t id = seq.begin + index;
if (id >= seq.end) break;
indice.push_back(id);
}
auto indice = detail::GenDyBatchIndice(meta, index);
PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index);
// copy the indice of records in LoDTensor
......@@ -237,9 +243,9 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
for (size_t i = 0; i < indice.size(); i++) {
auto index = indice[i];
auto target = result.Slice<value_type>(i, i + 1);
auto source_ = source->Slice<value_type>(index, index + 1);
auto slice = source->Slice<value_type>(index, index + 1);
target.CopyFrom<value_type>(source_, platform::CPUPlace(),
target.CopyFrom<value_type>(slice, platform::CPUPlace(),
platform::CPUDeviceContext());
}
......
......@@ -34,6 +34,13 @@ struct DySeqMeta {
size_t ori_idx;
};
using DySeqMetaBatch = std::vector<DySeqMeta>;
/*
* Extract the indices of instances.
*/
std::vector<size_t> GenDyBatchIndice(const DySeqMetaBatch &metas, int batch_id);
/*
* TensorArray is a C-array-like array of tensors, it is meant to be used with
* dynamic iteration primitives such as while_loop. It is used to segment inputs
......@@ -69,7 +76,7 @@ class TensorArray {
* Recover the original LoD-arranged LoDTensor with the `values`, `level` and
* `indice_map`.
*/
LoDTensor Pack(size_t level, const std::vector<DySeqMeta> &meta,
LoDTensor Pack(size_t level, const DySeqMetaBatch &meta,
const LoD &lod) const;
/*
......@@ -77,8 +84,7 @@ class TensorArray {
* `values`, if set `desend`, will sort by length in descending order else in
* ascending order.
*/
std::vector<DySeqMeta> Unpack(const LoDTensor &source, int level,
bool length_desend);
DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend);
/*
* Pack the values into a tensor with rank one higher than each tensor in
......
......@@ -84,8 +84,9 @@ function(op_library TARGET)
endif()
# pybind USE_NO_KERNEL_OP
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}")
string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
string(REPLACE "_op" "" TARGET "${TARGET}")
if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
......
......@@ -338,6 +338,38 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename AttrType>
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardSigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator");
AddComment(R"DOC(
Hard Sigmoid activation operator.
Segment-wise linear approximation of sigmoid[1].
This is much faster than sigmoid.
hard_sigmoid = max(0, min(1, slope * x + shift))
The slope should be positive. The offset can be either positive or negative.
The default slope and shift are set from [1].
It is recommended to use the defaults for this activation.
References:
[1] Noisy Activation Functions
(https://arxiv.org/abs/1603.00391)
)DOC");
AddAttr<AttrType>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(static_cast<AttrType>(0.2));
AddAttr<AttrType>("offset", "Offset for linear approximation of sigmoid")
.SetDefault(static_cast<AttrType>(0.5));
}
};
} // namespace operators
} // namespace paddle
......@@ -413,6 +445,9 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp,
ops::ThresholdedReluOpMaker<float>, thresholded_relu_grad,
ops::ActivationOpGrad);
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker<float>,
hard_sigmoid_grad, ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
......
......@@ -616,30 +616,63 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
y.device(d) = temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
}
};
template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy *
((y > static_cast<T>(0)) * (y < static_cast<T>(1))).template cast<T>() *
static_cast<T>(slope);
}
};
} // namespace operators
} // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);
......@@ -23,6 +23,7 @@ using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
using framework::DySeqMetaBatch;
namespace detail {
......@@ -33,6 +34,29 @@ inline void CreateVariables(Scope& scope,
}
}
/*
* The inputs with sequence should be reordered when they are split, so the
* boot_states should be reordered in the same order.
*
* NOTE This may require that the `pre_state` of the first time step should just
* copy the `boot_state` rather than reference it, for that the content should
* be reordered, but the RNN op should not change the `boot_state` as an input
* variable's content.
*/
template <typename T>
inline void ReorderBootState(const DySeqMetaBatch& metas,
const LoDTensor& boot_state, LoDTensor* tensor,
const platform::Place& dst_place) {
for (size_t seq_id = 0; seq_id < metas.size(); seq_id++) {
auto slice = tensor->Slice<T>(seq_id, seq_id + 1);
auto boot_slice =
boot_state.Slice<T>(metas[seq_id].ori_idx, metas[seq_id].ori_idx + 1);
// TODO(superjom) pass in device context as an argument
slice.template CopyFrom<T>(boot_slice, dst_place,
platform::CPUDeviceContext());
}
}
} // namespace detail
class DynamicRecurrentOpProtoAndCheckerMaker
......@@ -69,6 +93,7 @@ void DynamicRecurrentOp::Run(const Scope& scope,
CreateScopes();
WriteStepInputs();
InitStates();
WriteStepOutputs();
// call stepnet in all the time steps
for (size_t step = 0; step < cache_.num_steps; step++) {
......@@ -76,7 +101,6 @@ void DynamicRecurrentOp::Run(const Scope& scope,
stepnet_->Run(step_scope, dev_ctx);
}
WriteStepOutputs();
ConcatOutputs();
}
......@@ -84,11 +108,11 @@ void DynamicRecurrentOp::SplitInputs() const {
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int level = 0;
const auto& inlinks = cache_.inlinks;
for (const auto& item : inlinks) {
for (const auto& item : cache_.inlinks) {
const auto& var = item.second;
const auto& tensor = var->Get<LoDTensor>();
TensorArray& ta = step_inputs_[item.first];
dy_seq_metas_[item.first] =
ta.Unpack(tensor, level, true /*length_descend*/);
......@@ -120,17 +144,11 @@ void DynamicRecurrentOp::WriteStepInputs() const {
}
void DynamicRecurrentOp::WriteStepOutputs() const {
for (size_t step = 0; step < cache_.scopes->size(); step++) {
auto& scope = cache_.GetScope(step);
for (auto& item : step_outputs_) {
auto* var = scope.FindVar(item.first);
if (var == nullptr) {
var = scope.NewVar(item.first);
}
auto* tensor = var->GetMutable<LoDTensor>();
item.second.WriteShared(step, *tensor);
}
// initialize step outputs
for (const auto& item : cache_.outlinks) {
step_outputs_.emplace(item.first, TensorArray());
}
PADDLE_ENFORCE_GT(step_outputs_.size(), 0UL);
}
void DynamicRecurrentOp::CreateScopes() const {
......@@ -145,12 +163,18 @@ void DynamicRecurrentOp::CreateScopes() const {
PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first");
std::vector<std::string> memories;
std::vector<std::string> pre_memories;
std::vector<std::string> stepnet_outputs;
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(memories),
[](const rnn::MemoryAttr& m) { return m.var; });
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(pre_memories),
[](const rnn::MemoryAttr& m) { return m.pre_var; });
for (const auto& item : stepnet_->Outputs()) {
for (const auto& var : item.second) {
stepnet_outputs.push_back(var);
}
}
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& scope = cache_.GetScope(step);
......@@ -158,60 +182,88 @@ void DynamicRecurrentOp::CreateScopes() const {
detail::CreateVariables(scope, arg_.outlinks);
detail::CreateVariables(scope, memories);
detail::CreateVariables(scope, pre_memories);
detail::CreateVariables(scope, stepnet_outputs);
}
}
void DynamicRecurrentOp::ConcatOutputs() const {
// TODO(superjom) transform this to a config
int level = 0;
// TODO(superjom) pass in some lod
// just a placeholder
framework::LoD lod;
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& scope = cache_.GetScope(step);
for (auto& item : step_outputs_) {
auto* var = scope.FindVar(item.first);
PADDLE_ENFORCE_NOT_NULL(var);
auto* tensor = var->GetMutable<LoDTensor>();
tensor->mutable_data<value_type>(platform::CPUPlace());
item.second.WriteShared(step, *tensor);
}
}
// the inlinks' lods should be the same, so randomly get one lod.
const auto& some_lod =
cache_.scope->FindVar(arg_.inlinks.front())->Get<LoDTensor>().lod();
const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()];
for (auto& item : step_outputs_) {
auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod);
auto& output = cache_.outlinks[item.first]->Get<LoDTensor>();
const_cast<LoDTensor*>(&output)->ShareDataWith<value_type>(tensor);
auto tensor = item.second.Pack(level, some_meta, some_lod);
auto* output = cache_.outlinks[item.first]->GetMutable<LoDTensor>();
const_cast<LoDTensor*>(output)->ShareDataWith<value_type>(tensor);
}
}
void DynamicRecurrentOp::InitStates() const {
// init the first state
// TODO(superjom) parepare the scenerio that boot state not exists
for (auto memory : arg_.memories) {
auto* boot_state_var = cache_.scope->FindVar(memory.boot_var);
PADDLE_ENFORCE_NOT_NULL(boot_state_var);
auto& boot_state = boot_state_var->Get<LoDTensor>();
const auto& dims = boot_state.dims();
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& cur_scope = cache_.GetScope(step);
// link pre-state to boot_state
// init state and pre-state
auto* pre_state = cur_scope.FindVar(memory.pre_var);
PADDLE_ENFORCE_NOT_NULL(pre_state);
pre_state->GetMutable<LoDTensor>();
auto* state = cur_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state);
state->GetMutable<LoDTensor>()->Resize(dims);
state->GetMutable<LoDTensor>()->mutable_data<value_type>(
platform::CPUPlace());
if (step == 0) {
auto* pre_state_tensor = pre_state->GetMutable<LoDTensor>();
pre_state_tensor->Resize(boot_state.dims());
pre_state_tensor->ShareDataWith<value_type>(boot_state);
} else {
auto& pre_scope = cache_.GetScope(step - 1);
auto* state_pre = pre_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state_pre);
pre_state->GetMutable<LoDTensor>()->ShareDataWith<value_type>(
*state_pre->GetMutable<LoDTensor>());
}
for (size_t step = 0; step < cache_.num_steps; step++) {
for (const auto& memory : arg_.memories) {
CreateState(memory, step);
LinkState(memory, step);
}
}
}
void DynamicRecurrentOp::CreateState(const rnn::MemoryAttr& memory,
size_t step) const {
auto& scope = cache_.GetScope(step);
auto& state = *cache_.GetTensor(scope, memory.var);
auto& boot_state = *cache_.GetTensor(*cache_.scope, memory.boot_var);
size_t num_instances =
step_inputs_[arg_.inlinks.front()].Read(step).dims()[0];
auto dims = boot_state.dims();
dims[0] = num_instances;
state.Resize(dims);
state.mutable_data<value_type>(platform::CPUPlace());
states_[memory.var].WriteShared(step, state);
}
void DynamicRecurrentOp::LinkState(const rnn::MemoryAttr& memory,
size_t step) const {
auto& scope = cache_.GetScope(step);
auto& state_pre = *cache_.GetTensor(scope, memory.pre_var);
// all the step_inputs' metas should be the same, just randomly select one
// and get the dyseq meta.
const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()];
size_t num_instances =
step_inputs_[arg_.inlinks.front()].Read(step).dims()[0];
LoDTensor* pre_state{nullptr};
if (step == 0) {
pre_state = cache_.GetTensor(*cache_.scope, memory.boot_var);
pre_state->mutable_data<float>(platform::CPUPlace());
// allocate memory
state_pre.Resize(pre_state->dims());
state_pre.mutable_data<value_type>(platform::CPUPlace());
detail::ReorderBootState<value_type>(some_meta, *pre_state, &state_pre,
pre_state->place());
} else {
pre_state = cache_.GetTensor(cache_.GetScope(step - 1), memory.var);
}
// shink and share from previous state
auto shrinked_pre_state = pre_state->Slice<value_type>(0, num_instances);
state_pre.ShareDataWith<value_type>(shrinked_pre_state);
}
void DynamicRecurrentOp::ArgCache::Init(
const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op,
const paddle::framework::Scope& scope, rnn::Argument* arg) {
......@@ -261,6 +313,12 @@ Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope,
return var;
}
LoDTensor* DynamicRecurrentOp::ArgCache::GetTensor(
const framework::Scope& scope, const std::string& name) {
auto* var = GetVariable(scope, name);
return var->GetMutable<LoDTensor>();
}
const rnn::ArgumentName DynamicRecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};
......
......@@ -77,6 +77,17 @@ class DynamicRecurrentOp : public framework::OperatorBase {
*/
void InitStates() const;
/*
* Create state variables for each time step.
*/
void CreateState(const rnn::MemoryAttr& memory, size_t step) const;
/*
* Link pre-state variable in current scope to the state variable in the
* previous time step (scope).
*/
void LinkState(const rnn::MemoryAttr& memory, size_t step) const;
/*
* Concatenate outputs in each time step and generate a LoDTensor.
*/
......@@ -91,6 +102,16 @@ class DynamicRecurrentOp : public framework::OperatorBase {
}
const OperatorBase& GetStepNet() const { return *stepnet_; }
const framework::TensorArray& state(const std::string& name) const {
return states_[name];
}
const framework::TensorArray& step_input(const std::string& name) const {
return step_inputs_[name];
}
const framework::TensorArray& step_output(const std::string& name) const {
return step_outputs_[name];
}
protected:
struct ArgCache {
framework::Scope const* scope;
......@@ -108,6 +129,9 @@ class DynamicRecurrentOp : public framework::OperatorBase {
return *scopes->at(index);
}
framework::LoDTensor* GetTensor(const framework::Scope& scope,
const std::string& name);
private:
void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op,
rnn::Argument* arg);
......@@ -122,7 +146,7 @@ class DynamicRecurrentOp : public framework::OperatorBase {
private:
std::unique_ptr<OperatorBase> stepnet_;
mutable framework::TensorArray states_;
mutable std::map<std::string, framework::TensorArray> states_;
mutable std::map<std::string, framework::TensorArray> step_inputs_;
mutable std::map<std::string, framework::TensorArray> step_outputs_;
mutable std::map<std::string, std::vector<framework::DySeqMeta>>
......
......@@ -87,7 +87,6 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
platform::CPUPlace place;
scope.NewVar("step_scopes");
CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place);
// auto* out0 =
CreateVar(scope, "out0", framework::make_ddim({10, 20}), place);
auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place);
// 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gru_unit_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class GRUUnitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitOp should not be null.", "HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasOutput("Gate"),
"Output(%s) of GRUUnitOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
"Output(%s) of GRUUnitOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUUnitOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight");
int batch_size = input_dims[0];
int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0];
int weight_width = weight_dims[1];
PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp.");
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
auto bias = Input("Bias");
if (bias != framework::kEmptyVarName) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
ctx->SetOutputDim("Hidden", {batch_size, frame_size});
}
};
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"input.");
AddInput("HiddenPrev",
"(Tensor) Matrix with shape [batch_size, frame_size] for the "
"states of previous time step.");
AddInput("Weight",
"(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. "
"The elements continuous in memory can be divided into two parts. "
"The first part are weights of the update gate and reset gate "
"with shape [frame_size, frame_size * 2], and the second part are "
"weights of output candidate with shape [frame_size, frame_size]");
AddInput("Bias",
"(Tensor) Bias vector with shape [1, frame_size * 3] concating "
"bias of the update gate, reset gate and output candidate.");
AddOutput("Gate",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"output of update gate, reset gate and output candidate")
.AsIntermediate();
AddOutput("ResetHiddenPrev",
"(Tensor) Matrix with shape [batch_size, frame_size] for the "
"reseted hidden state of previous time step.")
.AsIntermediate();
AddOutput("Hidden",
"(Tensor) The GRU hidden state of the current time step "
"with shape [batch_size, frame_size].");
AddAttr<int>("activation",
"(enum int, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault(tanh)
.InEnum({identity, sigmoid, tanh, relu});
AddAttr<int>("gate_activation",
"(enum int, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu});
AddComment(R"DOC(
GRUUnitOp implements part calculations of the GRU unit as following:
\f[
update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\
output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_prev)
\f]
The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp.
)DOC");
}
};
class GRUUnitGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitGradOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitGradOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("Gate"),
"Input(%s) of GRUUnitGradOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Gate");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight");
// int batch_size = input_dims[0];
int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0];
int weight_width = weight_dims[1];
PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp.");
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
auto bias = Input("Bias");
if (bias != framework::kEmptyVarName) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
}
auto input_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(input_grad_name))
ctx->SetOutputDim(input_grad_name, input_dims);
auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev");
if (ctx->HasOutput(hidden_prev_grad_name))
ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims);
auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad,
ops::GRUUnitGradOp);
REGISTER_OP_CPU_KERNEL(gru_unit,
ops::GRUUnitKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/gru_unit_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gru_unit,
ops::GRUUnitKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/activation_op.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 };
template <typename Place, typename T>
class GRUUnitKernel : public framework::OpKernel<T> {
public:
template <typename Device, typename X, typename Y>
void ActCompute(const int act_type, const Device& d, X x, Y y) const {
if (act_type == identity)
y.device(d) = x;
else if (act_type == sigmoid)
SigmoidFunctor<T>()(d, x, y);
else if (act_type == tanh)
TanhFunctor<T>()(d, x, y);
else if (act_type == relu)
ReluFunctor<T>()(d, x, y);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("Input");
auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
auto* weight = context.Input<Tensor>("Weight");
auto* bias = context.Input<Tensor>("Bias");
auto* gate = context.Output<Tensor>("Gate");
gate->mutable_data<T>(context.GetPlace());
auto* reset_hidden_prev = context.Output<Tensor>("ResetHiddenPrev");
reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<Tensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];
auto x = EigenMatrix<T>::From(*input);
auto h_p = EigenMatrix<T>::From(*hidden_prev);
auto g = EigenMatrix<T>::From(*gate);
auto r_h_p = EigenMatrix<T>::From(*reset_hidden_prev);
auto h = EigenMatrix<T>::From(*hidden);
auto place = context.GetEigenDevice<Place>();
// calculate unactivated gate outputs
if (bias) {
auto b = EigenMatrix<T>::From(*bias);
g.device(place) = x +
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
} else {
g.device(place) = x;
}
const T* hidden_prev_data = hidden_prev->data<T>();
const T* weight_data = weight->data<T>();
T* gate_data = gate->data<T>();
T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
math::gemm<Place, T>(context.device_context(), false, false, batch_size,
2 * frame_size, frame_size, 1, hidden_prev_data,
frame_size, weight_data, frame_size * 2, 1, gate_data,
frame_size * 3);
// calculate activited gate
Eigen::array<int, 2> extents({{batch_size, frame_size}});
Eigen::array<int, 2> u_offsets({{0, 0}});
ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(u_offsets, extents), g.slice(u_offsets, extents));
auto u = g.slice(u_offsets, extents); // update gate
Eigen::array<int, 2> r_offsets({{0, frame_size}});
ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state
math::gemm<Place, T>(context.device_context(), false, false, batch_size,
frame_size, frame_size, 1, reset_hidden_prev_data,
frame_size, weight_data + frame_size * frame_size * 2,
frame_size, 1, gate_data + frame_size * 2,
frame_size * 3);
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
ActCompute(context.Attr<int>("activation"), place,
g.slice(c_offsets, extents), g.slice(c_offsets, extents));
auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output
h.device(place) = u * (h_p - c) + c;
}
};
template <typename Place, typename T>
class GRUUnitGradKernel : public framework::OpKernel<T> {
public:
template <typename Device, typename X, typename Y, typename DX, typename DY>
void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx,
DY dy) const {
// x is dummy and won't be used even in Relu(use y instead)
if (act_type == identity)
dx.device(d) = dy;
else if (act_type == sigmoid)
SigmoidGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == tanh)
TanhGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == relu)
ReluGradFunctor<T>()(d, x, y, dy, dx);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("Input");
auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
auto* weight = context.Input<Tensor>("Weight");
auto* gate = context.Input<Tensor>("Gate");
auto* reset_hidden_prev = context.Input<Tensor>("ResetHiddenPrev");
auto* hidden_grad = context.Input<Tensor>(framework::GradVarName("Hidden"));
auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
auto* hidden_prev_grad =
context.Output<Tensor>(framework::GradVarName("HiddenPrev"));
auto* weight_grad =
context.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
input_grad->mutable_data<T>(context.GetPlace());
hidden_prev_grad->mutable_data<T>(context.GetPlace());
weight_grad->mutable_data<T>(context.GetPlace());
Tensor gate_grad;
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
Tensor reset_hidden_prev_grad;
reset_hidden_prev_grad.mutable_data<T>(reset_hidden_prev->dims(),
context.GetPlace());
int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];
const T* hidden_prev_data = hidden_prev->data<T>();
T* hidden_prev_grad_data = hidden_prev_grad->data<T>();
const T* weight_data = weight->data<T>();
T* weight_grad_data = weight_grad->data<T>();
T* gate_grad_data = gate_grad.data<T>();
const T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data<T>();
auto h_p = EigenMatrix<T>::From(*hidden_prev);
auto g = EigenMatrix<T>::From(*gate);
auto d_h = EigenMatrix<T>::From(*hidden_grad);
auto d_x = EigenMatrix<T>::From(*input_grad);
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
auto d_g = EigenMatrix<T>::From(gate_grad);
auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad);
auto place = context.GetEigenDevice<Place>();
Eigen::array<int, 2> extents({{batch_size, frame_size}});
Eigen::array<int, 2> u_offsets({{0, 0}});
auto u = g.slice(u_offsets, extents); // update gate
Eigen::array<int, 2> r_offsets({{0, frame_size}});
auto r = g.slice(r_offsets, extents); // reset gate
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
// backward for reset_hidden_prev
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size, 1,
gate_grad_data + frame_size * 2, frame_size * 3,
weight_data + frame_size * frame_size * 2, frame_size,
0, reset_hidden_prev_grad_data, frame_size);
// backward for state_weight
math::gemm<Place, T>(
context.device_context(), true, false, frame_size, frame_size,
batch_size, 1, reset_hidden_prev_data, frame_size,
gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for unactivated reset gate
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
// backward for update_gate_weight and reset_gate_weight
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
frame_size * 2, batch_size, 1, hidden_prev_data,
frame_size, gate_grad_data, frame_size * 3, 0,
weight_grad_data, frame_size * 2);
// backward for hidden_prev
d_h_p.device(place) = d_r_h_p * r + d_h * u;
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size * 2, 1, gate_grad_data,
frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size);
// backward for input
d_x.device(place) = d_g;
// backward for bias
if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
auto d_b = EigenMatrix<T>::From(*bias_grad);
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -34,7 +34,7 @@ class SumOp : public framework::OperatorWithKernel {
auto in_dim = x_dims[0];
for (size_t i = 1; i < N; i++) {
auto dim = x_dims[i];
PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
PADDLE_ENFORCE_EQ(in_dim, dim, "Input tensors must have same shape");
}
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/operators/cond_op.h"
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
......@@ -341,6 +342,33 @@ All parameter, weight, gradient are variables in Paddle.
self.set_stepnet(net.Clone());
});
py::class_<operators::DynamicRecurrentOp, OperatorBase>(m,
"DynamicRecurrentOp")
.def_static("create",
[](py::bytes protobin) -> operators::DynamicRecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release());
})
.def("set_stepnet",
[](operators::DynamicRecurrentOp &self, const operators::NetOp &net)
-> void { self.SetStepNet(net.Clone()); })
.def("get_state",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> const TensorArray & { return self.state(name); })
.def("get_step_input",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> const TensorArray & { return self.step_input(name); })
.def("get_step_output",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> const TensorArray & { return self.step_output(name); });
// cond_op
py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
.def_static("create",
......
......@@ -219,6 +219,27 @@ class __RecurrentOp__(object):
return core.RecurrentOp.create(proto.SerializeToString())
class __DynamicRecurrentOp__(object):
__proto__ = None
type = "dynamic_recurrent"
def __init__(self):
# cache recurrent_op's proto
if self.__proto__ is None:
for op_proto in get_all_op_protos():
if op_proto.type == self.type:
self.__proto__ = op_proto
def __call__(self, *args, **kwargs):
if self.type not in args and "type" not in kwargs:
kwargs["type"] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
# create rnnop
return core.DynamicRecurrentOp.create(proto.SerializeToString())
class __CondOp__(object):
__proto__ = None
type = "cond"
......@@ -242,4 +263,5 @@ class __CondOp__(object):
Operator = OperatorFactory() # The default global factory
RecurrentOp = __RecurrentOp__()
DynamicRecurrentOp = __DynamicRecurrentOp__()
CondOp = __CondOp__()
......@@ -384,5 +384,33 @@ class TestThresholdedRelu(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=self.relative_error)
class TestHardSigmoid(OpTest):
def setUp(self):
self.op_type = "hard_sigmoid"
self.relative_error = 0.002
X = np.random.uniform(-5, 5, [2, 2]).astype("float32")
slope = 0.2
offset = 0.5
lower_threshold = -offset / slope
upper_threshold = (1 - offset) / slope
self.inputs = {'X': X}
# Same reason as TestAbs
X[np.abs(X - lower_threshold) < self.relative_error] = \
lower_threshold + 0.2
X[np.abs(X - upper_threshold) < self.relative_error] = \
upper_threshold - 0.2
temp = X * slope + offset
self.outputs = {'Y': np.maximum(0.0, np.minimum(1.0, temp))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.002)
if __name__ == "__main__":
unittest.main()
import logging
import paddle.v2.framework.core as core
import unittest
from paddle.v2.framework.op import Operator, DynamicRecurrentOp
import numpy as np
def create_tensor(scope, name, shape, np_data):
tensor = scope.new_var(name).get_tensor()
tensor.set_dims(shape)
tensor.set(np_data, core.CPUPlace())
return tensor
class DynamicRecurrentOpTest(unittest.TestCase):
'''
Test RNNOp
equation:
h_t = \sigma (W x_t + U h_{t-1})
weights:
- W
- U
vars:
- x
memories:
- h
outputs:
- h
'''
# for siplicity, just one level LoD
lod_py = [[0, 4, 7, 9, 10]]
input_dim = 30
num_sents = len(lod_py[0]) - 1
weight_dim = 15
def forward(self):
self.scope = core.Scope()
self.create_global_variables()
self.create_rnn_op()
self.create_step_net()
ctx = core.DeviceContext.create(core.CPUPlace())
self.rnnop.run(self.scope, ctx)
state = self.rnnop.get_state("h@mem")
print 'state size: ', state.size()
step_inputs = self.rnnop.get_step_input("x")
print "x size ", step_inputs.size()
for i in range(step_inputs.size()):
print "x %d" % i, np.array(step_inputs.read(i).get_dims())
step_outputs = self.rnnop.get_step_output('h@mem')
print 'step_outputs.size ', step_outputs.size()
output = self.scope.find_var("h@mem").get_tensor()
print 'output', np.array(output).shape
def create_global_variables(self):
x = np.random.normal(size=(self.lod_py[0][-1],
self.input_dim)).astype("float32")
W = np.random.normal(size=(self.input_dim,
self.input_dim)).astype("float32")
U = np.random.normal(size=(self.input_dim,
self.input_dim)).astype("float32")
h_boot = np.random.normal(size=(self.num_sents,
self.input_dim)).astype("float32")
# create inlink
x_tensor = create_tensor(self.scope, "x",
[self.num_sents, self.input_dim], x)
x_tensor.set_lod(self.lod_py)
create_tensor(self.scope, "W", [self.input_dim, self.input_dim], W)
create_tensor(self.scope, "U", [self.input_dim, self.input_dim], U)
create_tensor(self.scope, "h_boot", [self.num_sents, self.input_dim],
h_boot)
self.scope.new_var("step_scopes")
self.scope.new_var("h@mem")
def create_rnn_op(self):
# create RNNOp
self.rnnop = DynamicRecurrentOp(
# inputs
inlinks=["x"],
boot_memories=["h_boot"],
step_net="stepnet",
# outputs
outlinks=["h@mem"],
step_scopes="step_scopes",
# attributes
pre_memories=["h@pre"],
memories=["h@mem"])
def create_step_net(self):
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@mem")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.append_op(op)
stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet)
def test_forward(self):
print 'test recurrent op forward'
pd_output = self.forward()
print 'pd_output', pd_output
if __name__ == '__main__':
unittest.main()
import math
import unittest
import numpy as np
from op_test import OpTest
class GRUActivationType(OpTest):
identity = 0
sigmoid = 1
tanh = 2
relu = 3
def identity(x):
return x
def sigmoid(x):
return 1. / (1. + np.exp(-x))
def tanh(x):
return 2. * sigmoid(2. * x) - 1.
def relu(x):
return np.maximum(x, 0)
class TestGRUUnitOp(OpTest):
batch_size = 3
frame_size = 5
activate = {
GRUActivationType.identity: identity,
GRUActivationType.sigmoid: sigmoid,
GRUActivationType.tanh: tanh,
GRUActivationType.relu: relu,
}
def set_inputs(self):
batch_size = self.batch_size
frame_size = self.frame_size
self.op_type = 'gru_unit'
self.inputs = {
'Input': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'),
'HiddenPrev': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size)).astype('float32'),
'Weight': np.random.uniform(
-1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size),
(frame_size, frame_size * 3)).astype('float32'),
}
self.attrs = {
'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid
}
def set_outputs(self):
# GRU calculations
batch_size = self.batch_size
frame_size = self.frame_size
x = self.inputs['Input']
h_p = self.inputs['HiddenPrev']
w = self.inputs['Weight']
b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros(
(1, frame_size * 3))
g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2))
u_r = self.activate[self.attrs['gate_activation']](np.dot(
h_p, w_u_r) + g[:, :frame_size * 2])
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
(frame_size, frame_size))
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
h = u * h_p + (1 - u) * c
self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h}
def setUp(self):
self.set_inputs()
self.set_outputs()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['Input', 'HiddenPrev', 'Weight'], ['Hidden'],
max_relative_error=0.007)
class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self):
batch_size = self.batch_size
frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs()
self.inputs['Bias'] = np.random.uniform(
-0.1, 0.1, (1, frame_size * 3)).astype('float32')
self.attrs = {
'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid
}
def test_check_grad(self):
self.check_grad(
['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
max_relative_error=0.007)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册