提交 d5114c60 编写于 作者: J Jacek Czaja

- Reviewers suggesstions to fused_embedding_fc_lstm_op

上级 7ab5626d
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm>
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -98,17 +99,17 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -98,17 +99,17 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Copy only gate biases values (only actual bias data, not peephole // Copy only gate biases values (only actual bias data, not peephole
// weights) // weights)
std::vector<float> combined_biases(n, 0.0f); std::vector<float> combined_biases;
memcpy(&combined_biases[0], lstm_bias_tensor.data<float>(), combined_biases.reserve(n);
n * sizeof(float)); std::copy_n(lstm_bias_tensor.data<float>(), n,
std::back_inserter(combined_biases));
if (with_fc_bias) { if (with_fc_bias) {
// Add FC-bias with LSTM-bias (into GEMM result to be) // Add FC-bias with LSTM-bias (into GEMM result to be)
auto* fc_bias_var = scope->FindVar(fc_bias->Name()); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
for (int i = 0; i < fc_bias_tensor.numel(); i++) { for (int i = 0; i < fc_bias_tensor.numel(); i++) {
combined_biases[i] = combined_biases[i] += fc_bias_tensor.data<float>()[i];
lstm_bias_tensor.data<float>()[i] + fc_bias_tensor.data<float>()[i];
} }
} }
......
...@@ -63,10 +63,6 @@ void FusedEmbeddingFCLSTMOp::InferShape( ...@@ -63,10 +63,6 @@ void FusedEmbeddingFCLSTMOp::InferShape(
auto embeddings_dims = ctx->GetInputDim("Embeddings"); auto embeddings_dims = ctx->GetInputDim("Embeddings");
PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2, PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2,
"The rank of Input(Embeddings) should be 2."); "The rank of Input(Embeddings) should be 2.");
// PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
// "The first dimension of Input(Embeddings) "
// "should be %d.",
// x_dims[1]);
auto wh_dims = ctx->GetInputDim("WeightH"); auto wh_dims = ctx->GetInputDim("WeightH");
int frame_size = wh_dims[1] / 4; int frame_size = wh_dims[1] / 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册