提交 a747df02 编写于 作者: H hjchen2

Fix gru_unit bug

上级 b6993ba3
......@@ -14,13 +14,26 @@ limitations under the License. */
#pragma once
#include <vector>
#include "framework/tensor.h"
#include "memory/t_malloc.h"
#include "tensor.h"
namespace paddle_mobile {
namespace framework {
void TensorCopy(const Tensor &src, Tensor *dst);
void TensorCopy(const Tensor& src, Tensor* dst);
template <typename T>
void TensorFromVector(const std::vector<T>& src, Tensor* dst);
template <typename T>
void TensorFromVector(const std::vector<T>& src, Tensor* dst) {
auto src_ptr = static_cast<const void*>(src.data());
dst->Resize({static_cast<int64_t>(src.size())});
auto dst_ptr = static_cast<void*>(dst->mutable_data<T>());
auto size = src.size() * sizeof(T);
memory::Copy(dst_ptr, src_ptr, size);
}
} // namespace framework
} // namespace paddle_mobile
......@@ -124,12 +124,12 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())});
id_tensor->mutable_data<int64_t>();
// framework::TensorFromVector<int64_t>(id_data, cpu_ctx, id_tensor);
framework::TensorFromVector<int64_t>(id_data, id_tensor);
score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())});
score_tensor->mutable_data<T>();
// framework::TensorFromVector<T>(score_data, cpu_ctx, score_tensor);
framework::TensorFromVector<T>(score_data, score_tensor);
}
template <typename T>
......
......@@ -27,30 +27,39 @@ namespace operators {
template <typename P>
void GruUnitCompute(const GruUnitParam<CPU>& param) {
// inputs
auto* input = param.InputInput();
auto* hidden_prev = param.InputHiddenPrev();
auto* weight = param.InputWeight();
auto* bias = param.InputBias();
// outputs
auto* gate = param.OutGate();
gate->mutable_data<P>();
auto* reset_hidden_prev = param.OutResetHiddenPrev();
reset_hidden_prev->mutable_data<P>();
auto* hidden = param.OutHidden();
hidden->mutable_data<P>();
// add bias
if (bias) {
math::RowwiseAdd<CPU, float> add_bias;
add_bias(*gate, *bias, gate);
add_bias(*input, *bias, gate);
}
int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];
const P* weight_data = weight->data<P>();
math::GRUMetaValue<P> gru_value;
gru_value.gate_weight = const_cast<P*>(weight_data);
gru_value.state_weight =
const_cast<P*>(weight_data + 2 * frame_size * frame_size);
gru_value.output_value = hidden->data<P>();
gru_value.prev_out_value = const_cast<P*>(hidden_prev->data<P>());
gru_value.output_value = hidden->data<P>();
gru_value.gate_value = gate->data<P>();
gru_value.reset_output_value = reset_hidden_prev->data<P>();
auto active_node = math::GetActivationType(param.Activation());
auto active_gate = math::GetActivationType(param.GateActivation());
math::GRUUnitFunctor<CPU, float>::compute(gru_value, frame_size, batch_size,
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include "common/common.h"
#include "common/log.h"
#include "framework/ddim.h"
#include "framework/tensor.h"
#include "framework/lod_tensor.h"
static const char *g_ocr = "../models/ocr";
static const char *g_mobilenet_ssd = "../models/mobilenet+ssd";
......@@ -66,9 +66,10 @@ static const char *g_yolo_img = "../images/in_put_1_3_416_416_2";
static const char *g_super_img = "../images/mingren_input_data";
static const char *g_mobilenet_img = "../images/image";
using namespace paddle_mobile;
using paddle_mobile::framework::DDim;
using paddle_mobile::framework::LoDTensor;
using paddle_mobile::framework::Tensor;
using namespace paddle_mobile;
template <typename T>
void SetupTensor(paddle_mobile::framework::Tensor *input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册