提交 b387a194 编写于 作者: J JiabinYang

optimize op with blas

上级 4f71a6ee
...@@ -158,6 +158,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -158,6 +158,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
} }
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
protected: protected:
......
...@@ -185,7 +185,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -185,7 +185,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
ctx.Output<framework::SelectedRows>(framework::GradVarName("W")); ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows); w_grad->set_rows(real_rows);
// Build a map of id -> row_index to speed up finding the index of one id // Build a map of id -> row_index to speed up finding the index of one id
w_grad->SyncIndex();
w_grad->set_height(w.dims()[0]); w_grad->set_height(w.dims()[0]);
auto* w_grad_value = w_grad->mutable_value(); auto* w_grad_value = w_grad->mutable_value();
framework::DDim temp_dim(w.dims()); framework::DDim temp_dim(w.dims());
......
...@@ -89,6 +89,8 @@ template <typename T> ...@@ -89,6 +89,8 @@ template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight, const framework::Tensor& weight,
const framework::Tensor& input) { const framework::Tensor& input) {
auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat->dims()[0]; size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1]; size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1]; size_t input_width = input.dims()[1];
...@@ -99,13 +101,12 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, ...@@ -99,13 +101,12 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
const T* input_row = input_value + input_width * i;
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code->calc_index(j);
const T* weight_row = weight_value + weight_width * index;
T sum = static_cast<T>(0.0); T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) { sum = blas.DOT(input_width, weight_row, input_row);
sum += weight_value[weight_width * index + k] *
input_value[input_width * i + k];
}
tmat_value[i * tmat_width + j] += sum; tmat_value[i * tmat_width + j] += sum;
} }
} }
...@@ -115,6 +116,8 @@ template <typename T> ...@@ -115,6 +116,8 @@ template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight, framework::Tensor* weight,
const framework::Tensor& input) { const framework::Tensor& input) {
auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1]; size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat.dims()[1];
...@@ -122,17 +125,26 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, ...@@ -122,17 +125,26 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto tmat_value = tmat.data<T>(); auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>(); auto weight_value = weight->data<T>();
auto input_value = input.data<T>(); auto input_value = input.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
const T* input_value_row = input_value + input_width * i;
const T* tmat_row = tmat_value + i * tmat_width;
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
} }
} }
for (auto& op : ops) {
auto& op_in_row = op.second;
for (auto& pair : op_in_row) {
auto& scale = pair.first;
auto* input_row = pair.second;
T* weight_row = weight_value + op.first * weight_width;
blas.AXPY(input_width, scale, input_row, weight_row);
}
} }
} }
...@@ -140,6 +152,8 @@ template <typename T> ...@@ -140,6 +152,8 @@ template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::SelectedRows* weight, framework::SelectedRows* weight,
const framework::Tensor& input) { const framework::Tensor& input) {
auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1]; size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat.dims()[1];
...@@ -147,17 +161,28 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, ...@@ -147,17 +161,28 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto tmat_value = tmat.data<T>(); auto tmat_value = tmat.data<T>();
auto weight_value = weight->mutable_value()->data<T>(); auto weight_value = weight->mutable_value()->data<T>();
auto input_value = input.data<T>(); auto input_value = input.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
ops.reserve(weight->rows().size());
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
const T* input_value_row = input_value + input_width * i;
const T* tmat_row = tmat_value + i * tmat_width;
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
for (size_t k = 0; k < input_width; ++k) { }
int64_t row_index = weight->GetIndexFromId(static_cast<int64_t>(index));
weight_value[row_index * weight_width + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
} }
for (auto& row : weight->rows()) {
auto& op_in_row = ops[row];
for (auto& pair : op_in_row) {
auto& scale = pair.first;
auto* input_row = pair.second;
blas.AXPY(input_width, scale, input_row, weight_value);
} }
weight_value += weight_width;
} }
} }
......
...@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#if defined(_WIN32) #if defined(_WIN32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册