未验证 提交 0a32e74d 编写于 作者: Y Yu Yang 提交者: GitHub

Rewrite StaticRNN with Executor (#5224)

* Init commit

* Make executor use ProgramDescBind

* Change Attribute from BlockDesc to BlockDescBind

* Since we will get the program desc in RNN, just BlockDesc is not
  enough.

* Add DeviceContext to Executor API

* Rewrite RNN

* Pass Python

* AddBiasOp does not care num_flatten_dims

* Stash

* Fix MacOS Compile

* Pass RNN forward

* add python test

* refactor test

* Make compile pass

* add gradopmaker

* First draft done

* Polish code

* add grad op maker and grad infershape

* Polish code

* Fix backward.cc bug

* Fix infershape

* Rename function

* add backward test

* simplify recurrent test

* Update

* Pass unittest

* Add comments & refine test

* Add comments

* refactor test

* Complete Unittest

* fix StepScopes enforce

* Remove unused unittest

* no type error

* Update

* Make RNN Pass unittest
上级 8cdb42c2
......@@ -24,7 +24,6 @@
#include "paddle/framework/op_registry.h"
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
namespace paddle {
namespace framework {
......@@ -38,7 +37,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var, {});
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(),
......@@ -220,19 +219,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
});
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
} else if (forwardOp.Type() == "dynamic_recurrent") {
if (forwardOp.Type() == "dynamic_recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const auto& rnnop =
......@@ -331,7 +318,7 @@ static void CreateGradVarInBlock(
continue;
}
auto pname = FwdName(arg);
auto* param = block_desc->FindVar(pname);
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
LOG(WARNING) << "Cannot find forward variable of " << arg
......@@ -348,7 +335,9 @@ static void CreateGradVarInBlock(
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block =
std::vector<BlockDescBind*>()) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
......@@ -364,9 +353,10 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector
}
grad_op_descs = OpInfoMap::Instance()
.Get(op_desc->Type())
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
grad_op_descs =
OpInfoMap::Instance()
.Get(op_desc->Type())
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) {
......@@ -400,21 +390,20 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeOpGrad(*it, no_grad_vars, grad_to_var);
std::vector<std::unique_ptr<OpDescBind>> op_grads;
if ((*it)->Type() == "recurrent") {
PADDLE_ENFORCE_EQ(
op_grads.size(), static_cast<size_t>(1),
"rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("step_block");
auto backward_block_op_descs = MakeBlockBackward(
program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(std::move(ptr));
}
op_grads[0]->SetBlockAttr("step_block", *backward_block);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
}
for (const auto& desc : op_grads) {
......
......@@ -88,6 +88,8 @@ class BlockDescBind {
BlockDesc *Proto();
ProgramDescBind *Program() { return this->prog_; }
private:
void ClearPBOps();
void ClearPBVars();
......
......@@ -108,8 +108,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
info->grad_op_maker_ = [](
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var) {
T maker(fwd_op, no_grad_set, grad_to_var);
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block) {
T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
return maker();
};
}
......
......@@ -31,7 +31,7 @@ namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
Executor::Executor(const std::vector<platform::Place>& places) {
Executor::Executor(const std::vector<platform::Place>& places) : own_(true) {
PADDLE_ENFORCE_GT(places.size(), 0);
device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) {
......@@ -52,8 +52,10 @@ Executor::Executor(const std::vector<platform::Place>& places) {
}
Executor::~Executor() {
for (auto& device_context : device_contexts_) {
delete device_context;
if (own_) {
for (auto& device_context : device_contexts_) {
delete device_context;
}
}
}
......@@ -66,14 +68,18 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>();
} else {
PADDLE_THROW(
"Variable type must be "
"LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST.");
"Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST]",
var_type);
}
}
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id) {
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
bool create_local_scope) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
......@@ -81,29 +87,42 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id) {
auto& block = pdesc.Block(block_id);
auto& device = device_contexts_[0];
Scope& local_scope = scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope.Var(var->Name());
Scope* local_scope = scope;
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
op->Run(local_scope, *device);
op->Run(*local_scope, *device);
}
if (create_local_scope) {
scope->DeleteScope(local_scope);
}
scope->DeleteScope(&local_scope);
}
Executor::Executor(const platform::DeviceContext& device)
: device_contexts_({&device}), own_(false) {}
} // namespace framework
} // namespace paddle
......@@ -25,6 +25,7 @@ namespace framework {
class Executor {
public:
explicit Executor(const std::vector<platform::Place>& places);
explicit Executor(const platform::DeviceContext& devices);
~Executor();
/* @Brief
......@@ -34,10 +35,11 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDescBind&, Scope*, int);
void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true);
private:
std::vector<platform::DeviceContext*> device_contexts_;
std::vector<const platform::DeviceContext*> device_contexts_;
bool own_;
};
} // namespace framework
......
......@@ -15,6 +15,7 @@
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
......@@ -26,8 +27,13 @@ class GradOpDescMakerBase {
explicit GradOpDescMakerBase(
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {}
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block =
std::vector<BlockDescBind*>())
: fwd_op_(fwd_op),
no_grad_set_(no_grad_set),
grad_to_var_(grad_to_var),
grad_block_(grad_block) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
......@@ -102,6 +108,9 @@ class GradOpDescMakerBase {
const OpDescBind& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_;
std::unordered_map<std::string, std::string>* grad_to_var_;
protected:
std::vector<BlockDescBind*> grad_block_;
};
class SingleGradOpDescMaker : public GradOpDescMakerBase {
......
......@@ -327,6 +327,19 @@ void OpDescBind::InferShape(const BlockDescBind &block) const {
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
"%s's infer_shape has not been registered", this->Type());
CompileTimeInferShapeContext ctx(*this, block);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
auto inames = this->InputArgumentNames();
sout << " From [";
std::copy(inames.begin(), inames.end(),
std::ostream_iterator<std::string>(sout, ", "));
sout << "] to [";
auto onames = this->OutputArgumentNames();
std::copy(onames.begin(), onames.end(),
std::ostream_iterator<std::string>(sout, ", "));
sout << "]";
VLOG(10) << sout.str();
}
infer_shape(&ctx);
}
......
......@@ -126,7 +126,7 @@ OperatorBase::OperatorBase(const std::string& type,
std::vector<std::string> OperatorBase::InputVars() const {
std::vector<std::string> ret_val;
for (auto& o : outputs_) {
for (auto& o : inputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
......@@ -394,7 +394,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
void OperatorWithKernel::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
VLOG(3) << "Running operator " << this->Type();
if (VLOG_IS_ON(1)) {
auto inputs = this->InputVars();
auto outputs = this->OutputVars(true);
std::ostringstream sout;
sout << "Run operator " << this->Type() << " From [";
std::ostream_iterator<std::string> out_it(sout, ",");
std::copy(inputs.begin(), inputs.end(), out_it);
sout << "] to [";
std::copy(outputs.begin(), outputs.end(), out_it);
sout << "]";
VLOG(1) << sout.str();
}
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
......
......@@ -47,8 +47,12 @@ Variable* Scope::Var(const std::string& name) {
return v;
}
Variable* Scope::Var() {
return Var(string::Sprintf("%p.%d", this, vars_.size()));
Variable* Scope::Var(std::string* name) {
auto var_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) {
*name = var_name;
}
return Var(var_name);
}
Variable* Scope::FindVar(const std::string& name) const {
......
......@@ -49,7 +49,7 @@ class Scope {
Variable* Var(const std::string& name);
/// Create a variable with a scope-unique name.
Variable* Var();
Variable* Var(std::string* name = nullptr);
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
......
......@@ -125,7 +125,7 @@ class Tensor {
* @param[in] end_idx The index of the end row(exclusive) to slice.
* The index number begins from 0.
*/
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
inline Tensor Slice(int begin_idx, int end_idx) const;
platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -228,7 +228,7 @@ inline void Tensor::CopyFromVector(const std::vector<T>& src,
#endif
}
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0.");
......
......@@ -29,6 +29,7 @@ class OpDescBind;
class BlockDescBind;
class BlockDesc;
class InferShapeContext;
class BlockDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
......@@ -46,7 +47,8 @@ using OpCreator = std::function<OperatorBase*(
using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/,
std::unordered_map<std::string, std::string>* /*grad_to_var*/)>;
std::unordered_map<std::string, std::string>* /*grad_to_var*/,
const std::vector<BlockDescBind*>& grad_block)>;
using InferVarTypeFN = std::function<void(const OpDescBind& /*op_desc*/,
BlockDescBind* /*block*/)>;
......
......@@ -131,9 +131,10 @@ add_subdirectory(math)
add_subdirectory(nccl)
set(DEPS_OPS
recurrent_op
cond_op
cross_entropy_op
recurrent_op
dynamic_recurrent_op
softmax_with_cross_entropy_op
sum_op
pool_op
......@@ -142,9 +143,6 @@ set(DEPS_OPS
sequence_conv_op
lstm_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
......@@ -156,7 +154,9 @@ op_library(nccl_op DEPS nccl_common)
endif()
op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array)
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
op_library(${src})
......@@ -168,8 +168,9 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc
rnn/recurrent_op_utils.cc
DEPS dynamic_recurrent_op)
if(WITH_GPU)
nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context)
endif()
......
......@@ -29,9 +29,14 @@ class MulOpShapeInference : public framework::InferShapeBase {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
VLOG(3) << "mul operator x.shape=" << x_dims << " y.shape=" << y_dims
<< " x_num_col_dims=" << x_num_col_dims
<< " y_num_col_dims=" << y_num_col_dims;
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
"The input tensor X's rank of MulOp should be larger than "
......
......@@ -12,181 +12,618 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/recurrent_op.h"
#include <cstring>
#include <sstream>
#include <vector>
#include "paddle/framework/executor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
constexpr char kInputs[] = "inputs";
constexpr char kInitialStates[] = "initial_states";
constexpr char kParameters[] = "parameters";
constexpr char kOutputs[] = "outputs";
constexpr char kStepScopes[] = "step_scopes";
constexpr char kExStates[] = "ex_states";
constexpr char kStates[] = "states";
constexpr char kStepBlock[] = "step_block";
constexpr char kReverse[] = "reverse";
constexpr char kIsTrain[] = "is_train";
#define GRAD_SUFFIX "@GRAD"
constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX;
constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX;
constexpr char kParamGrads[] = "parameters" GRAD_SUFFIX;
constexpr char kInitStateGrads[] = "initial_states" GRAD_SUFFIX;
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len, 0);
CreateScopes(scope, seq_len);
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
InitMemories(step_scopes[0]);
for (size_t step_id = 0; step_id < seq_len; step_id++) {
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->states, step_id, -1);
using StepScopeVar = std::vector<framework::Scope *>;
// StepScopes manages scopes inside RNN.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
// StepScopes::Next() move to next time step.
//
// if is_train = False, then
// there are two scopes for the RNN and just support forward.
// else
// the len(scopes) == seq_len
//
// if is_backward = True, then
// reversely access scopes
// else
// access scopes from begin to end.
class StepScopes {
public:
StepScopes(const framework::Scope &parent, StepScopeVar *scopes,
bool is_train, size_t seq_len, bool is_backward = false)
: counter_(is_backward ? seq_len - 1 : 0UL),
scopes_(scopes),
is_train_(is_train),
is_backward_(is_backward) {
size_t num_step_scopes = is_train ? seq_len : 2;
PADDLE_ENFORCE(is_train || !is_backward,
"Cannot backward when is not training");
if (!is_backward_) {
PADDLE_ENFORCE(scopes->empty());
scopes->reserve(static_cast<size_t>(num_step_scopes));
for (size_t i = 0; i < num_step_scopes; ++i) {
scopes->emplace_back(&parent.NewScope());
}
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx);
}
void RecurrentAlgorithm::CreateScopes(const Scope& scope,
size_t seq_len) const {
// TODO(superjom) Only two scopes are needed for inference, this case will be
// supported later.
auto* step_scopes_var = scope.FindVar(arg_->step_scopes);
PADDLE_ENFORCE(step_scopes_var != nullptr, "");
auto* step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(),
"step_unit_ op has no outputs");
if (seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len; ++i) {
auto& step_scope = scope.NewScope();
// create step net's temp inputs
for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
step_scope.Var(var_name)->GetMutable<LoDTensor>();
}
}
framework::Scope &CurScope() { return GetScope(counter_); }
framework::Scope &ExScope() {
auto &scope = GetScope(is_backward_ ? counter_ + 1 : counter_ - 1);
return scope;
}
void Next() {
if (is_backward_) {
--counter_;
} else {
++counter_;
}
}
private:
framework::Scope &GetScope(size_t scope_id) const {
if (!is_train_) {
scope_id %= 2;
}
PADDLE_ENFORCE_LT(scope_id, scopes_->size());
return *(*scopes_)[scope_id];
}
size_t counter_;
StepScopeVar *scopes_;
bool is_train_;
bool is_backward_;
};
// Base class for RecurrentOp/RecurrentGradOp
// Some common protected functions for RecurrentOp/RecurrentGradOp
class RecurrentBase : public framework::OperatorBase {
public:
RecurrentBase(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
protected:
// Get SequenceLength from Scope
// The sequence length is got from input tensor. The input tensor's
// dimension should be [SEQ_LEN, ..., ...]. The first of the tensor's shape
// is SEQ_LEN. The second of the tensor's shape could be the batch size or
// nested sequence length.
int64_t GetSequenceLength(const framework::Scope &scope) const {
// Dim format SEQ_LEN, BATCH_SIZE, ...
int64_t seq_len = -1;
auto &all_inputs = Inputs(kInputs);
PADDLE_ENFORCE(!all_inputs.empty());
for (auto &iname : all_inputs) {
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>());
auto &dim = var->Get<framework::LoDTensor>().dims();
if (seq_len == -1) {
seq_len = dim[0];
} else {
PADDLE_ENFORCE_EQ(seq_len, dim[0]);
}
}
return seq_len;
}
// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars),
// map(dst_scope.Var, dst_vars)):
// dst_tensor.ShareDataWith(src_tensor)
static void LinkTensor(const framework::Scope &src_scope,
const std::vector<std::string> &src_vars,
framework::Scope *dst_scope,
const std::vector<std::string> &dst_vars) {
LinkTensorWithCallback(
src_scope, src_vars, dst_scope, dst_vars,
[&](const framework::Tensor &src, framework::Tensor *dst) {
dst->ShareDataWith(src);
});
}
// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars),
// map(dst_scope.Var, dst_vars)):
// callback(src_tensor, &dst_tensor)
template <typename Callback>
static void LinkTensorWithCallback(const framework::Scope &src_scope,
const std::vector<std::string> &src_vars,
framework::Scope *dst_scope,
const std::vector<std::string> &dst_vars,
Callback callback) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
}
}
// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars),
// map(dst_scope.FindVar, dst_vars)):
// callback(src_tensor, &dst_tensor)
template <typename Callback>
static void LinkTensorWithCallback(const framework::Scope &src_scope,
const std::vector<std::string> &src_vars,
const framework::Scope &dst_scope,
const std::vector<std::string> &dst_vars,
Callback callback) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
}
}
// (seq_len, shape) -> return [seq_len] + list(shape)
static framework::DDim PrependDims(size_t seq_len,
const framework::DDim &src) {
auto dims = framework::vectorize(src);
dims.insert(dims.begin(), static_cast<int64_t>(seq_len));
return framework::make_ddim(dims);
}
private:
template <typename Callback>
static void AccessTensor(const framework::Scope &src_scope,
const std::string &src_var_name,
framework::Scope *dst_scope,
const std::string &dst_var_name, Callback callback) {
auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr);
auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope->Var(dst_var_name);
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor);
}
template <typename Callback>
static void AccessTensor(const framework::Scope &src_scope,
const std::string &src_var_name,
const framework::Scope &dst_scope,
const std::string &dst_var_name, Callback callback) {
auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr);
auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope.FindVar(dst_var_name);
PADDLE_ENFORCE(dst_var != nullptr);
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor);
}
};
class RecurrentOp : public RecurrentBase {
public:
RecurrentOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
VLOG(3) << "Static RNN input sequence length = " << seq_len;
StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) {
size_t seq_offset = reverse ? seq_len - i - 1 : i;
VLOG(3) << "Recurrent operate at the time step " << seq_offset;
auto &cur_scope = scopes.CurScope();
// Link outside::input --> inside::input
// inside::input = outside::input[seq_offset: seq_offset+1]
LinkTensorWithCallback(
scope, Inputs(kInputs), &cur_scope, Inputs(kInputs),
[&seq_offset](const framework::Tensor &outside,
framework::Tensor *inside) {
inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1));
auto dims = framework::vectorize(inside->dims());
dims.erase(dims.begin());
inside->Resize(framework::make_ddim(dims));
});
if (i == 0) {
// Link initial states --> ex_states
LinkTensor(scope, Inputs(kInitialStates), &cur_scope,
Attr<std::vector<std::string>>(kExStates));
} else {
auto &ex_scope = scopes.ExScope();
// Link ex_scope::state --> cur_scope::ex_state
LinkTensor(ex_scope, Attr<std::vector<std::string>>(kStates),
&cur_scope, Attr<std::vector<std::string>>(kExStates));
}
// Every inputs are linked now, execute!
executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/);
// Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
this->LinkTensorWithCallback(
cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
if (i == 0) { // create output tensor at begin
dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims()));
dst_tensor->mutable_data(dev_ctx.GetPlace(), src_tensor.type());
}
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
// Explicit copy output since the local RNN scope can be destroyed
// early.
dst_out.CopyFrom(src_tensor, dev_ctx.GetPlace(), dev_ctx);
});
scopes.Next();
}
}
private:
StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len);
}
};
class RecurrentGradOp : public RecurrentBase {
public:
RecurrentGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto seq_len = static_cast<size_t>(GetSequenceLength(scope));
StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program();
for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
auto &cur_scope = scopes.CurScope();
// Link outside::output_grads --> inside::output_grads
// inside::output_grad = outside::output_grad[seq_offset:seq_offset+1]
LinkTensorWithCallback(
scope, Inputs(kOutputGrads), &cur_scope, Inputs(kOutputGrads),
[&](const framework::Tensor &outside, framework::Tensor *inside) {
inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1));
auto dims = framework::vectorize(inside->dims());
dims.erase(dims.begin());
inside->Resize(framework::make_ddim(dims));
});
auto og_set = List2Set(Inputs(kOutputGrads));
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
std::copy(og_set.begin(), og_set.end(),
std::ostream_iterator<std::string>(sout, ","));
VLOG(10) << " RNN output gradients = [" << sout.str() << "]";
}
// Link states
// if cur_scope::cur_state_grad in out_grads:
// cur_scope::cur_state_grad += ex_scope::ex_state_grad
// else:
// ex_scope::ex_state_grad --> cur_scope::cur_state_grad
if (step_id != 0) { // not at beginning
auto &ex_scope = scopes.ExScope();
auto ex_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kExStates));
auto cur_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kStates));
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size());
for (size_t i = 0; i < ex_state_grads.size(); ++i) {
auto &cur_grad = cur_state_grads[i];
auto &ex_grad = ex_state_grads[i];
auto &ex_tensor =
ex_scope.FindVar(ex_grad)->Get<framework::LoDTensor>();
VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad;
auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>();
cur_grad_tensor->CopyFrom(ex_tensor, dev_ctx.GetPlace(), dev_ctx);
}
}
// create stepnet's outputs
for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) {
step_scope.Var(var_name);
VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope
executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/);
VLOG(5) << "executor.Run finished ";
auto local_var_names = LocalVarNames(cur_scope);
// Accumulate params
// if (step == 0):
// outside::param_grad = 0.0
// outside::param_grad += inside::param_grad
{
auto &pg_names = Outputs(kParamGrads);
auto &p_names = Inputs(kParameters);
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
for (size_t prog_id = 0; prog_id < pg_names.size(); ++prog_id) {
auto inside_grad_name = framework::GradVarName(p_names[prog_id]);
// If does not compute gradient of that variable inside rnn, just
// continue
if (local_var_names.find(inside_grad_name) == local_var_names.end()) {
continue;
}
// zero gradient variable in step 0
if (step_id == 0) {
auto &inside_tensor = cur_scope.FindVar(inside_grad_name)
->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["data_type"] = framework::ToDataType(inside_tensor.type());
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", {}, {{"Out", {pg_names[prog_id]}}}, attrs);
zero_op->Run(scope, dev_ctx);
}
// sum gradient
auto *outside_var = scope.FindVar(pg_names[prog_id]);
PADDLE_ENFORCE(outside_var != nullptr);
auto &outside_tensor =
*outside_var->GetMutable<framework::LoDTensor>();
std::string result_var_name;
auto *local_result_var = cur_scope.Var(&result_var_name);
auto &local_result_tensor =
*local_result_var->GetMutable<framework::LoDTensor>();
local_result_tensor.ShareDataWith(outside_tensor);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {result_var_name, inside_grad_name}}},
{{"Out", {result_var_name}}}, {});
sum_op->Run(cur_scope, dev_ctx);
}
}
step_scopes->emplace_back(&step_scope);
VLOG(5) << "Accumulate Parameter finished ";
// Copy input gradient from inside to outside
// outside::input_grad[seq_offset: seq_offset + 1] = inside::input_grad
LinkTensorWithCallback(
cur_scope, GradVarLists(Inputs(kInputs)), scope, Outputs(kInputGrads),
[&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) {
if (inside.memory_size() == 0) { // IG is not created.
return;
}
if (step_id == 0) { // alloc memory
outside->Resize(PrependDims(seq_len, inside.dims()));
outside->mutable_data(dev_ctx.GetPlace(), inside.type());
}
auto dst = outside->Slice(seq_offset, seq_offset + 1);
dst.CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx);
});
VLOG(5) << "Link outside gradient finished ";
if (step_id + 1 == seq_len) { // at_end
// copy initialize states gradient from inside to outside
LinkTensorWithCallback(
cur_scope, GradVarLists(Attr<std::vector<std::string>>(kExStates)),
scope, Outputs(kInitStateGrads),
[&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) {
outside->Resize(inside.dims());
outside->mutable_data(dev_ctx.GetPlace(), inside.type());
outside->CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx);
});
VLOG(5) << "Link initialize state gradient finished ";
}
scopes.Next();
}
}
}
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
for (auto& attr : arg_->states) {
auto* pre_mem = step_scope->Var(attr.pre_var)->GetMutable<LoDTensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", attr.var,
attr.boot_var);
auto* boot_mem =
step_scope->FindVar(attr.boot_var)->GetMutable<LoDTensor>();
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
pre_mem->ShareDataWith(*boot_mem);
}
}
const rnn::ArgumentName RecurrentOp::kArgName{
"step_net", "step_scopes", "inputs", "outputs",
"states", "ex_states", "initial_states"};
const rnn::ArgumentName RecurrentGradientOp::kArgName{
"step_net", "step_scopes@GRAD", "outputs@GRAD", "inputs@GRAD",
"states", "ex_states", "initial_states@GRAD"};
RecurrentOp::RecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
}
class RecurrentAlgorithmProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
private:
StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
}
std::unordered_set<std::string> List2Set(
const std::vector<std::string> &list) const {
std::unordered_set<std::string> local_var_name_set;
local_var_name_set.reserve(list.size());
for (auto &each : list) {
local_var_name_set.insert(each);
}
return local_var_name_set;
}
std::unordered_set<std::string> LocalVarNames(
const framework::Scope &scope) const {
return this->List2Set(scope.GetAllNames(false));
}
static std::vector<std::string> GradVarLists(
const std::vector<std::string> &var_names) {
std::vector<std::string> retv;
retv.reserve(var_names.size());
std::transform(var_names.begin(), var_names.end(), std::back_inserter(retv),
framework::GradVarName);
return retv;
}
};
class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
RecurrentAlgorithmProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
RecurrentOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
AddInput(kInputs, "rnn inputs").AsDuplicable();
AddInput(kInitialStates, "rnn initial states").AsDuplicable();
AddInput(kParameters,
"Parameters are used by step block as its input. However, the "
"inputs is not a sequence tensor. Every time step, each operator "
"in step block just use the parameter directly")
.AsDuplicable();
AddInput(name.initial_states, "variables to initialize states.")
AddOutput(kOutputs,
"The output sequence of RNN. The sequence length must be same")
.AsDuplicable();
AddOutput(kStepScopes,
"StepScopes contains all local variables in each time step.");
AddAttr<std::vector<std::string>>(kExStates,
string::Sprintf(
R"DOC(The ex-state variable names.
The ex-state means the state value in the ex-timestep or the previous time step
[%s, %s, %s] must be the same order)DOC",
kExStates, kStates, kInitStateGrads));
AddAttr<std::vector<std::string>>(
kStates,
string::Sprintf(
"The state variable names. [%s, %s, %s] must be the same order",
kExStates, kStates, kInitStateGrads));
AddAttr<framework::BlockDescBind *>(kStepBlock,
"The step block inside RNN");
AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not.
By default reverse=False
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
Assume the input data is [A, B, C, D]
if reverse is False:
the computation of RNN is like
A B C D
| | | |
v v v v
rnn -----> rnn -----> rnn ----> rnn
| | | |
v v v v
o o o o
if reverse is True
the computation of RNN is like
A B C D
| | | |
v v v v
rnn <----- rnn <----- rnn <---- rnn
| | | |
v v v v
o o o o
)DOC").SetDefault(false);
AddAttr<bool>(kIsTrain, "").SetDefault(true);
AddComment(R"DOC(Static Length Recurrent Operator
The static length recurrent operator can only operate on fix sized sequence
data, i.e. in each mini-batch, the sequence length of all inputs are same.
)DOC");
}
};
class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.ex_states, "names of pre-states");
AddAttr<std::vector<std::string>>(name.states, "names of states");
protected:
virtual std::unique_ptr<framework::OpDescBind> Apply() const {
auto *grad = new framework::OpDescBind();
grad->SetType("recurrent_grad");
for (auto &input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param));
}
for (auto &output_param : this->OutputNames()) {
if (output_param == kStepScopes) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param));
} else {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
}
}
grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
AddComment("This is a recurrent group operator.");
return std::unique_ptr<framework::OpDescBind>(grad);
}
};
void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len - 1) {
rnn::LinkMemories(step_scopes, arg_->states, step_id, 1);
class RecurrentGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
std::vector<std::string> input{kInputs, kInitialStates};
std::vector<std::string> output{kOutputs};
for (auto &s : input) {
PADDLE_ENFORCE(ctx->HasInputs(s));
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)));
}
for (auto &s : output) {
PADDLE_ENFORCE(ctx->HasInputs(s));
}
for (auto &s : input) {
ctx->SetOutputsDim(framework::GradVarName(s), ctx->GetInputsDim(s));
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx);
LinkBootMemoryGradients(step_scopes[0]);
}
void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope) const {
for (auto& attr : arg_->states) {
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
"memory variable [%s] does not exists", attr.var);
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"boot variable [%s] does not exists", attr.boot_var);
auto* mem_grad = step_scope->Var(attr.var)->GetMutable<LoDTensor>();
auto* boot_mem_grad =
step_scope->Var(attr.boot_var)->GetMutable<LoDTensor>();
boot_mem_grad->Resize(mem_grad->dims());
boot_mem_grad->ShareDataWith(*mem_grad);
}
}
RecurrentGradientOp::RecurrentGradientOp(
const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this, true /*is grad*/);
alg_.Init(&arg_, &stepnet_);
}
if (ctx->HasInputs(kParameters)) {
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)));
ctx->SetOutputsDim(framework::GradVarName(kParameters),
ctx->GetInputsDim(kParameters));
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(recurrent, paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker,
recurrent_grad, paddle::operators::RecurrentGradientOp);
REGISTER_OPERATOR(recurrent, paddle::operators::RecurrentOp,
paddle::operators::RecurrentOpProtoMaker,
paddle::operators::RecurrentGradOpDescMaker);
REGISTER_OPERATOR(recurrent_grad, paddle::operators::RecurrentGradOp,
paddle::operators::RecurrentGradOpShapeInference);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/operator.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
namespace operators {
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
// TODO(Superjom)
// 1. No-padding computing for sequences with indifinite length in one batch.
// 2. Hierarchical RNN for sequence with sub-sequence.
// 3. Internal Memory.
// 4. More Complex RNN architecture, such as Gated Feedback RNN.
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
class RecurrentAlgorithm {
public:
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(rnn::Argument* arg,
std::unique_ptr<framework::OperatorBase>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg;
stepnet_ = stepnet;
}
protected:
/*
* The step scopes will be stored in the father scope as a variable.
*
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void CreateScopes(const framework::Scope& scope, size_t seq_len) const;
const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>();
}
void InitMemories(framework::Scope* step_scopes) const;
private:
std::unique_ptr<framework::OperatorBase>* stepnet_;
rnn::Argument* arg_;
};
class RecurrentGradientAlgorithm {
/**
* RNN's backward alogorithm.
*
* To accelerate the development of RecurrentGradientOp, we decouple RNN's
* algorithm and `OperatorBase`'s implementation, the former contains the core
* implementation of a RNN, and will keep stable even if the framework changes
* a
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
* operator.
*/
public:
void Init(rnn::Argument* arg,
std::unique_ptr<framework::OperatorBase>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg);
stepnet_ = stepnet;
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(framework::Scope* step_scopes) const;
protected:
inline const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>();
}
private:
rnn::Argument* arg_;
std::unique_ptr<framework::OperatorBase>* stepnet_;
};
class RecurrentOp : public framework::OperatorBase {
public:
RecurrentOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
RecurrentOp(const RecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx);
}
void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}
const OperatorBase& stepnet() const { return *stepnet_; }
static const rnn::ArgumentName kArgName;
private:
RecurrentAlgorithm alg_;
rnn::Argument arg_;
std::unique_ptr<OperatorBase> stepnet_;
};
class RecurrentGradientOp : public framework::OperatorBase {
public:
RecurrentGradientOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
RecurrentGradientOp(const RecurrentGradientOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement Copy ctor.
PADDLE_THROW("Not Implemented");
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx);
}
static const rnn::ArgumentName kArgName;
/*
* set a stepnet that is created according to a RecurrentOp's stepnet.
*/
void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}
const OperatorBase& stepnet() const { return *stepnet_; }
private:
RecurrentGradientAlgorithm alg_;
std::unique_ptr<OperatorBase> stepnet_;
rnn::Argument arg_;
};
} // namespace operators
} // namespace paddle
......@@ -133,11 +133,10 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
auto x_grad_name = framework::GradVarName("X");
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE(ctx->HasInput(out_grad_name), "");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), "");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name);
PADDLE_ENFORCE(ctx->HasInput("X"), "");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
};
......
......@@ -29,22 +29,27 @@ template <typename Place, typename T>
class SumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& in_vars = context.MultiInputVar("X");
auto in_vars = context.MultiInputVar("X");
int N = in_vars.size();
auto out_var = context.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto result = EigenVector<T>::Flatten(*out);
math::SetConstant<Place, T> constant_functor;
constant_functor(context.device_context(), out, 0.0);
if (!in_place) {
math::SetConstant<Place, T> constant_functor;
constant_functor(context.device_context(), out, 0.0);
}
math::SelectedRowsAddToTensor<Place, T> functor;
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < N; i++) {
// If in_place, just skip the first tensor
for (int i = in_place ? 1 : 0; i < N; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto& in_t = in_vars[i]->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
......@@ -57,6 +62,7 @@ class SumKernel : public framework::OpKernel<T> {
}
}
} else if (out_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
auto* out = context.Output<SelectedRows>("Out");
auto* out_value = out->mutable_value();
......
......@@ -28,7 +28,6 @@ limitations under the License. */
#include "paddle/operators/cond_op.h"
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/exception.h"
......@@ -428,25 +427,6 @@ All parameter, weight, gradient are variables in Paddle.
return self.UnstackShared(source);
});
// recurrent_op
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
.def_static(
"create",
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone());
});
py::class_<operators::DynamicRecurrentOp, OperatorBase>(m,
"DynamicRecurrentOp")
.def_static("create",
......
......@@ -62,7 +62,7 @@ class Executor(object):
outputs={'Out': [fetch_var]},
attrs={'col': i})
self.executor.run(program.desc, scope, 0)
self.executor.run(program.desc, scope, 0, True)
return [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
......
......@@ -270,7 +270,8 @@ class Operator(object):
self.desc.check_attrs()
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'rnn_memory_helper_grad'
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
Operator
from paddle.v2.framework.initializer import ConstantInitializer
import re
......@@ -32,7 +33,6 @@ def fc(input,
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype)
tmp = helper.create_tmp_variable(dtype)
......@@ -88,8 +88,17 @@ def data(name,
program=None,
init_program=None):
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in xrange(len(shape)):
if shape[i] is None:
shape[i] = -1
append_batch_size = False
elif shape[i] < 0:
append_batch_size = False
if append_batch_size:
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
name=name, shape=shape, dtype=data_type, type=type)
......@@ -165,6 +174,9 @@ _create_op_func_('mul')
_create_op_func_('elementwise_add')
_create_op_func_('dropout')
_create_op_func_('reshape')
_create_op_func_('elementwise_add')
_create_op_func_('sigmoid')
_create_op_func_('scale')
def cast(x, data_type, program=None):
......@@ -193,7 +205,7 @@ def concat(input, axis, program=None, init_program=None):
def sums(input, program=None, init_program=None):
helper = LayerHelper('sum', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(type='sum', inputs={'X': [input]}, outputs={'Out': out})
helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out})
return out
......@@ -346,7 +358,7 @@ def conv2d(input,
'paddings': padding,
'groups': groups})
pre_act = helper.append_bias_op(pre_bias)
pre_act = helper.append_bias_op(pre_bias, 1)
return helper.append_activation(pre_act)
......@@ -518,6 +530,8 @@ class StaticRNNGuard(BlockGuard):
return super(StaticRNNGuard, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_rnn_op()
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
......@@ -577,7 +591,7 @@ class StaticRNN(object):
outputs={'Out': [boot_var]},
attrs={
'value': init_value,
'shape': boot_var.shape,
'shape': [40] + list(boot_var.shape[1:]),
'data_type': boot_var.data_type
})
......@@ -596,14 +610,14 @@ class StaticRNN(object):
if not isinstance(x, Variable):
raise TypeError("step input takes a Variable")
if self.seq_len is None:
self.seq_len = x.shape[1]
elif self.seq_len != x.shape[1]:
self.seq_len = x.shape[0]
elif self.seq_len != x.shape[0]:
raise ValueError("Static RNN only take fix seq_len input")
ipt = self.helper.create_variable(
name=x.name,
dtype=x.data_type,
shape=[-1] + list(x.shape[2:]),
shape=list(x.shape[1:]),
type=x.type)
self.inputs.append(ipt)
return ipt
......@@ -613,10 +627,17 @@ class StaticRNN(object):
if not isinstance(o, Variable):
raise TypeError("step output takes a Variable")
tmp_o = self.helper.create_tmp_variable(dtype=o.data_type)
self.helper.append_op(
type='rnn_memory_helper',
inputs={'X': [o]},
outputs={'Out': tmp_o},
attrs={'data_type': o.data_type})
out_var = self.parent_block().create_var(
name=o.name,
shape=[-1, self.seq_len] + list(o.shape[1:]),
dtype=o.data_type)
name=tmp_o.name,
shape=[self.seq_len] + list(tmp_o.shape),
dtype=tmp_o.data_type)
self.outputs.append(out_var)
......@@ -647,6 +668,68 @@ class StaticRNN(object):
return self.outputs
def complete_rnn_op(self):
# TODO(yuyang18): Create RNN Op here.
# Implement this method after RNN op complete.
pass
program = self.helper.program
rnn_block = program.current_block()
parent_block = self.parent_block()
local_inputs = set()
for op in rnn_block.ops:
assert isinstance(op, Operator)
for oname in op.output_names:
for out_var_name in op.output(oname):
local_inputs.add(out_var_name)
for var in self.inputs:
local_inputs.add(var.name)
for m in self.memories:
local_inputs.add(m)
params = list()
for op in rnn_block.ops:
assert isinstance(op, Operator)
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in local_inputs:
params.append(in_var_name)
parameters = [parent_block.var(name) for name in params]
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
inlinks = [parent_block.var(i.name) for i in self.inputs]
outlinks = self.outputs
boot_memories = []
pre_memories = []
memories = []
for _, mem in self.memories.iteritems():
boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name)
assert isinstance(mem_var, Variable)
new_mem = self.helper.create_tmp_variable(dtype=mem_var.data_type)
rnn_block.append_op(
type='rnn_memory_helper',
inputs={'X': [mem_var]},
outputs={'Out': [new_mem]},
attrs={'data_type': mem_var.data_type})
memories.append(new_mem.name)
parent_block.append_op(
type='recurrent',
inputs={
'inputs': inlinks,
'initial_states': boot_memories,
'parameters': parameters
},
outputs={'outputs': outlinks,
'step_scopes': [step_scope]},
attrs={
'ex_states': pre_memories,
'states': memories,
'step_block': rnn_block
})
import logging
import paddle.v2.framework.core as core
import unittest
import numpy as np
from paddle.v2.framework.op import Operator, RecurrentOp
from op_test import get_numeric_gradient
def py_sigmoid(x):
return 1. / (1. + np.exp(-x))
import logging
from op_test import get_numeric_gradient
from paddle.v2.framework.layers import *
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.backward import append_backward_ops
import numpy as np
import paddle.v2.framework.core as core
class PySimpleRNN(object):
'''
A simple implementation of RNN based on numpy, to futhur test RecurrentOp's alogorithm
'''
def __init__(self, input_dim=30, batch_size=50, weight_dim=15, sent_len=11):
self.x = np.random.normal(size=(sent_len, batch_size,
input_dim)).astype("float32")
self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.h_boot = np.random.normal(size=(batch_size,
input_dim)).astype("float32")
class PyRNNBase(object):
def __init__(self, input_shape, output_shape):
self.x = np.ones(shape=input_shape).astype("float32")
self.y = np.zeros(shape=output_shape).astype("float32")
# memories
self.mems = [
np.zeros(shape=(batch_size, input_dim)).astype("float32")
for i in range(sent_len)
]
def step(self):
pass
def forward(self):
xs = self.segment_inputs()
for step_id in range(self.x.shape[0]):
self.step(step_id, xs[step_id])
return self.concat_outputs()
self.step(step_id, self.x[step_id])
return np.array([np.mean(self.y)])
def segment_inputs(self):
return [self.x[i] for i in range(self.x.shape[0])]
def concat_outputs(self):
return np.array(self.mems).astype("float32")
class PySimpleRNN1(PyRNNBase):
def __init__(self, input_shape, output_shape):
super(PySimpleRNN1, self).__init__(input_shape, output_shape)
seq_len, batch_size, input_dim = input_shape
self.h_boot = np.random.normal(size=(batch_size,
input_dim)).astype("float32")
self.scale = 1.0 / 2.0
men_dim = (seq_len, batch_size, input_dim)
self.mems = np.zeros(shape=men_dim).astype("float32")
def step(self, step_id, x):
if step_id == 0:
pre_mem = self.h_boot
else:
pre_mem = self.mems[step_id - 1]
self.mems[step_id] = (pre_mem + x) * self.scale
self.y[step_id] = self.mems[step_id]
class PySimpleRNN2(PyRNNBase):
def __init__(self, input_shape, output_shape):
super(PySimpleRNN2, self).__init__(input_shape, output_shape)
seq_len, batch_size, input_dim = input_shape
self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.h_boot = np.ones(shape=(batch_size, input_dim)).astype("float32")
men_dim = (seq_len, batch_size, input_dim)
self.mems = np.zeros(shape=men_dim).astype("float32")
def step(self, step_id, x):
'''
run a step
'''
mem = self.mems[step_id]
if step_id > 0:
pre_mem = self.mems[step_id - 1]
else:
......@@ -53,108 +69,124 @@ class PySimpleRNN(object):
xW = np.matmul(x, self.W).astype("float32")
hU = np.matmul(pre_mem, self.U).astype("float32")
sum = xW + hU
self.mems[step_id] = py_sigmoid(sum)
def py_sigmoid(x):
return 1. / (1. + np.exp(-x))
class PySimpleRNNTest(unittest.TestCase):
def setUp(self):
self.rnn = PySimpleRNN()
def test_forward(self):
output = self.rnn.forward()
self.mems[step_id] = py_sigmoid(xW + hU)
self.y[step_id] = self.mems[step_id]
def create_tensor(scope, name, shape, np_data):
tensor = scope.var(name).get_tensor()
tensor.set_dims(shape)
tensor.set(np_data, core.CPUPlace())
def create_tensor(np_data, place):
tensor = core.LoDTensor()
tensor.set(np_data, place)
return tensor
class RecurrentOpTest(unittest.TestCase):
class RecurrentOpTest1(unittest.TestCase):
'''
Test RNNOp
equation:
h_t = \sigma (W x_t + U h_{t-1})
weights:
- W
- U
h_t = ( x_t + h_{t-1} ) / scale
vars:
- x
memories:
- h
outputs:
- h
- h
'''
input_dim = 30
batch_size = 50
weight_dim = 15
sent_len = 11
input_dim = 2
batch_size = 1
sent_len = 1
def init_program(self):
self.program = Program()
self.init_program = Program()
self.p_info = {
"program": self.program,
"init_program": self.init_program
}
self.place = core.CPUPlace()
def setUp(self):
self.py_rnn = PySimpleRNN(self.input_dim, self.batch_size,
self.weight_dim, self.sent_len)
self.init_program()
self.data_field = {"x", "h_boot"}
def forward(self):
self.scope = core.Scope()
self.create_global_variables()
self.create_rnn_op()
self.create_step_net()
ctx = core.DeviceContext.create(core.CPUPlace())
self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h@mem").get_tensor()).astype(
"float32")
def create_global_variables(self):
# create inlink
x_np_data = self.py_rnn.x
create_tensor(self.scope, "x",
[self.sent_len, self.batch_size, self.input_dim],
x_np_data)
W_np_data = self.py_rnn.W
create_tensor(self.scope, "W", [self.input_dim, self.input_dim],
W_np_data)
U_np_data = self.py_rnn.U
create_tensor(self.scope, "U", [self.input_dim, self.input_dim],
U_np_data)
h_boot_np_data = self.py_rnn.h_boot
create_tensor(self.scope, "h_boot", [self.batch_size, self.input_dim],
h_boot_np_data)
self.scope.var("step_scopes")
self.scope.var("h@mem")
self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape)
self.output = mean(x=self.create_rnn_op(), **self.p_info)
def create_rnn_op(self):
# create RNNOp
self.rnnop = RecurrentOp(
# inputs
inputs=["x"],
initial_states=["h_boot"],
step_net="stepnet",
# outputs
outputs=["h@mem"],
step_scopes="step_scopes",
# attributes
ex_states=["h@pre"],
states=["h@mem"])
def create_step_net(self):
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@mem")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.append_op(op)
stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet)
def test_forward(self):
x = data(
shape=[self.sent_len, self.batch_size, self.input_dim],
data_type='float32',
name='x',
append_batch_size=False,
**self.p_info)
h_boot = data(
shape=[self.input_dim],
data_type='float32',
name='h_boot',
**self.p_info)
rnn = StaticRNN(program=self.program)
with rnn.step():
h_pre = rnn.memory(init=h_boot)
x_t = rnn.step_input(x)
h = scale(
x=elementwise_add(
x=h_pre, y=x_t, **self.p_info),
scale=self.py_rnn.scale,
**self.p_info)
rnn.update_memory(h_pre, h)
rnn.output(h)
return rnn()
def forward(self):
self.feed_map = {
x: create_tensor(getattr(self.py_rnn, x), self.place)
for x in self.data_field
}
exe = Executor(self.place)
out = exe.run(self.program,
feed=self.feed_map,
fetch_list=[self.output])
return np.array(out[0])
def backward(self):
self.feed_map = {
x: create_tensor(getattr(self.py_rnn, x), self.place)
for x in self.data_field
}
fetch_list = [
self.program.global_block().var(x + "@GRAD")
for x in self.data_field
]
exe = Executor(self.place)
return exe.run(self.program, feed=self.feed_map, fetch_list=fetch_list)
def test_backward(self):
self.check_forward()
append_backward_ops(self.output)
ana_grad = [np.array(x) for x in self.backward()]
num_grad = self.get_numerical_gradient()
for idx, name in enumerate(self.data_field):
self.assertEqual(num_grad[idx].shape, ana_grad[idx].shape)
self.assertTrue(
np.isclose(
num_grad[idx], ana_grad[idx], rtol=0.1).all())
def check_forward(self):
print 'test recurrent op forward'
pd_output = self.forward()
py_output = self.py_rnn.forward()
......@@ -164,44 +196,190 @@ class RecurrentOpTest(unittest.TestCase):
self.assertEqual(pd_output.shape, py_output.shape)
self.assertTrue(np.isclose(pd_output, py_output, rtol=0.1).all())
def get_numerical_gradient(self, delta=0.005):
dloss_dout = 1.0
feed_list = [getattr(self.py_rnn, x) for x in self.data_field]
grad_list = [np.zeros_like(x) for x in feed_list]
for feed, grad in zip(feed_list, grad_list):
for f, g in np.nditer([feed, grad], op_flags=['readwrite']):
o = float(f)
f[...] = o + delta
y_pos = self.forward()
class RecurrentGradientOpTest(unittest.TestCase):
def create_forward_op(self):
self.forward_op = RecurrentOp(
# inputs
inputs=["x"],
initial_states=["h_boot"],
step_net="stepnet",
# outputs
outputs=["h"],
step_scopes="step_scopes",
# attributes
ex_states=["h@pre"],
states=["h@alias"])
# create a stepnet for RNN
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@alias")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.append_op(op)
stepnet.complete_add_op(True)
self.forward_op.set_stepnet(stepnet)
def create_gradient_op(self):
a = set()
backward_op = core.RecurrentOp.backward(self.forward_op, a)
def test_grad(self):
self.create_forward_op()
self.create_gradient_op()
f[...] = o - delta
y_neg = self.forward()
f[...] = o
dout_dfeed = (y_pos - y_neg) / (delta * 2)
g[...] = dout_dfeed[0]
return grad_list
class RecurrentOpTest2(RecurrentOpTest1):
'''
Test RNNOp
equation:
h_t = \sigma (W x_t + U h_{t-1})
weights:
- W
- U
vars:
- x
memories:
- h
outputs:
- h
'''
input_dim = 2
batch_size = 10
sent_len = 2
def setUp(self):
self.init_program()
self.data_field = {"x", "h_boot", "W", "U"}
self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)
self.output = mean(x=self.create_rnn_op(), **self.p_info)
def create_rnn_op(self):
x = data(
shape=[self.sent_len, self.batch_size, self.input_dim],
data_type='float32',
name='x',
append_batch_size=False,
**self.p_info)
h_boot = data(
shape=[self.input_dim],
data_type='float32',
name='h_boot',
**self.p_info)
rnn = StaticRNN(program=self.program)
with rnn.step():
h_pre = rnn.memory(init=h_boot)
x_t = rnn.step_input(x)
temp_l = fc(input=x_t,
size=self.input_dim,
param_attr={'name': 'W'},
bias_attr=False,
**self.p_info)
temp_r = fc(input=h_pre,
size=self.input_dim,
param_attr={'name': 'U'},
bias_attr=False,
**self.p_info)
h = sigmoid(
x=elementwise_add(
x=temp_l, y=temp_r, **self.p_info),
**self.p_info)
rnn.update_memory(h_pre, h)
rnn.output(h)
return rnn()
class RecurrentOpTest3(RecurrentOpTest1):
'''
Test RNNOp with two memories
equation:
h_1 = h_pre_1
h_2 = h_pre_2
y = h_1 + h_2
vars:
- x
memories:
- h_1, h_2
outputs:
- y
'''
class PySimpleRNN3(PyRNNBase):
def __init__(self, input_shape, output_shape):
super(RecurrentOpTest3.PySimpleRNN3, self).__init__(input_shape,
output_shape)
seq_len, batch_size, input_dim = input_shape
self.h_boot1 = np.random.normal(size=(batch_size,
input_dim)).astype("float32")
self.h_boot2 = np.random.normal(size=(batch_size,
input_dim)).astype("float32")
men_dim = (seq_len, batch_size, input_dim)
self.mems1 = np.zeros(shape=men_dim).astype("float32")
self.mems2 = np.zeros(shape=men_dim).astype("float32")
def step(self, step_id, x):
if step_id == 0:
pre_mem1 = self.h_boot1
pre_mem2 = self.h_boot2
else:
pre_mem1 = self.mems1[step_id - 1]
pre_mem2 = self.mems2[step_id - 1]
self.mems1[step_id] = pre_mem1
self.mems2[step_id] = pre_mem2
self.y[step_id] = self.mems1[step_id] + self.mems2[step_id] + x
input_dim = 1
batch_size = 1
sent_len = 2
def setUp(self):
self.init_program()
self.data_field = {"x", "h_boot1", "h_boot2"}
self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = RecurrentOpTest3.PySimpleRNN3(self.input_shape,
self.output_shape)
self.output = mean(x=self.create_rnn_op(), **self.p_info)
def create_rnn_op(self):
x = data(
shape=[self.sent_len, self.batch_size, self.input_dim],
data_type='float32',
name='x',
append_batch_size=False,
**self.p_info)
h_boot1 = data(
shape=[self.batch_size, self.input_dim],
data_type='float32',
name='h_boot1',
append_batch_size=False,
**self.p_info)
h_boot2 = data(
shape=[self.batch_size, self.input_dim],
data_type='float32',
name='h_boot2',
append_batch_size=False,
**self.p_info)
rnn = StaticRNN(program=self.program)
with rnn.step():
h_pre1 = rnn.memory(init=h_boot1)
h_pre2 = rnn.memory(init=h_boot2)
x_t = rnn.step_input(x)
mem1 = scale(x=h_pre1, scale=1.0, **self.p_info)
mem2 = scale(x=h_pre2, scale=1.0, **self.p_info)
out = sums(input=[mem1, x_t, mem2], **self.p_info)
rnn.update_memory(h_pre1, mem1)
rnn.update_memory(h_pre2, mem2)
rnn.output(out)
return rnn()
if __name__ == '__main__':
exit(
0
) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
unittest.main()
import unittest
from paddle.v2.framework.layers import *
from paddle.v2.framework.framework import g_program
class TestRNN(unittest.TestCase):
def test_rnn(self):
img = data(
shape=[
80, # sequence length
22, # image height
22
], # image width
data_type='float32',
name='image')
hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2)
self.assertEqual((-1, 80, 100), hidden.shape)
hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2)
self.assertEqual((-1, 80, 100), hidden.shape)
rnn = StaticRNN()
with rnn.step():
hidden = rnn.step_input(hidden)
self.assertEqual((-1, 100), hidden.shape)
memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0)
rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid')
self.assertEqual((-1, 32), rnn_out.shape)
rnn.update_memory(memory, rnn_out)
rnn.output(rnn_out)
out = rnn()
self.assertEqual((-1, 80, 32), out.shape)
print g_program
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册