diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a719da2560291dbc7e98aadfae41d4692d8afcad..93ec763424d73b79321fa81dce3ea3a36e5a59a8 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -185,7 +185,8 @@ set(DEPS_OPS tensor_array_read_write_op gru_op adagrad_op - sgd_op) + sgd_op + hierarchical_sigmoid_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) @@ -203,6 +204,7 @@ op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op) op_library(array_to_lod_tensor_op SRCS array_to_lod_tensor_op.cc DEPS lod_rank_table_op) op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc) +op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) if(WITH_GPU) op_library(nccl_op DEPS nccl_common) endif() diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 9b7af92662e403c94964b49be206a73e2b26faad..f81f3d34d1931ce2e3231a3fd60b6dda434e86dd 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -85,12 +85,16 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { "(TensorArray, required) The input array. Each Tensor has the " "same shape with [N * D].") .AsDuplicable(); + AddInput("Parameters", + "(Tensor, required), The parameters of hierarchical " + "sigmoid operator, each of them is s a 2-D tensor.") + .AsDuplicable(); AddInput("Label", "(Tensor, required), The labels of training data. It's a" "1-D tensor."); AddInput("Bias", "(Tensor, optional), The bias is a 1-D tensor, " - "which is applied to the output"); + "which is applied to the output."); AddOutput( "Out", "(Tensor, required) The output of hierarchical sigmoid operator."); diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 11a553a40392c783c3b74efe72c820ca5314852c..baf655f2141804616acf0fd9373f095d35a19246 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -14,28 +14,61 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/matrix_bit_code.h" namespace paddle { namespace operators { -template +template +using EigenMatrix = framework::EigenMatrix; + +template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); + auto params = ctx.MultiInput("Parameters"); auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); size_t num_classes = static_cast(ctx.Attr("num_classes")); + + framework::Tensor sum; + framework::Tensor pre_out; + auto place = ctx.GetEigenDevice(); + auto& device_ctx = ctx.device_context(); + math::ColwiseSum col_sum; + math::RowwiseSum row_sum; + + auto pre_out_mat = EigenMatrix::From(pre_out); int64_t batch_size = ins[0]->dims()[0]; int64_t size = ins.size(); - framework::Tensor pre_out; + std::vector pre_out_dims({batch_size, size}); pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + std::vector sum_dims({batch_size, 1UL}); + sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); + out->mutable_data(ctx.GetPlace()); - if (bias != NULL) { + if (bias) { math::AddByBitCode(num_classes, *label, pre_out, *bias); } + + for (size_t i = 0; i < ins.size(); ++i) { + math::MulByBitCode(num_classes, *label, pre_out, *params[i], *ins[i]); + } + // clip the matrix with (-40, 40) + pre_out_mat.device(place) = + pre_out_mat.abs().cwiseMax(static_cast(40.0)); + math::SumByBitCode(num_classes, *label, *out, pre_out, + static_cast(-1)); + // softrelu + pre_out_mat.device(place) = (static_cast(1) + pre_out_mat.exp()).log(); + + row_sum(device_ctx, pre_out, &sum); + col_sum(device_ctx, *out, &sum); } }; diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 2e333a8cde721f8e65dbf2cf5e3aac6272172cc0..3bc0945fe3a7aee50f0d4501926af11fa469a55e 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -314,6 +314,8 @@ template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; +template struct RowwiseSum; +template struct RowwiseSum; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 58356a4b7783241ca0292829bf05dc1a8ed80c6c..1a226821f717affeccd691c8248bb00ca78d72ac 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -298,6 +298,8 @@ template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; +template struct RowwiseSum; +template struct RowwiseSum; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index ffb99f53808c4316ede96b04e57aec4dae4134de..c21a20fc326024905fe2f0623c86047ea92c7b2d 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -130,6 +130,12 @@ struct ColwiseSum { const framework::Tensor& input, framework::Tensor* vec); }; +template +struct RowwiseSum { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index 4dc17a4e525c52b8f696277274a7ad00a6b00a08..8c1971fc611cf0678d1185d14fbebbeed7d8a594 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -78,6 +78,20 @@ void ColwiseSum::operator()(const platform::DeviceContext& context, in.sum(Eigen::array({{0}})).reshape(shape); } +template +void RowwiseSum::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[1]; + PADDLE_ENFORCE_EQ(vector->numel(), size); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenMatrix::From(*vector); + Eigen::array shape({{static_cast(size), 1}}); + vec.reshape(shape).device(*context.GetEigenDevice()) = + in.sum(Eigen::array({{0}})).reshape(shape); +} } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 30c2ffc2cfdacc89a01721265137eb3cfb496af8..8f68e2f79ddf0ed9f8253efa247aeadbe8b20afd 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -53,18 +53,18 @@ namespace math { template static void AddByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, framework::Tensor& a, - framework::Tensor& b) { + const framework::Tensor& b) { size_t num_classes = code_table.size(); size_t max_code_length = code_table.get_max_code_length(); - size_t num_sample = a.dims()[0].size(); - size_t width = a.dims()[1].size(); + size_t num_sample = a.dims()[0]; + size_t width = a.dims()[1]; for (size_t i = 0; i < num_sample; ++i) { - auto code = code_table(codes.data()[i]) int code_length = - code.get_length(); + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); for (int j = 0; j < code_length; + j) { size_t index = code.calc_index(j); - op(a.data()[i * width + j], b.data()[index]); + op(a.data()[i * width + j], b.data()[index]); } } } @@ -79,6 +79,71 @@ void AddByBitCode(size_t num_classes, const framework::Tensor& codes, AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); } +template +void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& sum, + const T& scale_sum) { + size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + T sm = 0; + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + sm += tmat.data()[i * o_width + j]; + } + } + sum.data()[i] = scale_sum * sm; + } +} +/* For j < codeLength: + sum(i, 0) = \sum_j bit(i, j) * input(i, j) +*/ +template +void SumByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& sum, + T scale_sum) { + SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, scale_sum); +} + +template +void MulByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& weight, + framework::Tensor& input) { + size_t num_classes = code_table.size(); + size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t input_dim = input.dims()[1]; + size_t o_width = tmat.dims()[1]; + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + op(tmat.data()[i * o_width + j], + weight.data() + index * weight.dims()[1], + input.data() + i * input.dims()[1], input_dim); + } + } +} + +template +void MulByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, const framework::Tensor& weight, + const framework::Tensor& input) { + auto op = [](T& t, const T* weight_row, const T* input_row, + size_t input_dim) { + T sum = 0; + for (size_t k = 0; k < input_dim; ++k) { + sum += weight_row[k] * input_row[k]; + } + t += sum; + }; + MulByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, weight, input); +} } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index bb0599aa17716eb593b05ef7526c663bef6a2c98..7bef5077b9bbb160a82d968d16b484d4f46e8cae 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -59,10 +59,27 @@ struct SimpleCodeTable { int max_code_length_; }; +/* For j < codeLength + tmat(i, j) += vec(0, index(i, j)) +*/ template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& a, const framework::Tensor& b); + framework::Tensor& tmat, const framework::Tensor& vec); +/* For j < codeLength + sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) +*/ +template +void SumByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); + +/* For j < codeLength + input.row(i) += tmat(i, j) * weight.row(index(i, j)) +*/ +template +void MulByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, const framework::Tensor& weight, + const framework::Tensor& input); } // namespace math } // namespace operators } // namespace paddle