recurrent_op_utils.h 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include <string>

#include "paddle/framework/operator.h"

namespace paddle {
namespace operators {
namespace rnn {

D
dongzhihong 已提交
25 26
using Scope = framework::Scope;

27 28 29 30 31 32 33
/**
 * Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
 *
 * Memory attributes cached by this op, dims will be infered from
 * boot memories in father scope. Other attributes are copied from Op's proto
 * attributes.
 */
34
struct StateAttr {
35 36 37 38 39 40 41 42 43 44 45 46
  // name of current state variable
  std::string var;
  // name of previous step's state variable
  std::string pre_var;
  // name of the variables to init this memory (same role of `boot_layer` in
  // PaddlePaddle), which is store in father's scope.
  std::string boot_var;
};

struct Argument {
  std::string step_net;
  std::string step_scopes;
S
superjom 已提交
47 48
  std::vector<std::string> inlinks;
  std::vector<std::string> outlinks;
49
  std::vector<rnn::StateAttr> states;
50 51 52 53 54 55 56
};

struct ArgumentName {
  std::string step_net;
  std::string step_scopes;
  std::string inlinks;
  std::string outlinks;
57 58 59
  std::string states;          // the memory name
  std::string ex_states;       // the previous memory name
  std::string initial_states;  // the boot memory name
60 61 62 63 64 65
};

/**
 * Prepare inputs for each step net.
 */
void SegmentInputs(const std::vector<Scope*>& step_scopes,
S
superjom 已提交
66
                   const std::vector<std::string>& inlinks,
Q
qiaolongfei 已提交
67
                   const size_t seq_len);
68 69 70 71 72

/**
 * Process outputs of step nets and merge to variables.
 */
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
S
superjom 已提交
73
                   const std::vector<std::string>& outlinks,
74
                   const size_t seq_len, const platform::DeviceContext& ctx);
75 76

void LinkMemories(const std::vector<Scope*>& step_scopes,
77
                  const std::vector<StateAttr>& memories, const size_t step_id,
Q
qiaolongfei 已提交
78
                  const int offset);
79 80

void InitArgument(const ArgumentName& name, Argument* arg,
S
superjom 已提交
81
                  const framework::OperatorBase& op, bool is_grad = false);
82 83 84 85

}  // namespace rnn
}  // namespace operators
}  // namespace paddle