提交 a747df02 编写于 作者: H hjchen2

Fix gru_unit bug

上级 b6993ba3
...@@ -14,13 +14,26 @@ limitations under the License. */ ...@@ -14,13 +14,26 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "framework/tensor.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include "tensor.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void TensorCopy(const Tensor &src, Tensor *dst); void TensorCopy(const Tensor& src, Tensor* dst);
template <typename T>
void TensorFromVector(const std::vector<T>& src, Tensor* dst);
template <typename T>
void TensorFromVector(const std::vector<T>& src, Tensor* dst) {
auto src_ptr = static_cast<const void*>(src.data());
dst->Resize({static_cast<int64_t>(src.size())});
auto dst_ptr = static_cast<void*>(dst->mutable_data<T>());
auto size = src.size() * sizeof(T);
memory::Copy(dst_ptr, src_ptr, size);
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -124,12 +124,12 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor( ...@@ -124,12 +124,12 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
id_tensor->set_lod(lod); id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())}); id_tensor->Resize({static_cast<int64_t>(id_data.size())});
id_tensor->mutable_data<int64_t>(); id_tensor->mutable_data<int64_t>();
// framework::TensorFromVector<int64_t>(id_data, cpu_ctx, id_tensor); framework::TensorFromVector<int64_t>(id_data, id_tensor);
score_tensor->set_lod(lod); score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())}); score_tensor->Resize({static_cast<int64_t>(score_data.size())});
score_tensor->mutable_data<T>(); score_tensor->mutable_data<T>();
// framework::TensorFromVector<T>(score_data, cpu_ctx, score_tensor); framework::TensorFromVector<T>(score_data, score_tensor);
} }
template <typename T> template <typename T>
......
...@@ -27,30 +27,39 @@ namespace operators { ...@@ -27,30 +27,39 @@ namespace operators {
template <typename P> template <typename P>
void GruUnitCompute(const GruUnitParam<CPU>& param) { void GruUnitCompute(const GruUnitParam<CPU>& param) {
// inputs
auto* input = param.InputInput(); auto* input = param.InputInput();
auto* hidden_prev = param.InputHiddenPrev(); auto* hidden_prev = param.InputHiddenPrev();
auto* weight = param.InputWeight(); auto* weight = param.InputWeight();
auto* bias = param.InputBias(); auto* bias = param.InputBias();
// outputs
auto* gate = param.OutGate(); auto* gate = param.OutGate();
gate->mutable_data<P>();
auto* reset_hidden_prev = param.OutResetHiddenPrev(); auto* reset_hidden_prev = param.OutResetHiddenPrev();
reset_hidden_prev->mutable_data<P>();
auto* hidden = param.OutHidden(); auto* hidden = param.OutHidden();
hidden->mutable_data<P>();
// add bias
if (bias) { if (bias) {
math::RowwiseAdd<CPU, float> add_bias; math::RowwiseAdd<CPU, float> add_bias;
add_bias(*gate, *bias, gate); add_bias(*input, *bias, gate);
} }
int batch_size = input->dims()[0]; int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1]; int frame_size = hidden_prev->dims()[1];
const P* weight_data = weight->data<P>(); const P* weight_data = weight->data<P>();
math::GRUMetaValue<P> gru_value; math::GRUMetaValue<P> gru_value;
gru_value.gate_weight = const_cast<P*>(weight_data); gru_value.gate_weight = const_cast<P*>(weight_data);
gru_value.state_weight = gru_value.state_weight =
const_cast<P*>(weight_data + 2 * frame_size * frame_size); const_cast<P*>(weight_data + 2 * frame_size * frame_size);
gru_value.output_value = hidden->data<P>();
gru_value.prev_out_value = const_cast<P*>(hidden_prev->data<P>()); gru_value.prev_out_value = const_cast<P*>(hidden_prev->data<P>());
gru_value.output_value = hidden->data<P>();
gru_value.gate_value = gate->data<P>(); gru_value.gate_value = gate->data<P>();
gru_value.reset_output_value = reset_hidden_prev->data<P>(); gru_value.reset_output_value = reset_hidden_prev->data<P>();
auto active_node = math::GetActivationType(param.Activation()); auto active_node = math::GetActivationType(param.Activation());
auto active_gate = math::GetActivationType(param.GateActivation()); auto active_gate = math::GetActivationType(param.GateActivation());
math::GRUUnitFunctor<CPU, float>::compute(gru_value, frame_size, batch_size, math::GRUUnitFunctor<CPU, float>::compute(gru_value, frame_size, batch_size,
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "common/common.h" #include "common/common.h"
#include "common/log.h" #include "common/log.h"
#include "framework/ddim.h" #include "framework/ddim.h"
#include "framework/tensor.h" #include "framework/lod_tensor.h"
static const char *g_ocr = "../models/ocr"; static const char *g_ocr = "../models/ocr";
static const char *g_mobilenet_ssd = "../models/mobilenet+ssd"; static const char *g_mobilenet_ssd = "../models/mobilenet+ssd";
...@@ -66,9 +66,10 @@ static const char *g_yolo_img = "../images/in_put_1_3_416_416_2"; ...@@ -66,9 +66,10 @@ static const char *g_yolo_img = "../images/in_put_1_3_416_416_2";
static const char *g_super_img = "../images/mingren_input_data"; static const char *g_super_img = "../images/mingren_input_data";
static const char *g_mobilenet_img = "../images/image"; static const char *g_mobilenet_img = "../images/image";
using namespace paddle_mobile;
using paddle_mobile::framework::DDim; using paddle_mobile::framework::DDim;
using paddle_mobile::framework::LoDTensor;
using paddle_mobile::framework::Tensor; using paddle_mobile::framework::Tensor;
using namespace paddle_mobile;
template <typename T> template <typename T>
void SetupTensor(paddle_mobile::framework::Tensor *input, void SetupTensor(paddle_mobile::framework::Tensor *input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册