提交 83f4bc4e 编写于 作者: T tensor-tang

follow comment and refine code

上级 9838bacb
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -97,6 +98,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
#define TMP_NAME(x) "at.new.tmp." #x
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
......@@ -134,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto fc_no_bias_handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
......
......@@ -16,14 +16,10 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(seq_mode, true, "Use sequence mode");
namespace paddle {
namespace operators {
......@@ -110,7 +106,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Cell");
int xx_width;
if (FLAGS_seq_mode) {
if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
......@@ -189,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<bool>("use_seq",
"(bool, defalut: True) "
"whether to use seq mode to compute.")
.SetDefault(true);
AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
......@@ -264,8 +264,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int N = x_lod[0].size() - 1; // batch size
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0 ? c0->data<T>() : NULL;
const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* c0_data = c0 ? c0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
......@@ -295,8 +295,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_c_data = NULL;
const T* prev_h_data = NULL;
const T* prev_c_data = nullptr;
const T* prev_h_data = nullptr;
int tstart = 0;
if (h0_data) {
prev_h_data = h0_data + bid * D;
......@@ -351,8 +351,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT
if (x->lod()[0].size() == 2) { // batch size == 1
if (x->lod()[0].size() == 2) {
SeqCompute(ctx);
return;
}
INIT_BASE_SIZES
INIT_VEC_FUNC
......@@ -396,8 +397,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0->Resize({max_bs, D});
int tstart = 0;
T* prev_h_data = NULL;
T* prev_c_data = NULL;
T* prev_h_data = nullptr;
T* prev_c_data = nullptr;
if (h0) {
// reorder h0, c0
T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
......@@ -489,7 +490,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
void Compute(const framework::ExecutionContext& ctx) const override {
if (FLAGS_seq_mode) {
if (ctx.Attr<bool>("use_seq")) {
SeqCompute(ctx);
} else {
BatchCompute(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册