未验证 提交 2803cf57 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #14868 from reyoung/feature/refine_w2v

Feature/refine w2v
...@@ -18,8 +18,8 @@ ENDIF() ...@@ -18,8 +18,8 @@ ENDIF()
INCLUDE(python_module) INCLUDE(python_module)
FIND_PACKAGE(PythonInterp ${PY_VERSION}) FIND_PACKAGE(PythonInterp ${PY_VERSION} REQUIRED)
FIND_PACKAGE(PythonLibs ${PY_VERSION}) FIND_PACKAGE(PythonLibs ${PY_VERSION} REQUIRED)
if(WIN32) if(WIN32)
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
...@@ -79,6 +79,5 @@ IF(PYTHONINTERP_FOUND) ...@@ -79,6 +79,5 @@ IF(PYTHONINTERP_FOUND)
"please use pip to upgrade protobuf. pip install -U protobuf") "please use pip to upgrade protobuf. pip install -U protobuf")
ENDIF() ENDIF()
ENDIF(PYTHONINTERP_FOUND) ENDIF(PYTHONINTERP_FOUND)
INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR})
INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR})
...@@ -150,19 +150,27 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -150,19 +150,27 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
label.data<int64_t>())); label.data<int64_t>()));
} }
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); // softrelu derivative
auto pre_out_mat = EigenMatrix<T>::From(pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(out_grad);
Eigen::array<int, 2> bcast{1, static_cast<int>(pre_out_grad.dims()[1])}; auto blas = math::GetBlas<DeviceContext, T>(ctx);
// softrelu derivative auto* pre_out_grad_data = pre_out_grad.data<T>();
pre_out_grad_mat.device(place) = auto* pre_out_data = pre_out.data<T>();
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp(); auto n = pre_out.numel();
blas.VEXP(n, pre_out_data, pre_out_grad_data);
blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
for (int64_t i = 0; i < n; ++i) {
pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
}
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) = auto* out_grad_data = out_grad.data<T>();
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
int64_t dim0 = pre_out_grad.dims()[0];
int64_t dim1 = pre_out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
T tmp = out_grad_data[i];
blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1);
}
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward. // be consistent with the clipping in forward.
......
...@@ -181,6 +181,9 @@ class Blas { ...@@ -181,6 +181,9 @@ class Blas {
const framework::Tensor& mat_b, const MatDescriptor& dim_b, const framework::Tensor& mat_b, const MatDescriptor& dim_b,
T alpha, framework::Tensor* mat_out, T beta) const; T alpha, framework::Tensor* mat_out, T beta) const;
template <typename T>
void VINV(int n, const T* a, T* y) const;
private: private:
const DeviceContext& context_; const DeviceContext& context_;
}; };
...@@ -282,6 +285,11 @@ class BlasT : private Blas<DeviceContext> { ...@@ -282,6 +285,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template BatchedGEMM<T>(args...); Base()->template BatchedGEMM<T>(args...);
} }
template <typename... ARGS>
void VINV(ARGS... args) const {
Base()->template VINV<T>(args...);
}
private: private:
const Blas<DeviceContext>* Base() const { const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this); return static_cast<const Blas<DeviceContext>*>(this);
......
...@@ -118,6 +118,11 @@ struct CBlas<float> { ...@@ -118,6 +118,11 @@ struct CBlas<float> {
static void VPOW(ARGS... args) { static void VPOW(ARGS... args) {
platform::dynload::vsPowx(args...); platform::dynload::vsPowx(args...);
} }
template <typename... ARGS>
static void VINV(ARGS... args) {
platform::dynload::vsInv(args...);
}
}; };
template <> template <>
...@@ -213,6 +218,11 @@ struct CBlas<double> { ...@@ -213,6 +218,11 @@ struct CBlas<double> {
static void VPOW(ARGS... args) { static void VPOW(ARGS... args) {
platform::dynload::vdPowx(args...); platform::dynload::vdPowx(args...);
} }
template <typename... ARGS>
static void VINV(ARGS... args) {
platform::dynload::vdInv(args...);
}
}; };
#else #else
...@@ -603,6 +613,17 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, ...@@ -603,6 +613,17 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
dim_a.stride_, dim_b.stride_); dim_a.stride_, dim_b.stride_);
} }
} }
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::VINV(int n, const T *a, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VINV(n, a, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = 1.0 / a[i];
}
#endif
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -14,195 +14,334 @@ limitations under the License. */ ...@@ -14,195 +14,334 @@ limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream> #include <iostream>
#include <map>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Add(const framework::Tensor& vec, struct MatrixBitCodeFunctorAdd : public boost::static_visitor<void> {
framework::Tensor* tmat) { const framework::Tensor &vec_;
size_t batch_size = tmat->dims()[0]; framework::Tensor *tmat_;
size_t width = tmat->dims()[1];
MatrixBitCodeFunctorAdd(const framework::Tensor &vec, framework::Tensor *tmat)
: vec_(vec), tmat_(tmat) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
size_t batch_size = tmat_->dims()[0];
size_t width = tmat_->dims()[1];
auto *tmat_data = tmat_->data<T>();
auto *vec_data = vec_.data<T>();
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code.calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index]; tmat_data[i * width + j] += vec_data[index];
} }
} }
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::Add(const framework::Tensor &vec,
framework::Tensor *tmat) {
MatrixBitCodeFunctorAdd<T> func(vec, tmat);
code_table_.apply_visitor(func);
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, struct MatrixBitCodeFunctorAddGrad : public boost::static_visitor<void> {
framework::Tensor* vec) { const framework::Tensor &tmat_;
size_t batch_size = tmat.dims()[0]; framework::Tensor *vec_;
size_t width = tmat.dims()[1]; MatrixBitCodeFunctorAddGrad(const framework::Tensor &tmat,
framework::Tensor *vec)
: tmat_(tmat), vec_(vec) {}
template <typename CodeTable>
void operator()(const CodeTable &table) {
size_t batch_size = tmat_.dims()[0];
size_t width = tmat_.dims()[1];
auto *vec_data = vec_->data<T>();
auto *tmat_data = tmat_.data<T>();
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i); auto code = table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code.calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j]; vec_data[index] += tmat_data[i * width + j];
}
} }
} }
};
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
framework::Tensor *vec) {
MatrixBitCodeFunctorAddGrad<T> func(tmat, vec);
code_table_.apply_visitor(func);
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, struct MatrixBitCodeFunctorSelectedRowsAddGrad
framework::SelectedRows* vec) { : public boost::static_visitor<void> {
size_t batch_size = tmat.dims()[0]; const framework::Tensor &tmat_;
size_t width = tmat.dims()[1]; framework::SelectedRows *vec_;
MatrixBitCodeFunctorSelectedRowsAddGrad(const framework::Tensor &tmat,
framework::SelectedRows *vec)
: tmat_(tmat), vec_(vec) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
size_t batch_size = tmat_.dims()[0];
size_t width = tmat_.dims()[1];
auto *vec_data = vec_->mutable_value()->template data<T>();
auto *tmat_data = tmat_.data<T>();
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code.calc_index(j);
int64_t row_index = vec->GetIndexFromId(static_cast<int64_t>(index)); int64_t row_index = vec_->GetIndexFromId(static_cast<int64_t>(index));
vec->mutable_value()->data<T>()[row_index] += vec_data[row_index] += tmat_data[i * width + j];
tmat.data<T>()[i * width + j]; }
} }
} }
};
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
framework::SelectedRows *vec) {
MatrixBitCodeFunctorSelectedRowsAddGrad<T> func(tmat, vec);
code_table_.apply_visitor(func);
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, struct MatrixBitCodeFunctorSum : public boost::static_visitor<void> {
framework::Tensor* sum, T scale_sum) { const framework::Tensor &tmat_;
size_t num_samples = tmat.dims()[0]; framework::Tensor *sum_;
size_t o_width = tmat.dims()[1]; T scale_sum_;
MatrixBitCodeFunctorSum(const framework::Tensor &tmat, framework::Tensor *sum,
T scale_sum)
: tmat_(tmat), sum_(sum), scale_sum_(scale_sum) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
size_t num_samples = tmat_.dims()[0];
size_t o_width = tmat_.dims()[1];
auto *tmat_data = tmat_.data<T>();
auto *sum_data = sum_->data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0); T sm = static_cast<T>(0.0);
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code->calc_bit(j)) { if (code.calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the // calc_bit starts from right most bit, while data in tmat[i] is in
// the
// reverse order. // reverse order.
sm += tmat.data<T>()[i * o_width + j]; sm += tmat_data[i * o_width + j];
} }
} }
sum->data<T>()[i] = scale_sum * sm; sum_data[i] = scale_sum_ * sm;
} }
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor &tmat,
framework::Tensor *sum, T scale_sum) {
MatrixBitCodeFunctorSum<T> func(tmat, sum, scale_sum);
code_table_.apply_visitor(func);
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, struct MatrixBitCodeFunctorMul : public boost::static_visitor<void> {
const framework::Tensor& weight, framework::Tensor *tmat_;
const framework::Tensor& input) { const framework::Tensor &weight_;
const framework::Tensor &input_;
MatrixBitCodeFunctorMul(framework::Tensor *tmat,
const framework::Tensor &weight,
const framework::Tensor &input)
: tmat_(tmat), weight_(weight), input_(input) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
auto blas = auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext()); GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat->dims()[0]; size_t num_samples = tmat_->dims()[0];
size_t tmat_width = tmat->dims()[1]; size_t tmat_width = tmat_->dims()[1];
size_t input_width = input.dims()[1]; size_t input_width = input_.dims()[1];
size_t weight_width = weight.dims()[1]; size_t weight_width = weight_.dims()[1];
auto tmat_value = tmat->data<T>(); auto tmat_value = tmat_->data<T>();
auto weight_value = weight.data<T>(); auto weight_value = weight_.data<T>();
auto input_value = input.data<T>(); auto input_value = input_.data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
const T* input_row = input_value + input_width * i; const T *input_row = input_value + input_width * i;
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code.calc_index(j);
const T* weight_row = weight_value + weight_width * index; const T *weight_row = weight_value + weight_width * index;
T sum = static_cast<T>(0.0); T sum = blas.DOT(input_width, weight_row, input_row);
sum = blas.DOT(input_width, weight_row, input_row);
tmat_value[i * tmat_width + j] += sum; tmat_value[i * tmat_width + j] += sum;
} }
} }
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor *tmat,
const framework::Tensor &weight,
const framework::Tensor &input) {
MatrixBitCodeFunctorMul<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
} }
template <typename T, size_t N>
class ReservedVector : public std::vector<T> {
public:
ReservedVector() { this->reserve(N); }
};
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, struct MatrixBitCodeFunctorMulGradWeight : public boost::static_visitor<void> {
framework::Tensor* weight, const framework::Tensor &tmat_;
const framework::Tensor& input) { framework::Tensor *weight_;
const framework::Tensor &input_;
MatrixBitCodeFunctorMulGradWeight(const framework::Tensor &tmat,
framework::Tensor *weight,
const framework::Tensor &input)
: tmat_(tmat), weight_(weight), input_(input) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
auto blas = auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext()); GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat_.dims()[0];
size_t input_width = input.dims()[1]; size_t input_width = input_.dims()[1];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat_.dims()[1];
size_t weight_width = weight->dims()[1]; size_t weight_width = weight_->dims()[1];
auto tmat_value = tmat.data<T>(); auto tmat_value = tmat_.data<T>();
auto weight_value = weight->data<T>(); auto weight_value = weight_->data<T>();
auto input_value = input.data<T>(); auto input_value = input_.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
std::map<int, ReservedVector<std::pair<T, const T *>, 8u>> ops;
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
const T* input_value_row = input_value + input_width * i; const T *input_value_row = input_value + input_width * i;
const T* tmat_row = tmat_value + i * tmat_width; const T *tmat_row = tmat_value + i * tmat_width;
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); ops[code.calc_index(j)].emplace_back(tmat_row[j], input_value_row);
} }
} }
for (auto& op : ops) { for (auto &op : ops) {
auto& op_in_row = op.second; auto &op_in_row = op.second;
for (auto& pair : op_in_row) { for (auto &pair : op_in_row) {
auto& scale = pair.first; auto &scale = pair.first;
auto* input_row = pair.second; auto *input_row = pair.second;
T* weight_row = weight_value + op.first * weight_width; T *weight_row = weight_value + op.first * weight_width;
blas.AXPY(input_width, scale, input_row, weight_row); blas.AXPY(input_width, scale, input_row, weight_row);
} }
} }
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat,
framework::Tensor *weight,
const framework::Tensor &input) {
MatrixBitCodeFunctorMulGradWeight<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, struct MatrixBitCodeFunctorMulGradWeightSR
framework::SelectedRows* weight, : public boost::static_visitor<void> {
const framework::Tensor& input) { const framework::Tensor &tmat_;
framework::SelectedRows *weight_;
const framework::Tensor &input_;
MatrixBitCodeFunctorMulGradWeightSR(const framework::Tensor &tmat,
framework::SelectedRows *weight,
const framework::Tensor &input)
: tmat_(tmat), weight_(weight), input_(input) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
auto blas = auto blas =
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext()); GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat_.dims()[0];
size_t input_width = input.dims()[1]; size_t input_width = input_.dims()[1];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat_.dims()[1];
size_t weight_width = weight->value().dims()[1]; size_t weight_width = weight_->value().dims()[1];
auto tmat_value = tmat.data<T>(); auto tmat_value = tmat_.data<T>();
auto weight_value = weight->mutable_value()->data<T>(); auto weight_value = weight_->mutable_value()->data<T>();
auto input_value = input.data<T>(); auto input_value = input_.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops; std::unordered_map<int, std::vector<std::pair<T, const T *>>> ops;
ops.reserve(weight->rows().size()); ops.reserve(weight_->rows().size());
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
const T* input_value_row = input_value + input_width * i; const T *input_value_row = input_value + input_width * i;
const T* tmat_row = tmat_value + i * tmat_width; const T *tmat_row = tmat_value + i * tmat_width;
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); ops[code.calc_index(j)].emplace_back(tmat_row[j], input_value_row);
} }
} }
for (auto& row : weight->rows()) { for (auto &row : weight_->rows()) {
auto& op_in_row = ops[row]; auto &op_in_row = ops[row];
for (auto& pair : op_in_row) { for (auto &pair : op_in_row) {
auto& scale = pair.first; auto &scale = pair.first;
auto* input_row = pair.second; auto *input_row = pair.second;
blas.AXPY(input_width, scale, input_row, weight_value); blas.AXPY(input_width, scale, input_row, weight_value);
} }
weight_value += weight_width; weight_value += weight_width;
} }
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat,
framework::SelectedRows *weight,
const framework::Tensor &input) {
MatrixBitCodeFunctorMulGradWeightSR<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, struct MatrixBitCodeFunctorMulGradError : public boost::static_visitor<void> {
const framework::Tensor& weight, const framework::Tensor &tmat_;
framework::Tensor* input) { const framework::Tensor &weight_;
size_t num_samples = tmat.dims()[0]; framework::Tensor *input_;
size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1]; MatrixBitCodeFunctorMulGradError(const framework::Tensor &tmat,
size_t weight_width = weight.dims()[1]; const framework::Tensor &weight,
auto tmat_value = tmat.data<T>(); framework::Tensor *input)
auto weight_value = weight.data<T>(); : tmat_(tmat), weight_(weight), input_(input) {}
auto input_value = input->data<T>(); template <typename CodeTable>
void operator()(const CodeTable &code_table) {
size_t num_samples = tmat_.dims()[0];
size_t tmat_width = tmat_.dims()[1];
size_t input_width = input_->dims()[1];
size_t weight_width = weight_.dims()[1];
auto tmat_value = tmat_.data<T>();
auto weight_value = weight_.data<T>();
auto input_value = input_->data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) { for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] += input_value[input_width * i + k] +=
...@@ -211,21 +350,44 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, ...@@ -211,21 +350,44 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
} }
} }
} }
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor &tmat,
const framework::Tensor &weight,
framework::Tensor *input) {
MatrixBitCodeFunctorMulGradError<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) { struct MatrixBitCodeFunctorSub : public boost::static_visitor<void> {
size_t num_samples = tmat->dims()[0]; framework::Tensor *tmat_;
size_t o_width = tmat->dims()[1];
explicit MatrixBitCodeFunctorSub(framework::Tensor *tmat) : tmat_(tmat) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
size_t num_samples = tmat_->dims()[0];
size_t o_width = tmat_->dims()[1];
auto *tmat_data = tmat_->data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table.get_code(i);
int code_length = code->get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code->calc_bit(j)) { if (code.calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1; tmat_data[i * o_width + j] -= 1;
}
} }
} }
} }
};
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor *tmat) {
MatrixBitCodeFunctorSub<T> func(tmat);
code_table_.apply_visitor(func);
} }
template class MatrixBitCodeFunctor<float>; template class MatrixBitCodeFunctor<float>;
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -22,6 +23,7 @@ limitations under the License. */ ...@@ -22,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
#if defined(_WIN32) #if defined(_WIN32)
#include <intrin.h> #include <intrin.h>
...@@ -98,24 +100,7 @@ inline int clz(const T& value) { ...@@ -98,24 +100,7 @@ inline int clz(const T& value) {
inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); }
#endif // !_WIN32 #endif // !_WIN32
// set a code interface to create multiple code class SimpleCode {
class Code {
public:
virtual ~Code() {}
virtual size_t calc_index(int bit) const = 0;
virtual bool calc_bit(int bit) const = 0;
virtual int get_length() const = 0;
};
// set a CodeTable interface to create multiple code table
class CodeTable {
public:
virtual std::unique_ptr<Code> get_code(int64_t code) const = 0;
virtual size_t size() const = 0;
virtual int get_max_code_length() const = 0;
virtual ~CodeTable() {}
};
class SimpleCode : public Code {
public: public:
SimpleCode(size_t code, size_t num_classes, const int64_t* ids) SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
: c_(static_cast<size_t>(ids[code]) + num_classes) {} : c_(static_cast<size_t>(ids[code]) + num_classes) {}
...@@ -137,16 +122,16 @@ class SimpleCode : public Code { ...@@ -137,16 +122,16 @@ class SimpleCode : public Code {
}; };
template <typename T> template <typename T>
class CustomCode : public Code { class CustomCode {
public: public:
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
const int64_t* ids, int index) const int64_t* ids, int index) {
: ids_(ids), index_(index) { seq_len_ = ptable.dims()[1];
ptable_ = ptable.Slice(index, index + 1); ptable_data_ = ptable.data<T>() + seq_len_ * index;
pcode_ = pcode.Slice(index, index + 1); pcode_data_ = pcode.data<T>() + seq_len_ * index;
} }
/** /**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c * Here the id of root should be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using * is `c + num_classes` and all siblings can get the same weight indice using
* prefixes. * prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most * Weight index is the prefixes of encoding, thus leave out the right most
...@@ -154,36 +139,37 @@ class CustomCode : public Code { ...@@ -154,36 +139,37 @@ class CustomCode : public Code {
* Binary classification path is the suffixes of encoding, thus leave out the * Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit. * left most bit in calc_bit.
*/ */
size_t calc_index(int bit) const { return ptable_.data<T>()[bit]; } size_t calc_index(int bit) const { return ptable_data_[bit]; }
bool calc_bit(int bit) const { return pcode_.data<T>()[bit]; } bool calc_bit(int bit) const { return pcode_data_[bit]; }
int get_length() const {
int length = 0;
for (int i = 0; i < static_cast<int>(ptable_.dims()[1]); i++) { // NOTE: this function is not thread-safe.
if (ptable_.data<T>()[i] >= 0) { int get_length() const {
length++; if (length_ < 0) {
} else { auto len = seq_len_;
return length; length_ =
} static_cast<int>(std::find_if(ptable_data_, ptable_data_ + len,
[](const T& val) { return val < 0; }) -
ptable_data_);
} }
return length; return length_;
} }
private: private:
framework::Tensor ptable_; int64_t seq_len_;
framework::Tensor pcode_; const T* ptable_data_;
const int64_t* ids_; const T* pcode_data_;
const int index_; mutable int length_{-1};
}; };
class SimpleCodeTable : public CodeTable { class SimpleCodeTable {
public: public:
SimpleCodeTable(size_t num_classes, const int64_t* ids) SimpleCodeTable(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {} : num_classes_(num_classes), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new SimpleCode(code, num_classes_, ids_)); SimpleCode get_code(int64_t code) const {
return coder; return SimpleCode(code, num_classes_, ids_);
} }
size_t size() const { return num_classes_; } size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
...@@ -193,15 +179,14 @@ class SimpleCodeTable : public CodeTable { ...@@ -193,15 +179,14 @@ class SimpleCodeTable : public CodeTable {
}; };
template <typename T> template <typename T>
class CustomCodeTable : public CodeTable { class CustomCodeTable {
public: public:
CustomCodeTable(const framework::Tensor& ptable, CustomCodeTable(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids) const framework::Tensor& pcode, const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {} : ptable_(ptable), pcode_(pcode), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const { CustomCode<T> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new CustomCode<T>(ptable_, pcode_, ids_, code)); return CustomCode<T>(ptable_, pcode_, ids_, code);
return coder;
} }
size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); } size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
...@@ -215,19 +200,21 @@ class CustomCodeTable : public CodeTable { ...@@ -215,19 +200,21 @@ class CustomCodeTable : public CodeTable {
const int64_t* ids_; const int64_t* ids_;
}; };
using CodeTable = boost::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;
template <typename T> template <typename T>
class MatrixBitCodeFunctor { class MatrixBitCodeFunctor {
public: public:
MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), : num_classes_(num_classes),
ids_(ids), ids_(ids),
code_table_(new SimpleCodeTable(num_classes, ids)) {} code_table_(SimpleCodeTable(num_classes, ids)) {}
MatrixBitCodeFunctor(const framework::Tensor& ptable, MatrixBitCodeFunctor(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids) const framework::Tensor& pcode, const int64_t* ids)
: num_classes_(static_cast<size_t>(ptable.dims()[1])), : num_classes_(static_cast<size_t>(ptable.dims()[1])),
ids_(ids), ids_(ids),
code_table_(new CustomCodeTable<int64_t>(ptable, pcode, ids)) {} code_table_(CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
/* For j < code_length /* For j < code_length
tmat(i, j) += vec(0, index(i, j)) tmat(i, j) += vec(0, index(i, j))
*/ */
...@@ -277,7 +264,7 @@ class MatrixBitCodeFunctor { ...@@ -277,7 +264,7 @@ class MatrixBitCodeFunctor {
size_t num_classes_; size_t num_classes_;
const int64_t* ids_; const int64_t* ids_;
std::unique_ptr<CodeTable> code_table_; CodeTable code_table_;
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -82,6 +82,8 @@ extern void* mklml_dso_handle; ...@@ -82,6 +82,8 @@ extern void* mklml_dso_handle;
__macro(vdSqr); \ __macro(vdSqr); \
__macro(vsPowx); \ __macro(vsPowx); \
__macro(vdPowx); \ __macro(vdPowx); \
__macro(vsInv); \
__macro(vdInv); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册