diff --git a/paddle/framework/tensor_array.cc b/paddle/framework/tensor_array.cc index 7ae16e99cdb8a23f14f0c8b684ba4ec66a4ce074..06459cbfd7b8c19c176452ff73c9f3a81ba1dc03 100644 --- a/paddle/framework/tensor_array.cc +++ b/paddle/framework/tensor_array.cc @@ -76,6 +76,17 @@ LoDTensor PackDynamicBatch(const std::vector& source, const std::vector& meta, const LoD& lod, size_t level); +std::vector GenDyBatchIndice(const DySeqMetaBatch& meta, int batch_id) { + // collect indice need to copy to the batch + std::vector indice; + for (const auto& seq : meta) { + size_t id = seq.begin + batch_id; + if (id >= seq.end) break; + indice.push_back(id); + } + return indice; +} + } // namespace detail const LoDTensor& TensorArray::Read(size_t index) const { @@ -113,8 +124,8 @@ LoDTensor TensorArray::Pack(size_t level, const std::vector& meta, return detail::PackDynamicBatch(values_, meta, lod, level); } -std::vector TensorArray::Unpack(const LoDTensor& source, int level, - bool length_desend) { +DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level, + bool length_desend) { detail::DynamicBatchUnpacker unpacker(source, level, length_desend /*descend*/); @@ -129,6 +140,7 @@ std::vector TensorArray::Unpack(const LoDTensor& source, int level, Write(batch_id, unpacker.GetBatch(batch_id)); } + PADDLE_ENFORCE(!unpacker.meta.empty()); return unpacker.meta; } @@ -218,13 +230,7 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) { PADDLE_ENFORCE(!meta.empty(), "should build meta first"); LoDTensor result; - // collect indice need to copy to the batch - std::vector indice; - for (const auto& seq : meta) { - size_t id = seq.begin + index; - if (id >= seq.end) break; - indice.push_back(id); - } + auto indice = detail::GenDyBatchIndice(meta, index); PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index); // copy the indice of records in LoDTensor @@ -237,9 +243,9 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) { for (size_t i = 0; i < indice.size(); i++) { auto index = indice[i]; auto target = result.Slice(i, i + 1); - auto source_ = source->Slice(index, index + 1); + auto slice = source->Slice(index, index + 1); - target.CopyFrom(source_, platform::CPUPlace(), + target.CopyFrom(slice, platform::CPUPlace(), platform::CPUDeviceContext()); } diff --git a/paddle/framework/tensor_array.h b/paddle/framework/tensor_array.h index 293da04997304be41810446cb3e866d545805f83..046ecb5221b7ed9d88e5017348ee8fcde23c7677 100644 --- a/paddle/framework/tensor_array.h +++ b/paddle/framework/tensor_array.h @@ -34,6 +34,13 @@ struct DySeqMeta { size_t ori_idx; }; +using DySeqMetaBatch = std::vector; + +/* + * Extract the indices of instances. + */ +std::vector GenDyBatchIndice(const DySeqMetaBatch &metas, int batch_id); + /* * TensorArray is a C-array-like array of tensors, it is meant to be used with * dynamic iteration primitives such as while_loop. It is used to segment inputs @@ -69,7 +76,7 @@ class TensorArray { * Recover the original LoD-arranged LoDTensor with the `values`, `level` and * `indice_map`. */ - LoDTensor Pack(size_t level, const std::vector &meta, + LoDTensor Pack(size_t level, const DySeqMetaBatch &meta, const LoD &lod) const; /* @@ -77,8 +84,7 @@ class TensorArray { * `values`, if set `desend`, will sort by length in descending order else in * ascending order. */ - std::vector Unpack(const LoDTensor &source, int level, - bool length_desend); + DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend); /* * Pack the values into a tensor with rank one higher than each tensor in diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index ad941bde2be3bbbc6d910fff262ea4cb3878f8be..75fcc1cda165197fc4413efc6bbbc440088cb4cd 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -84,8 +84,9 @@ function(op_library TARGET) endif() # pybind USE_NO_KERNEL_OP + # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel file(READ ${TARGET}.cc TARGET_CONTENT) - string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}") + string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") string(REPLACE "_op" "" TARGET "${TARGET}") if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index cba57ba57f5e03c7861897e177cc09aa513e5395..84c3775b4fc2602e5df9bb454d21b318b8fda493 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -338,6 +338,38 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HardSigmoidOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of HardSigmoid operator"); + AddOutput("Y", "Output of HardSigmoid operator"); + AddComment(R"DOC( +Hard Sigmoid activation operator. + +Segment-wise linear approximation of sigmoid[1]. +This is much faster than sigmoid. + +hard_sigmoid = max(0, min(1, slope * x + shift)) + +The slope should be positive. The offset can be either positive or negative. +The default slope and shift are set from [1]. +It is recommended to use the defaults for this activation. + +References: + [1] Noisy Activation Functions + (https://arxiv.org/abs/1603.00391) + + )DOC"); + AddAttr("slope", "Slope for linear approximation of sigmoid") + .SetDefault(static_cast(0.2)); + AddAttr("offset", "Offset for linear approximation of sigmoid") + .SetDefault(static_cast(0.5)); + } +}; + } // namespace operators } // namespace paddle @@ -413,6 +445,9 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp, ops::ThresholdedReluOpMaker, thresholded_relu_grad, ops::ActivationOpGrad); +REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker, + hard_sigmoid_grad, ops::ActivationOpGrad); + #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, \ diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 502c33be103c465c14f128be38ac62d029f1bfb9..4f4eb44fedc0a89cdcf60fb7177014a11eb96048 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -616,30 +616,63 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor { } }; +template +struct HardSigmoidFunctor : public BaseActivationFunctor { + float slope; + float offset; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"slope", &slope}, {"offset", &offset}}; + } + + template + void operator()(Device d, X x, Y y) const { + auto temp = x * static_cast(slope) + static_cast(offset); + y.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); + } +}; + +template +struct HardSigmoidGradFunctor : public BaseActivationFunctor { + float slope; + float offset; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"slope", &slope}, {"offset", &offset}}; + } + + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = + dy * + ((y > static_cast(0)) * (y < static_cast(1))).template cast() * + static_cast(slope); + } +}; + } // namespace operators } // namespace paddle -#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ - __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ - __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ - __macro(exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, ReluFunctor, ReluGradFunctor); \ - __macro(tanh, TanhFunctor, TanhGradFunctor); \ - __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ - __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ - __macro(abs, AbsFunctor, AbsGradFunctor); \ - __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ - __macro(log, LogFunctor, LogGradFunctor); \ - __macro(square, SquareFunctor, SquareGradFunctor); \ - __macro(brelu, BReluFunctor, BReluGradFunctor); \ - __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(pow, PowFunctor, PowGradFunctor); \ - __macro(stanh, STanhFunctor, STanhGradFunctor); \ - __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ - __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(relu6, Relu6Functor, Relu6GradFunctor); \ - __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ - __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ - __macro(elu, ELUFunctor, ELUGradFunctor); \ - __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ +#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ + __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ + __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ + __macro(exp, ExpFunctor, ExpGradFunctor); \ + __macro(relu, ReluFunctor, ReluGradFunctor); \ + __macro(tanh, TanhFunctor, TanhGradFunctor); \ + __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ + __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ + __macro(abs, AbsFunctor, AbsGradFunctor); \ + __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ + __macro(log, LogFunctor, LogGradFunctor); \ + __macro(square, SquareFunctor, SquareGradFunctor); \ + __macro(brelu, BReluFunctor, BReluGradFunctor); \ + __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ + __macro(pow, PowFunctor, PowGradFunctor); \ + __macro(stanh, STanhFunctor, STanhGradFunctor); \ + __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ + __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ + __macro(relu6, Relu6Functor, Relu6GradFunctor); \ + __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ + __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ + __macro(elu, ELUFunctor, ELUGradFunctor); \ + __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ + __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \ __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); diff --git a/paddle/operators/dynamic_recurrent_op.cc b/paddle/operators/dynamic_recurrent_op.cc index b919aef8fb62e5b2331c2d842556e0642ea6b095..58a5bf3e3651c963eead6dc0b8a3497c65b0eff2 100644 --- a/paddle/operators/dynamic_recurrent_op.cc +++ b/paddle/operators/dynamic_recurrent_op.cc @@ -23,6 +23,7 @@ using framework::Scope; using framework::TensorArray; using framework::LoDTensor; using framework::Variable; +using framework::DySeqMetaBatch; namespace detail { @@ -33,6 +34,29 @@ inline void CreateVariables(Scope& scope, } } +/* + * The inputs with sequence should be reordered when they are split, so the + * boot_states should be reordered in the same order. + * + * NOTE This may require that the `pre_state` of the first time step should just + * copy the `boot_state` rather than reference it, for that the content should + * be reordered, but the RNN op should not change the `boot_state` as an input + * variable's content. + */ +template +inline void ReorderBootState(const DySeqMetaBatch& metas, + const LoDTensor& boot_state, LoDTensor* tensor, + const platform::Place& dst_place) { + for (size_t seq_id = 0; seq_id < metas.size(); seq_id++) { + auto slice = tensor->Slice(seq_id, seq_id + 1); + auto boot_slice = + boot_state.Slice(metas[seq_id].ori_idx, metas[seq_id].ori_idx + 1); + // TODO(superjom) pass in device context as an argument + slice.template CopyFrom(boot_slice, dst_place, + platform::CPUDeviceContext()); + } +} + } // namespace detail class DynamicRecurrentOpProtoAndCheckerMaker @@ -69,6 +93,7 @@ void DynamicRecurrentOp::Run(const Scope& scope, CreateScopes(); WriteStepInputs(); InitStates(); + WriteStepOutputs(); // call stepnet in all the time steps for (size_t step = 0; step < cache_.num_steps; step++) { @@ -76,7 +101,6 @@ void DynamicRecurrentOp::Run(const Scope& scope, stepnet_->Run(step_scope, dev_ctx); } - WriteStepOutputs(); ConcatOutputs(); } @@ -84,11 +108,11 @@ void DynamicRecurrentOp::SplitInputs() const { // TODO(superjom) make level a config // TODO(superjom) check all the inputs has the same LoD int level = 0; - const auto& inlinks = cache_.inlinks; - for (const auto& item : inlinks) { + for (const auto& item : cache_.inlinks) { const auto& var = item.second; const auto& tensor = var->Get(); TensorArray& ta = step_inputs_[item.first]; + dy_seq_metas_[item.first] = ta.Unpack(tensor, level, true /*length_descend*/); @@ -120,17 +144,11 @@ void DynamicRecurrentOp::WriteStepInputs() const { } void DynamicRecurrentOp::WriteStepOutputs() const { - for (size_t step = 0; step < cache_.scopes->size(); step++) { - auto& scope = cache_.GetScope(step); - for (auto& item : step_outputs_) { - auto* var = scope.FindVar(item.first); - if (var == nullptr) { - var = scope.NewVar(item.first); - } - auto* tensor = var->GetMutable(); - item.second.WriteShared(step, *tensor); - } + // initialize step outputs + for (const auto& item : cache_.outlinks) { + step_outputs_.emplace(item.first, TensorArray()); } + PADDLE_ENFORCE_GT(step_outputs_.size(), 0UL); } void DynamicRecurrentOp::CreateScopes() const { @@ -145,12 +163,18 @@ void DynamicRecurrentOp::CreateScopes() const { PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first"); std::vector memories; std::vector pre_memories; + std::vector stepnet_outputs; std::transform(arg_.memories.begin(), arg_.memories.end(), std::back_inserter(memories), [](const rnn::MemoryAttr& m) { return m.var; }); std::transform(arg_.memories.begin(), arg_.memories.end(), std::back_inserter(pre_memories), [](const rnn::MemoryAttr& m) { return m.pre_var; }); + for (const auto& item : stepnet_->Outputs()) { + for (const auto& var : item.second) { + stepnet_outputs.push_back(var); + } + } for (size_t step = 0; step < cache_.num_steps; step++) { auto& scope = cache_.GetScope(step); @@ -158,60 +182,88 @@ void DynamicRecurrentOp::CreateScopes() const { detail::CreateVariables(scope, arg_.outlinks); detail::CreateVariables(scope, memories); detail::CreateVariables(scope, pre_memories); + detail::CreateVariables(scope, stepnet_outputs); } } void DynamicRecurrentOp::ConcatOutputs() const { // TODO(superjom) transform this to a config int level = 0; - // TODO(superjom) pass in some lod - // just a placeholder - framework::LoD lod; + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& scope = cache_.GetScope(step); + for (auto& item : step_outputs_) { + auto* var = scope.FindVar(item.first); + PADDLE_ENFORCE_NOT_NULL(var); + auto* tensor = var->GetMutable(); + tensor->mutable_data(platform::CPUPlace()); + item.second.WriteShared(step, *tensor); + } + } + // the inlinks' lods should be the same, so randomly get one lod. + const auto& some_lod = + cache_.scope->FindVar(arg_.inlinks.front())->Get().lod(); + const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()]; for (auto& item : step_outputs_) { - auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod); - auto& output = cache_.outlinks[item.first]->Get(); - const_cast(&output)->ShareDataWith(tensor); + auto tensor = item.second.Pack(level, some_meta, some_lod); + auto* output = cache_.outlinks[item.first]->GetMutable(); + const_cast(output)->ShareDataWith(tensor); } } void DynamicRecurrentOp::InitStates() const { - // init the first state - // TODO(superjom) parepare the scenerio that boot state not exists - for (auto memory : arg_.memories) { - auto* boot_state_var = cache_.scope->FindVar(memory.boot_var); - PADDLE_ENFORCE_NOT_NULL(boot_state_var); - auto& boot_state = boot_state_var->Get(); - const auto& dims = boot_state.dims(); - - for (size_t step = 0; step < cache_.num_steps; step++) { - auto& cur_scope = cache_.GetScope(step); - // link pre-state to boot_state - // init state and pre-state - auto* pre_state = cur_scope.FindVar(memory.pre_var); - PADDLE_ENFORCE_NOT_NULL(pre_state); - pre_state->GetMutable(); - - auto* state = cur_scope.FindVar(memory.var); - PADDLE_ENFORCE_NOT_NULL(state); - state->GetMutable()->Resize(dims); - state->GetMutable()->mutable_data( - platform::CPUPlace()); - - if (step == 0) { - auto* pre_state_tensor = pre_state->GetMutable(); - pre_state_tensor->Resize(boot_state.dims()); - pre_state_tensor->ShareDataWith(boot_state); - } else { - auto& pre_scope = cache_.GetScope(step - 1); - auto* state_pre = pre_scope.FindVar(memory.var); - PADDLE_ENFORCE_NOT_NULL(state_pre); - pre_state->GetMutable()->ShareDataWith( - *state_pre->GetMutable()); - } + for (size_t step = 0; step < cache_.num_steps; step++) { + for (const auto& memory : arg_.memories) { + CreateState(memory, step); + LinkState(memory, step); } } } +void DynamicRecurrentOp::CreateState(const rnn::MemoryAttr& memory, + size_t step) const { + auto& scope = cache_.GetScope(step); + auto& state = *cache_.GetTensor(scope, memory.var); + auto& boot_state = *cache_.GetTensor(*cache_.scope, memory.boot_var); + + size_t num_instances = + step_inputs_[arg_.inlinks.front()].Read(step).dims()[0]; + auto dims = boot_state.dims(); + dims[0] = num_instances; + + state.Resize(dims); + state.mutable_data(platform::CPUPlace()); + states_[memory.var].WriteShared(step, state); +} + +void DynamicRecurrentOp::LinkState(const rnn::MemoryAttr& memory, + size_t step) const { + auto& scope = cache_.GetScope(step); + auto& state_pre = *cache_.GetTensor(scope, memory.pre_var); + + // all the step_inputs' metas should be the same, just randomly select one + // and get the dyseq meta. + const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()]; + size_t num_instances = + step_inputs_[arg_.inlinks.front()].Read(step).dims()[0]; + + LoDTensor* pre_state{nullptr}; + if (step == 0) { + pre_state = cache_.GetTensor(*cache_.scope, memory.boot_var); + pre_state->mutable_data(platform::CPUPlace()); + // allocate memory + state_pre.Resize(pre_state->dims()); + state_pre.mutable_data(platform::CPUPlace()); + detail::ReorderBootState(some_meta, *pre_state, &state_pre, + pre_state->place()); + } else { + pre_state = cache_.GetTensor(cache_.GetScope(step - 1), memory.var); + } + + // shink and share from previous state + auto shrinked_pre_state = pre_state->Slice(0, num_instances); + state_pre.ShareDataWith(shrinked_pre_state); +} + void DynamicRecurrentOp::ArgCache::Init( const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op, const paddle::framework::Scope& scope, rnn::Argument* arg) { @@ -261,6 +313,12 @@ Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope, return var; } +LoDTensor* DynamicRecurrentOp::ArgCache::GetTensor( + const framework::Scope& scope, const std::string& name) { + auto* var = GetVariable(scope, name); + return var->GetMutable(); +} + const rnn::ArgumentName DynamicRecurrentOp::kArgName{ "step_net", "step_scopes", "inlinks", "outlinks", "memories", "pre_memories", "boot_memories"}; diff --git a/paddle/operators/dynamic_recurrent_op.h b/paddle/operators/dynamic_recurrent_op.h index 6a2970f27fd5bcb25e924dbc567e254159b55a3e..ec80a1c90eee3a655febe0dd3d6c67c16ec6c64b 100644 --- a/paddle/operators/dynamic_recurrent_op.h +++ b/paddle/operators/dynamic_recurrent_op.h @@ -77,6 +77,17 @@ class DynamicRecurrentOp : public framework::OperatorBase { */ void InitStates() const; + /* + * Create state variables for each time step. + */ + void CreateState(const rnn::MemoryAttr& memory, size_t step) const; + + /* + * Link pre-state variable in current scope to the state variable in the + * previous time step (scope). + */ + void LinkState(const rnn::MemoryAttr& memory, size_t step) const; + /* * Concatenate outputs in each time step and generate a LoDTensor. */ @@ -91,6 +102,16 @@ class DynamicRecurrentOp : public framework::OperatorBase { } const OperatorBase& GetStepNet() const { return *stepnet_; } + const framework::TensorArray& state(const std::string& name) const { + return states_[name]; + } + const framework::TensorArray& step_input(const std::string& name) const { + return step_inputs_[name]; + } + const framework::TensorArray& step_output(const std::string& name) const { + return step_outputs_[name]; + } + protected: struct ArgCache { framework::Scope const* scope; @@ -108,6 +129,9 @@ class DynamicRecurrentOp : public framework::OperatorBase { return *scopes->at(index); } + framework::LoDTensor* GetTensor(const framework::Scope& scope, + const std::string& name); + private: void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op, rnn::Argument* arg); @@ -122,7 +146,7 @@ class DynamicRecurrentOp : public framework::OperatorBase { private: std::unique_ptr stepnet_; - mutable framework::TensorArray states_; + mutable std::map states_; mutable std::map step_inputs_; mutable std::map step_outputs_; mutable std::map> diff --git a/paddle/operators/dynamic_recurrent_op_test.cc b/paddle/operators/dynamic_recurrent_op_test.cc index 675a7890f3fa6bb7ab9dbbdb04894b2557214a8a..b849c4541da5d9812f4d86430049c0cbc04f385d 100644 --- a/paddle/operators/dynamic_recurrent_op_test.cc +++ b/paddle/operators/dynamic_recurrent_op_test.cc @@ -87,7 +87,6 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test { platform::CPUPlace place; scope.NewVar("step_scopes"); CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place); - // auto* out0 = CreateVar(scope, "out0", framework::make_ddim({10, 20}), place); auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place); // 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively. diff --git a/paddle/operators/gru_unit_op.cc b/paddle/operators/gru_unit_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..24f84597cd7301af6521b8c1032e69569ba6f03a --- /dev/null +++ b/paddle/operators/gru_unit_op.cc @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gru_unit_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class GRUUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUUnitOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), + "Input(%s) of GRUUnitOp should not be null.", "HiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUUnitOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasOutput("Gate"), + "Output(%s) of GRUUnitOp should not be null.", "Gate"); + PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"), + "Output(%s) of GRUUnitOp should not be null.", + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(%s) of GRUUnitOp should not be null.", "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); + auto weight_dims = ctx->GetInputDim("Weight"); + int batch_size = input_dims[0]; + int input_size = input_dims[1]; + int frame_size = hidden_prev_dims[1]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUUnitOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + } + ctx->SetOutputDim("Gate", {batch_size, frame_size * 3}); + ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size}); + ctx->SetOutputDim("Hidden", {batch_size, frame_size}); + } +}; + +class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GRUUnitOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " + "input."); + AddInput("HiddenPrev", + "(Tensor) Matrix with shape [batch_size, frame_size] for the " + "states of previous time step."); + AddInput("Weight", + "(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. " + "The elements continuous in memory can be divided into two parts. " + "The first part are weights of the update gate and reset gate " + "with shape [frame_size, frame_size * 2], and the second part are " + "weights of output candidate with shape [frame_size, frame_size]"); + AddInput("Bias", + "(Tensor) Bias vector with shape [1, frame_size * 3] concating " + "bias of the update gate, reset gate and output candidate."); + AddOutput("Gate", + "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " + "output of update gate, reset gate and output candidate") + .AsIntermediate(); + AddOutput("ResetHiddenPrev", + "(Tensor) Matrix with shape [batch_size, frame_size] for the " + "reseted hidden state of previous time step.") + .AsIntermediate(); + AddOutput("Hidden", + "(Tensor) The GRU hidden state of the current time step " + "with shape [batch_size, frame_size]."); + AddAttr("activation", + "(enum int, default tanh) " + "The activation type used for output candidate {h}_t.") + .SetDefault(tanh) + .InEnum({identity, sigmoid, tanh, relu}); + AddAttr("gate_activation", + "(enum int, default sigmoid) " + "The activation type used in update gate and reset gate.") + .SetDefault(sigmoid) + .InEnum({identity, sigmoid, tanh, relu}); + AddComment(R"DOC( +GRUUnitOp implements part calculations of the GRU unit as following: + +\f[ +update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\ +reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\ +output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\ +output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_prev) +\f] + +The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp. +)DOC"); + } +}; + +class GRUUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUUnitGradOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), + "Input(%s) of GRUUnitGradOp should not be null.", + "HiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUUnitGradOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasInput("Gate"), + "Input(%s) of GRUUnitGradOp should not be null.", "Gate"); + PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"), + "Input(%s) of GRUUnitGradOp should not be null.", + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(%s) of GRUUnitGradOp should not be null.", "Hidden"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "Gate"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); + auto weight_dims = ctx->GetInputDim("Weight"); + // int batch_size = input_dims[0]; + int input_size = input_dims[1]; + int frame_size = hidden_prev_dims[1]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUUnitOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + auto input_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(input_grad_name)) + ctx->SetOutputDim(input_grad_name, input_dims); + auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev"); + if (ctx->HasOutput(hidden_prev_grad_name)) + ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims); + auto weight_grad_name = framework::GradVarName("Weight"); + if (ctx->HasOutput(weight_grad_name)) + ctx->SetOutputDim(weight_grad_name, weight_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad, + ops::GRUUnitGradOp); +REGISTER_OP_CPU_KERNEL(gru_unit, + ops::GRUUnitKernel); +REGISTER_OP_CPU_KERNEL( + gru_unit_grad, ops::GRUUnitGradKernel); diff --git a/paddle/operators/gru_unit_op.cu b/paddle/operators/gru_unit_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..365f656523ddfb7ec8e2a5b885de74674823325a --- /dev/null +++ b/paddle/operators/gru_unit_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gru_unit_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gru_unit, + ops::GRUUnitKernel); +REGISTER_OP_GPU_KERNEL( + gru_unit_grad, ops::GRUUnitGradKernel); diff --git a/paddle/operators/gru_unit_op.h b/paddle/operators/gru_unit_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c53e7d9827e0395e6ce613302e732b2797f83cdd --- /dev/null +++ b/paddle/operators/gru_unit_op.h @@ -0,0 +1,230 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/activation_op.h" +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; + +template +class GRUUnitKernel : public framework::OpKernel { + public: + template + void ActCompute(const int act_type, const Device& d, X x, Y y) const { + if (act_type == identity) + y.device(d) = x; + else if (act_type == sigmoid) + SigmoidFunctor()(d, x, y); + else if (act_type == tanh) + TanhFunctor()(d, x, y); + else if (act_type == relu) + ReluFunctor()(d, x, y); + else + PADDLE_THROW("unsupported activation type"); + } + + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* hidden_prev = context.Input("HiddenPrev"); + auto* weight = context.Input("Weight"); + auto* bias = context.Input("Bias"); + auto* gate = context.Output("Gate"); + gate->mutable_data(context.GetPlace()); + auto* reset_hidden_prev = context.Output("ResetHiddenPrev"); + reset_hidden_prev->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + + auto x = EigenMatrix::From(*input); + auto h_p = EigenMatrix::From(*hidden_prev); + auto g = EigenMatrix::From(*gate); + auto r_h_p = EigenMatrix::From(*reset_hidden_prev); + auto h = EigenMatrix::From(*hidden); + auto place = context.GetEigenDevice(); + + // calculate unactivated gate outputs + if (bias) { + auto b = EigenMatrix::From(*bias); + g.device(place) = x + + b.reshape(Eigen::array({{1, frame_size * 3}})) + .broadcast(Eigen::array({{batch_size, 1}})); + } else { + g.device(place) = x; + } + const T* hidden_prev_data = hidden_prev->data(); + const T* weight_data = weight->data(); + T* gate_data = gate->data(); + T* reset_hidden_prev_data = reset_hidden_prev->data(); + math::gemm(context.device_context(), false, false, batch_size, + 2 * frame_size, frame_size, 1, hidden_prev_data, + frame_size, weight_data, frame_size * 2, 1, gate_data, + frame_size * 3); + + // calculate activited gate + Eigen::array extents({{batch_size, frame_size}}); + Eigen::array u_offsets({{0, 0}}); + ActCompute(context.Attr("gate_activation"), place, + g.slice(u_offsets, extents), g.slice(u_offsets, extents)); + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets({{0, frame_size}}); + ActCompute(context.Attr("gate_activation"), place, + g.slice(r_offsets, extents), g.slice(r_offsets, extents)); + auto r = g.slice(r_offsets, extents); // reset gate + r_h_p.device(place) = r * h_p; // reset previous hidden state + math::gemm(context.device_context(), false, false, batch_size, + frame_size, frame_size, 1, reset_hidden_prev_data, + frame_size, weight_data + frame_size * frame_size * 2, + frame_size, 1, gate_data + frame_size * 2, + frame_size * 3); + + Eigen::array c_offsets({{0, frame_size * 2}}); + ActCompute(context.Attr("activation"), place, + g.slice(c_offsets, extents), g.slice(c_offsets, extents)); + auto c = g.slice(c_offsets, extents); // output candidate + + // calculate final output + h.device(place) = u * (h_p - c) + c; + } +}; + +template +class GRUUnitGradKernel : public framework::OpKernel { + public: + template + void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx, + DY dy) const { + // x is dummy and won't be used even in Relu(use y instead) + if (act_type == identity) + dx.device(d) = dy; + else if (act_type == sigmoid) + SigmoidGradFunctor()(d, x, y, dy, dx); + else if (act_type == tanh) + TanhGradFunctor()(d, x, y, dy, dx); + else if (act_type == relu) + ReluGradFunctor()(d, x, y, dy, dx); + else + PADDLE_THROW("unsupported activation type"); + } + + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* hidden_prev = context.Input("HiddenPrev"); + auto* weight = context.Input("Weight"); + auto* gate = context.Input("Gate"); + auto* reset_hidden_prev = context.Input("ResetHiddenPrev"); + auto* hidden_grad = context.Input(framework::GradVarName("Hidden")); + auto* input_grad = context.Output(framework::GradVarName("Input")); + auto* hidden_prev_grad = + context.Output(framework::GradVarName("HiddenPrev")); + auto* weight_grad = + context.Output(framework::GradVarName("Weight")); + auto* bias_grad = context.Output(framework::GradVarName("Bias")); + input_grad->mutable_data(context.GetPlace()); + hidden_prev_grad->mutable_data(context.GetPlace()); + weight_grad->mutable_data(context.GetPlace()); + Tensor gate_grad; + gate_grad.mutable_data(input->dims(), context.GetPlace()); + Tensor reset_hidden_prev_grad; + reset_hidden_prev_grad.mutable_data(reset_hidden_prev->dims(), + context.GetPlace()); + + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + + const T* hidden_prev_data = hidden_prev->data(); + T* hidden_prev_grad_data = hidden_prev_grad->data(); + const T* weight_data = weight->data(); + T* weight_grad_data = weight_grad->data(); + T* gate_grad_data = gate_grad.data(); + const T* reset_hidden_prev_data = reset_hidden_prev->data(); + T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data(); + + auto h_p = EigenMatrix::From(*hidden_prev); + auto g = EigenMatrix::From(*gate); + auto d_h = EigenMatrix::From(*hidden_grad); + auto d_x = EigenMatrix::From(*input_grad); + auto d_h_p = EigenMatrix::From(*hidden_prev_grad); + auto d_g = EigenMatrix::From(gate_grad); + auto d_r_h_p = EigenMatrix::From(reset_hidden_prev_grad); + auto place = context.GetEigenDevice(); + + Eigen::array extents({{batch_size, frame_size}}); + Eigen::array u_offsets({{0, 0}}); + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets({{0, frame_size}}); + auto r = g.slice(r_offsets, extents); // reset gate + Eigen::array c_offsets({{0, frame_size * 2}}); + auto c = g.slice(c_offsets, extents); // output candidate + + // backward for unactivated update gate + ActGradCompute(context.Attr("gate_activation"), place, u, u, + d_g.slice(u_offsets, extents), d_h * (h_p - c)); + // backward for unactivated output candidate + ActGradCompute(context.Attr("activation"), place, c, c, + d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u)); + // backward for reset_hidden_prev + math::gemm(context.device_context(), false, true, batch_size, + frame_size, frame_size, 1, + gate_grad_data + frame_size * 2, frame_size * 3, + weight_data + frame_size * frame_size * 2, frame_size, + 0, reset_hidden_prev_grad_data, frame_size); + // backward for state_weight + math::gemm( + context.device_context(), true, false, frame_size, frame_size, + batch_size, 1, reset_hidden_prev_data, frame_size, + gate_grad_data + frame_size * 2, frame_size * 3, 0, + weight_grad_data + frame_size * frame_size * 2, frame_size); + // backward for unactivated reset gate + ActGradCompute(context.Attr("gate_activation"), place, r, r, + d_g.slice(r_offsets, extents), d_r_h_p * h_p); + // backward for update_gate_weight and reset_gate_weight + math::gemm(context.device_context(), true, false, frame_size, + frame_size * 2, batch_size, 1, hidden_prev_data, + frame_size, gate_grad_data, frame_size * 3, 0, + weight_grad_data, frame_size * 2); + // backward for hidden_prev + d_h_p.device(place) = d_r_h_p * r + d_h * u; + math::gemm(context.device_context(), false, true, batch_size, + frame_size, frame_size * 2, 1, gate_grad_data, + frame_size * 3, weight_data, frame_size * 2, 1, + hidden_prev_grad_data, frame_size); + // backward for input + d_x.device(place) = d_g; + // backward for bias + if (bias_grad) { + bias_grad->mutable_data(context.GetPlace()); + auto d_b = EigenMatrix::From(*bias_grad); + d_b.device(place) = d_g.sum(Eigen::array({{0}})); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index ffb0cb92111bfb8490d35e4f5cfc9e405b0e3250..573487b83590c132d5a4379a4b2762fbc16c04bc 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -34,7 +34,7 @@ class SumOp : public framework::OperatorWithKernel { auto in_dim = x_dims[0]; for (size_t i = 1; i < N; i++) { auto dim = x_dims[i]; - PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape"); + PADDLE_ENFORCE_EQ(in_dim, dim, "Input tensors must have same shape"); } ctx->SetOutputDim("Out", in_dim); ctx->ShareLoD("X", /*->*/ "Out"); diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 0f6e3101e26c5ac249664ce8badc10adc939305f..cc9f7ffe02781cc13105b19bb987207743febdf6 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" #include "paddle/framework/tensor_array.h" #include "paddle/operators/cond_op.h" +#include "paddle/operators/dynamic_recurrent_op.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" @@ -341,6 +342,33 @@ All parameter, weight, gradient are variables in Paddle. self.set_stepnet(net.Clone()); }); + py::class_(m, + "DynamicRecurrentOp") + .def_static("create", + [](py::bytes protobin) -> operators::DynamicRecurrentOp * { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto rnn_op = OpRegistry::CreateOp(desc); + return static_cast( + rnn_op.release()); + }) + .def("set_stepnet", + [](operators::DynamicRecurrentOp &self, const operators::NetOp &net) + -> void { self.SetStepNet(net.Clone()); }) + .def("get_state", + [](operators::DynamicRecurrentOp &self, const std::string &name) + -> const TensorArray & { return self.state(name); }) + .def("get_step_input", + [](operators::DynamicRecurrentOp &self, const std::string &name) + -> const TensorArray & { return self.step_input(name); }) + .def("get_step_output", + [](operators::DynamicRecurrentOp &self, const std::string &name) + -> const TensorArray & { return self.step_output(name); }); + // cond_op py::class_(m, "CondOp") .def_static("create", diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 9086a5cc3452b178ec37fe6a3e358eaa4c5d606b..bc771a964adf9f97cbeae87c06ce954c76051150 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -219,6 +219,27 @@ class __RecurrentOp__(object): return core.RecurrentOp.create(proto.SerializeToString()) +class __DynamicRecurrentOp__(object): + __proto__ = None + type = "dynamic_recurrent" + + def __init__(self): + # cache recurrent_op's proto + if self.__proto__ is None: + for op_proto in get_all_op_protos(): + if op_proto.type == self.type: + self.__proto__ = op_proto + + def __call__(self, *args, **kwargs): + if self.type not in args and "type" not in kwargs: + kwargs["type"] = self.type + # create proto + create_method = OpDescCreationMethod(self.__proto__) + proto = create_method(*args, **kwargs) + # create rnnop + return core.DynamicRecurrentOp.create(proto.SerializeToString()) + + class __CondOp__(object): __proto__ = None type = "cond" @@ -242,4 +263,5 @@ class __CondOp__(object): Operator = OperatorFactory() # The default global factory RecurrentOp = __RecurrentOp__() +DynamicRecurrentOp = __DynamicRecurrentOp__() CondOp = __CondOp__() diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 3acd00e35213981fce60504876af1861961ebe12..5831b880e4c5ef881929920e87ac64d6c87a2ab5 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -384,5 +384,33 @@ class TestThresholdedRelu(OpTest): self.check_grad(['X'], 'Y', max_relative_error=self.relative_error) +class TestHardSigmoid(OpTest): + def setUp(self): + self.op_type = "hard_sigmoid" + self.relative_error = 0.002 + + X = np.random.uniform(-5, 5, [2, 2]).astype("float32") + slope = 0.2 + offset = 0.5 + lower_threshold = -offset / slope + upper_threshold = (1 - offset) / slope + + self.inputs = {'X': X} + # Same reason as TestAbs + X[np.abs(X - lower_threshold) < self.relative_error] = \ + lower_threshold + 0.2 + X[np.abs(X - upper_threshold) < self.relative_error] = \ + upper_threshold - 0.2 + + temp = X * slope + offset + self.outputs = {'Y': np.maximum(0.0, np.minimum(1.0, temp))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.002) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py b/python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b4629a3adb9a84470843214c7c6d80acde7228cc --- /dev/null +++ b/python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py @@ -0,0 +1,111 @@ +import logging +import paddle.v2.framework.core as core +import unittest +from paddle.v2.framework.op import Operator, DynamicRecurrentOp +import numpy as np + + +def create_tensor(scope, name, shape, np_data): + tensor = scope.new_var(name).get_tensor() + tensor.set_dims(shape) + tensor.set(np_data, core.CPUPlace()) + return tensor + + +class DynamicRecurrentOpTest(unittest.TestCase): + ''' + Test RNNOp + + equation: + h_t = \sigma (W x_t + U h_{t-1}) + weights: + - W + - U + vars: + - x + memories: + - h + outputs: + - h + ''' + + # for siplicity, just one level LoD + lod_py = [[0, 4, 7, 9, 10]] + input_dim = 30 + num_sents = len(lod_py[0]) - 1 + weight_dim = 15 + + def forward(self): + self.scope = core.Scope() + self.create_global_variables() + self.create_rnn_op() + self.create_step_net() + ctx = core.DeviceContext.create(core.CPUPlace()) + self.rnnop.run(self.scope, ctx) + state = self.rnnop.get_state("h@mem") + print 'state size: ', state.size() + + step_inputs = self.rnnop.get_step_input("x") + print "x size ", step_inputs.size() + for i in range(step_inputs.size()): + print "x %d" % i, np.array(step_inputs.read(i).get_dims()) + step_outputs = self.rnnop.get_step_output('h@mem') + print 'step_outputs.size ', step_outputs.size() + output = self.scope.find_var("h@mem").get_tensor() + + print 'output', np.array(output).shape + + def create_global_variables(self): + x = np.random.normal(size=(self.lod_py[0][-1], + self.input_dim)).astype("float32") + W = np.random.normal(size=(self.input_dim, + self.input_dim)).astype("float32") + U = np.random.normal(size=(self.input_dim, + self.input_dim)).astype("float32") + h_boot = np.random.normal(size=(self.num_sents, + self.input_dim)).astype("float32") + # create inlink + x_tensor = create_tensor(self.scope, "x", + [self.num_sents, self.input_dim], x) + x_tensor.set_lod(self.lod_py) + create_tensor(self.scope, "W", [self.input_dim, self.input_dim], W) + create_tensor(self.scope, "U", [self.input_dim, self.input_dim], U) + create_tensor(self.scope, "h_boot", [self.num_sents, self.input_dim], + h_boot) + self.scope.new_var("step_scopes") + self.scope.new_var("h@mem") + + def create_rnn_op(self): + # create RNNOp + self.rnnop = DynamicRecurrentOp( + # inputs + inlinks=["x"], + boot_memories=["h_boot"], + step_net="stepnet", + # outputs + outlinks=["h@mem"], + step_scopes="step_scopes", + # attributes + pre_memories=["h@pre"], + memories=["h@mem"]) + + def create_step_net(self): + stepnet = core.Net.create() + x_fc_op = Operator("mul", X="x", Y="W", Out="Wx") + h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh") + sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum") + sig_op = Operator("sigmoid", X="sum", Y="h@mem") + + for op in [x_fc_op, h_fc_op, sum_op, sig_op]: + stepnet.append_op(op) + stepnet.complete_add_op(True) + self.rnnop.set_stepnet(stepnet) + + def test_forward(self): + print 'test recurrent op forward' + pd_output = self.forward() + print 'pd_output', pd_output + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_gru_unit_op.py b/python/paddle/v2/framework/tests/test_gru_unit_op.py new file mode 100644 index 0000000000000000000000000000000000000000..57625362d21905d257f46ff5330841a20438773a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gru_unit_op.py @@ -0,0 +1,115 @@ +import math +import unittest +import numpy as np +from op_test import OpTest + + +class GRUActivationType(OpTest): + identity = 0 + sigmoid = 1 + tanh = 2 + relu = 3 + + +def identity(x): + return x + + +def sigmoid(x): + return 1. / (1. + np.exp(-x)) + + +def tanh(x): + return 2. * sigmoid(2. * x) - 1. + + +def relu(x): + return np.maximum(x, 0) + + +class TestGRUUnitOp(OpTest): + batch_size = 3 + frame_size = 5 + activate = { + GRUActivationType.identity: identity, + GRUActivationType.sigmoid: sigmoid, + GRUActivationType.tanh: tanh, + GRUActivationType.relu: relu, + } + + def set_inputs(self): + batch_size = self.batch_size + frame_size = self.frame_size + self.op_type = 'gru_unit' + self.inputs = { + 'Input': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'), + 'HiddenPrev': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size)).astype('float32'), + 'Weight': np.random.uniform( + -1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size), + (frame_size, frame_size * 3)).astype('float32'), + } + self.attrs = { + 'activation': GRUActivationType.tanh, + 'gate_activation': GRUActivationType.sigmoid + } + + def set_outputs(self): + # GRU calculations + batch_size = self.batch_size + frame_size = self.frame_size + x = self.inputs['Input'] + h_p = self.inputs['HiddenPrev'] + w = self.inputs['Weight'] + b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros( + (1, frame_size * 3)) + g = x + np.tile(b, (batch_size, 1)) + w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( + (frame_size, frame_size * 2)) + u_r = self.activate[self.attrs['gate_activation']](np.dot( + h_p, w_u_r) + g[:, :frame_size * 2]) + u = u_r[:, :frame_size] + r = u_r[:, frame_size:frame_size * 2] + r_h_p = r * h_p + w_c = w.flatten()[frame_size * frame_size * 2:].reshape( + (frame_size, frame_size)) + c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + + g[:, frame_size * 2:]) + g = np.hstack((u_r, c)) + h = u * h_p + (1 - u) * c + self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h} + + def setUp(self): + self.set_inputs() + self.set_outputs() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['Input', 'HiddenPrev', 'Weight'], ['Hidden'], + max_relative_error=0.007) + + +class TestGRUUnitOpWithBias(TestGRUUnitOp): + def set_inputs(self): + batch_size = self.batch_size + frame_size = self.frame_size + super(TestGRUUnitOpWithBias, self).set_inputs() + self.inputs['Bias'] = np.random.uniform( + -0.1, 0.1, (1, frame_size * 3)).astype('float32') + self.attrs = { + 'activation': GRUActivationType.identity, + 'gate_activation': GRUActivationType.sigmoid + } + + def test_check_grad(self): + self.check_grad( + ['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'], + max_relative_error=0.007) + + +if __name__ == '__main__': + unittest.main()