From 1f9426fd47d4cc3911e9f0a2f23274d69dd104e8 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 29 Nov 2017 20:17:22 +0800 Subject: [PATCH] add backward --- paddle/operators/hierarchical_sigmoid_op.h | 52 ++++++++++++++-- paddle/operators/math/matrix_bit_code.cc | 71 +++++++++++++++++++--- paddle/operators/math/matrix_bit_code.h | 36 ++++++++++- 3 files changed, 144 insertions(+), 15 deletions(-) diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index baf655f2141..186c7679323 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -44,9 +44,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { auto pre_out_mat = EigenMatrix::From(pre_out); int64_t batch_size = ins[0]->dims()[0]; - int64_t size = ins.size(); + int64_t code_length = math::FindLastSet(num_classes - 1); - std::vector pre_out_dims({batch_size, size}); + std::vector pre_out_dims({batch_size, code_length}); pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); @@ -64,8 +64,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { pre_out_mat.abs().cwiseMax(static_cast(40.0)); math::SumByBitCode(num_classes, *label, *out, pre_out, static_cast(-1)); - // softrelu - pre_out_mat.device(place) = (static_cast(1) + pre_out_mat.exp()).log(); + + // softrelu with threshold is 40.0 + pre_out_mat.device(place) = + pre_out_mat.abs().cwiseMax(static_cast(40.0)); + pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(device_ctx, pre_out, &sum); col_sum(device_ctx, *out, &sum); @@ -75,7 +78,46 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto ins_grad = + ctx.MultiOutput(framework::GradVarName("X")); + auto params = ctx.MultiOutput( + framework::GradVarName("Parameters")); + auto* bias = ctx.Output(framework::GradVarName("Bias")); + auto* label = + ctx.Output(framework::GradVarName("Label")); + size_t num_classes = static_cast(ctx.Attr("num_classes")); + + framework::Tensor pre_out; + auto place = ctx.GetEigenDevice(); + auto& dev_ctx = ctx.device_context(); + int64_t batch_size = ins_grad.size(); + int64_t code_length = math::FindLastSet(num_classes - 1); + auto pre_out_mat = EigenMatrix::From(pre_out); + + // init pre_out matrix with {1.0} + std::vector pre_out_dims({batch_size, code_length}); + pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + math::SetConstant set; + set(dev_ctx, &pre_out, static_cast(1.0)); + // softrelu derivative + pre_out_mat.device(place) = + pre_out_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat); + + math::SubByBitCode(num_classes, *label, pre_out); + + if (bias) { + math::AddByBitCodeGrad(num_classes, *label, pre_out, *bias); + } + + for (size_t i = 0; i < ins_grad.size(); ++i) { + math::MulByBitCodeGradWeight(num_classes, *label, pre_out, *params[i], + *ins[i]); + math::MulByBitCodeGradError(num_classes, *label, pre_out, *params[i], + *ins_grad[i]); + } + } }; } // namespace operators diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 8f68e2f79dd..996e0b819f6 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -69,19 +69,23 @@ static void AddByBitCodeT(Op op, CodeTable code_table, } } -/* For j < codeLength: - a(i, j) += b(0, index(i, j)) -*/ template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& a, const framework::Tensor& b) { + framework::Tensor& tmat, const framework::Tensor& vec) { auto op = [](T& t, T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); +} + +template +void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, framework::Tensor& vec) { + auto op = [](T& t, T& v) { v += t; }; + AddByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, vec); } template void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& sum, + framework::Tensor& tmat, const framework::Tensor& sum, const T& scale_sum) { size_t max_code_length = code_table.get_max_code_length(); size_t num_samples = tmat.dims()[0]; @@ -142,8 +146,61 @@ void MulByBitCode(size_t num_classes, const framework::Tensor& codes, } t += sum; }; - MulByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, weight, input); + MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, + input); +} + +template +void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + framework::Tensor& weight, + const framework::Tensor& input) { + auto op = [](const T t, T* weight_row, const T* input_row, size_t input_dim) { + for (size_t k = 0; k < input_dim; ++k) { + weight_row[k] += t * input_row[k]; + } + }; + MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, + input); } + +template +void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor& input) { + auto op = [](const T t, const T* weight_row, T* input_row, size_t input_dim) { + for (size_t k = 0; k < input_dim; ++k) { + input_row[k] += t * weight_row[k]; + } + }; + MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, + input); +} + +template +void SubByBitCodeT(CodeTable code_table, const framework::Tensor& codes, + framework::Tensor& tmat) { + size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat.data()[i * o_width + j] -= 1; + } + } + } +} + +template +void SubByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat) { + SubByBitCodeT(SimpleCodeTable(num_classes), codes, tmat); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index 7bef5077b9b..43c9d43d89d 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -59,27 +59,57 @@ struct SimpleCodeTable { int max_code_length_; }; -/* For j < codeLength +/* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& tmat, const framework::Tensor& vec); -/* For j < codeLength +/* For j < code_length + vec(0, index(i, j)) += tmat(i, j) +*/ +template +void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, framework::Tensor& vec); +/* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ template void SumByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); -/* For j < codeLength +/* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ template void MulByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& tmat, const framework::Tensor& weight, const framework::Tensor& input); + +/* For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) +*/ +template +void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + framework::Tensor& weight, + const framework::Tensor& input); +/* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) +*/ +template +void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor& input); + +/* For j < code_length + tmat(i, j) -= bit(i, j) +*/ +template +void SubByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat); } // namespace math } // namespace operators } // namespace paddle -- GitLab