未验证 提交 beb32dfe 编写于 作者: Y Yiqun Liu 提交者: GitHub

[X86] Optimize gru and softmax (#2877) (#2884)

test=develop test=release/2.3
上级 2ef97e68
......@@ -142,14 +142,13 @@ void StrideScal(const T* a, const T* x, T* y, int n, int stride);
// remain is the product of dimension shapes after the axis dimension
template <typename T>
void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
std::vector<T> entities(bs);
for (int i = 0; i < bs; ++i) {
entities[i] = x[i * n];
T entity = x[i * n];
for (int c = 1; c < n; ++c) {
entities[i] = x[i * n + c] > entities[i] ? x[i * n + c] : entities[i];
entity = x[i * n + c] > entity ? x[i * n + c] : entity;
}
for (int c = 0; c < n; ++c) {
y[i * n + c] = x[i * n + c] - entities[i];
y[i * n + c] = x[i * n + c] - entity;
}
}
VExp(y, y, n * bs);
......
......@@ -110,11 +110,7 @@ void set_constant(const lite::Context<Target>& context,
lite::Tensor* tensor,
float value) {
TensorSetConstantWithTarget<Target> func(context, tensor, value);
// #ifdef PADDLE_WITH_CUDA
// tensor->target().apply_visitor(func);
// #else
func();
// #endif
}
template <typename T>
......@@ -123,7 +119,7 @@ struct RowwiseAdd<lite::TargetType::kX86, T> {
const lite::Tensor& input,
const lite::Tensor& vector,
lite::Tensor* output) {
auto in_dims = input.dims();
const auto& in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector.numel(), size);
PADDLE_ENFORCE_EQ(output->dims(), in_dims);
......
......@@ -48,6 +48,10 @@ inline void ReorderInitState(const lite::Context<TARGET(kX86)>& context,
row_shuffle(context, src, index_lod, dst, indexed_src);
}
static inline int64_t CalculateSeqWidth(const DDim& dims) {
return dims.count(1, dims.size());
}
template <typename T>
class GRUCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
......@@ -65,15 +69,16 @@ class GRUCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* bias = param.bias;
auto* batch_gate = param.batch_gate;
batch_gate->mutable_data<T>();
auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
batch_reset_hidden_prev->mutable_data<T>();
auto* batch_hidden = param.batch_hidden;
batch_hidden->mutable_data<T>();
T* batch_gate_ptr = batch_gate->mutable_data<T>();
T* batch_reset_hidden_prev_ptr = batch_reset_hidden_prev->mutable_data<T>();
T* batch_hidden_ptr = batch_hidden->mutable_data<T>();
auto* hidden = param.hidden;
hidden->mutable_data<T>();
auto hidden_dims = hidden->dims();
const auto& hidden_dims = hidden->dims();
lite::x86::math::LoDTensor2BatchFunctor<TARGET(kX86), T> to_batch;
to_batch(context, *input, batch_gate, true, is_reverse);
......@@ -90,19 +95,23 @@ class GRUCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
std::vector<size_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
const std::vector<size_t>& order(batch_gate->lod()[2]);
ReorderInitState<T>(context, *h0, order, &ordered_h0, true);
gru_value.prev_out_value = ordered_h0.mutable_data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
const auto& batch_starts = batch_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
int64_t batch_gate_width = CalculateSeqWidth(batch_gate->dims());
int64_t batch_reset_hidden_prev_width =
CalculateSeqWidth(batch_reset_hidden_prev->dims());
int64_t batch_hidden_width = CalculateSeqWidth(batch_hidden->dims());
auto active_node =
lite::x86::math::detail::GetActivationType(param.activation);
auto active_gate =
......@@ -145,13 +154,10 @@ class GRUCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
int64_t bend = static_cast<int64_t>(batch_starts[n + 1]);
int64_t cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice<T>(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice<T>(bstart, bend);
Tensor hidden_t = batch_hidden->Slice<T>(bstart, bend);
gru_value.output_value = hidden_t.mutable_data<T>();
gru_value.gate_value = gate_t.mutable_data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.mutable_data<T>();
gru_value.output_value = batch_hidden_ptr + bstart * batch_hidden_width;
gru_value.gate_value = batch_gate_ptr + bstart * batch_gate_width;
gru_value.reset_output_value = batch_reset_hidden_prev_ptr +
bstart * batch_reset_hidden_prev_width;
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(CblasNoTrans,
......@@ -188,13 +194,10 @@ class GRUCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
int64_t bend = static_cast<int64_t>(batch_starts[n + 1]);
int64_t cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice<T>(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice<T>(bstart, bend);
Tensor hidden_t = batch_hidden->Slice<T>(bstart, bend);
gru_value.output_value = hidden_t.mutable_data<T>();
gru_value.gate_value = gate_t.mutable_data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.mutable_data<T>();
gru_value.output_value = batch_hidden_ptr + bstart * batch_hidden_width;
gru_value.gate_value = batch_gate_ptr + bstart * batch_gate_width;
gru_value.reset_output_value = batch_reset_hidden_prev_ptr +
bstart * batch_reset_hidden_prev_width;
lite::x86::math::GRUUnitFunctor<TARGET(kX86), T>::compute(
context,
......
......@@ -55,24 +55,33 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto& context = ctx_->As<X86Context>();
CHECK(param.output);
CHECK(param.x);
param.output->mutable_data<T>();
const int rank = param.x->dims().size();
auto* x = param.x;
auto* output = param.output;
output->mutable_data<T>();
const int rank = x->dims().size();
const int axis = CanonicalAxis(param.axis, rank);
int axis_dim = param.x->dims()[axis];
const int n = SizeToAxis(axis, param.x->dims());
const int d = SizeFromAxis(axis, param.x->dims());
int axis_dim = x->dims()[axis];
if (rank == 2 && axis == 1) {
lite::x86::math::SoftmaxFunctor<lite::TargetType::kX86, T, true>()(
context, axis_dim, x, output);
} else {
const int n = SizeToAxis(axis, x->dims());
const int d = SizeFromAxis(axis, x->dims());
DDim shape(std::vector<DDim::value_type>{n, d});
DDim x_dims = x->dims();
DDim out_dims = output->dims();
Tensor input_2d;
Tensor out_2d;
input_2d.ShareDataWith(*param.x);
input_2d.Resize(shape);
out_2d.ShareDataWith(*param.output);
out_2d.Resize(shape);
DDim shape_2d(std::vector<DDim::value_type>{n, d});
x->Resize(shape_2d);
output->Resize(shape_2d);
lite::x86::math::SoftmaxFunctor<lite::TargetType::kX86, T, true>()(
context, axis_dim, &input_2d, &out_2d);
lite::x86::math::SoftmaxFunctor<lite::TargetType::kX86, T, true>()(
context, axis_dim, x, output);
x->Resize(x_dims);
output->Resize(out_dims);
}
}
virtual ~SoftmaxCompute() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册