diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 972dcf5494e9acd47e7ff615db45f056a43724a6..b326b583199a9eb8588de2c51157d98972815167 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -158,6 +158,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); } ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); } protected: diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 07ff8f947e59d2954783e2ba537bfce3cb320f22..b73a32af89e882ac02623dd1d312f400a78fc47a 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -185,7 +185,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { ctx.Output(framework::GradVarName("W")); w_grad->set_rows(real_rows); // Build a map of id -> row_index to speed up finding the index of one id - w_grad->SyncIndex(); w_grad->set_height(w.dims()[0]); auto* w_grad_value = w_grad->mutable_value(); framework::DDim temp_dim(w.dims()); diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 71b9293eeded77553ca06a8574cca3941fa36b6a..5a6e64b6f87d33249f0153e5f391deaf78e53de5 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -89,6 +89,8 @@ template void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, const framework::Tensor& weight, const framework::Tensor& input) { + auto blas = + GetBlas(platform::CPUDeviceContext()); size_t num_samples = tmat->dims()[0]; size_t tmat_width = tmat->dims()[1]; size_t input_width = input.dims()[1]; @@ -99,13 +101,12 @@ void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); + const T* input_row = input_value + input_width * i; for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); + const T* weight_row = weight_value + weight_width * index; T sum = static_cast(0.0); - for (size_t k = 0; k < input_width; ++k) { - sum += weight_value[weight_width * index + k] * - input_value[input_width * i + k]; - } + sum = blas.DOT(input_width, weight_row, input_row); tmat_value[i * tmat_width + j] += sum; } } @@ -115,6 +116,8 @@ template void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, const framework::Tensor& input) { + auto blas = + GetBlas(platform::CPUDeviceContext()); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -122,16 +125,25 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto tmat_value = tmat.data(); auto weight_value = weight->data(); auto input_value = input.data(); + + std::unordered_map>> ops; + for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); + const T* input_value_row = input_value + input_width * i; + const T* tmat_row = tmat_value + i * tmat_width; for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - - for (size_t k = 0; k < input_width; ++k) { - weight_value[weight_width * index + k] += - tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; - } + ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } + } + for (auto& op : ops) { + auto& op_in_row = op.second; + for (auto& pair : op_in_row) { + auto& scale = pair.first; + auto* input_row = pair.second; + T* weight_row = weight_value + op.first * weight_width; + blas.AXPY(input_width, scale, input_row, weight_row); } } } @@ -140,6 +152,8 @@ template void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::SelectedRows* weight, const framework::Tensor& input) { + auto blas = + GetBlas(platform::CPUDeviceContext()); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -147,17 +161,28 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto tmat_value = tmat.data(); auto weight_value = weight->mutable_value()->data(); auto input_value = input.data(); + + std::unordered_map>> ops; + ops.reserve(weight->rows().size()); + for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); + const T* input_value_row = input_value + input_width * i; + const T* tmat_row = tmat_value + i * tmat_width; for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - for (size_t k = 0; k < input_width; ++k) { - int64_t row_index = weight->GetIndexFromId(static_cast(index)); - weight_value[row_index * weight_width + k] += - tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; - } + ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } + } + + for (auto& row : weight->rows()) { + auto& op_in_row = ops[row]; + for (auto& pair : op_in_row) { + auto& scale = pair.first; + auto* input_row = pair.second; + blas.AXPY(input_width, scale, input_row, weight_value); } + weight_value += weight_width; } } diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index c30bb52641e865efe57659a551bc4b493634c6b9..35ca73802b48982ddf3ed7485b56f50221c9f28c 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" #if defined(_WIN32)