recurrent_op_utils.h 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include <string>

#include "paddle/framework/operator.h"

namespace paddle {
namespace operators {
namespace rnn {

D
dongzhihong 已提交
25 26
using Scope = framework::Scope;

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
/**
 * Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
 *
 * Memory attributes cached by this op, dims will be infered from
 * boot memories in father scope. Other attributes are copied from Op's proto
 * attributes.
 */
struct MemoryAttr {
  // name of current state variable
  std::string var;
  // name of previous step's state variable
  std::string pre_var;
  // name of the variables to init this memory (same role of `boot_layer` in
  // PaddlePaddle), which is store in father's scope.
  std::string boot_var;
};

struct Argument {
  std::string step_net;
  std::string step_scopes;
Y
Yan Chunwei 已提交
47 48
  std::vector<std::string> inlinks;
  std::vector<std::string> outlinks;
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
  std::vector<rnn::MemoryAttr> memories;
};

struct ArgumentName {
  std::string step_net;
  std::string step_scopes;
  std::string inlinks;
  std::string outlinks;
  std::string memories;       // the memory name
  std::string pre_memories;   // the previous memory name
  std::string boot_memories;  // the boot memory name
};

/**
 * Prepare inputs for each step net.
 */
void SegmentInputs(const std::vector<Scope*>& step_scopes,
Y
Yan Chunwei 已提交
66 67
                   const std::vector<std::string>& inlinks,
                   const size_t seq_len, bool infer_shape_mode);
68 69 70 71 72

/**
 * Process outputs of step nets and merge to variables.
 */
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
Y
Yan Chunwei 已提交
73 74
                   const std::vector<std::string>& outlinks,
                   const size_t seq_len, bool infer_shape_mode);
75 76 77 78 79 80

void LinkMemories(const std::vector<Scope*>& step_scopes,
                  const std::vector<MemoryAttr>& memories, const size_t step_id,
                  const int offset, bool infer_shape_mode);

void InitArgument(const ArgumentName& name, Argument* arg,
D
dongzhihong 已提交
81
                  const framework::OperatorBase& op);
82 83 84 85

}  // namespace rnn
}  // namespace operators
}  // namespace paddle