未验证 提交 21c0f874 编写于 作者: J Jiabin Yang 提交者: GitHub

Merge pull request #14728 from JiabinYang/optimize_hs_op

Optimize hs op
......@@ -158,6 +158,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
}
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
protected:
......
......@@ -185,7 +185,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows);
// Build a map of id -> row_index to speed up finding the index of one id
w_grad->SyncIndex();
w_grad->set_height(w.dims()[0]);
auto* w_grad_value = w_grad->mutable_value();
framework::DDim temp_dim(w.dims());
......
......@@ -89,6 +89,8 @@ template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1];
......@@ -99,13 +101,12 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i);
int code_length = code->get_length();
const T* input_row = input_value + input_width * i;
for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j);
const T* weight_row = weight_value + weight_width * index;
T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] *
input_value[input_width * i + k];
}
sum = blas.DOT(input_width, weight_row, input_row);
tmat_value[i * tmat_width + j] += sum;
}
}
......@@ -115,6 +116,8 @@ template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight,
const framework::Tensor& input) {
auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
......@@ -122,16 +125,25 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i);
int code_length = code->get_length();
const T* input_value_row = input_value + input_width * i;
const T* tmat_row = tmat_value + i * tmat_width;
for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
}
}
for (auto& op : ops) {
auto& op_in_row = op.second;
for (auto& pair : op_in_row) {
auto& scale = pair.first;
auto* input_row = pair.second;
T* weight_row = weight_value + op.first * weight_width;
blas.AXPY(input_width, scale, input_row, weight_row);
}
}
}
......@@ -140,6 +152,8 @@ template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::SelectedRows* weight,
const framework::Tensor& input) {
auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
......@@ -147,17 +161,28 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto tmat_value = tmat.data<T>();
auto weight_value = weight->mutable_value()->data<T>();
auto input_value = input.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
ops.reserve(weight->rows().size());
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i);
int code_length = code->get_length();
const T* input_value_row = input_value + input_width * i;
const T* tmat_row = tmat_value + i * tmat_width;
for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
int64_t row_index = weight->GetIndexFromId(static_cast<int64_t>(index));
weight_value[row_index * weight_width + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
}
}
for (auto& row : weight->rows()) {
auto& op_in_row = ops[row];
for (auto& pair : op_in_row) {
auto& scale = pair.first;
auto* input_row = pair.second;
blas.AXPY(input_width, scale, input_row, weight_value);
}
weight_value += weight_width;
}
}
......
......@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
#if defined(_WIN32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册