提交 83f4bc4e 编写于 作者: T tensor-tang

follow comment and refine code

上级 9838bacb
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -97,6 +98,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -97,6 +98,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"}); op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
#define TMP_NAME(x) "at.new.tmp." #x #define TMP_NAME(x) "at.new.tmp." #x
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)}) #define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
...@@ -134,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -134,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto fc_no_bias_handler = [&]( auto fc_no_bias_handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
#define GET_NODE(name__) \ #define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \ std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \ auto* name__##n = pattern->RetrieveNode(name__##key); \
......
...@@ -16,14 +16,10 @@ limitations under the License. */ ...@@ -16,14 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(seq_mode, true, "Use sequence mode");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -110,7 +106,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -110,7 +106,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Cell"); ctx->ShareLoD("X", "Cell");
int xx_width; int xx_width;
if (FLAGS_seq_mode) { if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
...@@ -189,6 +185,10 @@ void FusionLSTMOpMaker::Make() { ...@@ -189,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
"(bool, defalut: False) " "(bool, defalut: False) "
"whether to compute reversed LSTM.") "whether to compute reversed LSTM.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_seq",
"(bool, defalut: True) "
"whether to use seq mode to compute.")
.SetDefault(true);
AddAttr<std::string>("gate_activation", AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)" "(string, default: sigmoid)"
"The activation for input gate, forget gate and output " "The activation for input gate, forget gate and output "
...@@ -264,8 +264,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -264,8 +264,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int N = x_lod[0].size() - 1; // batch size const int N = x_lod[0].size() - 1; // batch size
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* c0_data = c0 ? c0->data<T>() : NULL; const T* c0_data = c0 ? c0->data<T>() : nullptr;
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>(); const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
...@@ -295,8 +295,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -295,8 +295,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i; int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_c_data = NULL; const T* prev_c_data = nullptr;
const T* prev_h_data = NULL; const T* prev_h_data = nullptr;
int tstart = 0; int tstart = 0;
if (h0_data) { if (h0_data) {
prev_h_data = h0_data + bid * D; prev_h_data = h0_data + bid * D;
...@@ -351,8 +351,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -351,8 +351,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext; using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT INIT_BASE_INPUT_OUTPUT
if (x->lod()[0].size() == 2) { // batch size == 1 if (x->lod()[0].size() == 2) {
SeqCompute(ctx); SeqCompute(ctx);
return;
} }
INIT_BASE_SIZES INIT_BASE_SIZES
INIT_VEC_FUNC INIT_VEC_FUNC
...@@ -396,8 +397,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -396,8 +397,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0->Resize({max_bs, D}); reordered_c0->Resize({max_bs, D});
int tstart = 0; int tstart = 0;
T* prev_h_data = NULL; T* prev_h_data = nullptr;
T* prev_c_data = NULL; T* prev_c_data = nullptr;
if (h0) { if (h0) {
// reorder h0, c0 // reorder h0, c0
T* reordered_h0_data = reordered_h0->mutable_data<T>(place); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
...@@ -489,7 +490,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -489,7 +490,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} }
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
if (FLAGS_seq_mode) { if (ctx.Attr<bool>("use_seq")) {
SeqCompute(ctx); SeqCompute(ctx);
} else { } else {
BatchCompute(ctx); BatchCompute(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册