提交 1f9426fd 编写于 作者: Y Yancey1989

add backward

上级 2ce56940
...@@ -44,9 +44,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -44,9 +44,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
auto pre_out_mat = EigenMatrix<T>::From(pre_out); auto pre_out_mat = EigenMatrix<T>::From(pre_out);
int64_t batch_size = ins[0]->dims()[0]; int64_t batch_size = ins[0]->dims()[0];
int64_t size = ins.size(); int64_t code_length = math::FindLastSet(num_classes - 1);
std::vector<int64_t> pre_out_dims({batch_size, size}); std::vector<int64_t> pre_out_dims({batch_size, code_length});
pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace()); pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace());
std::vector<int64_t> sum_dims({batch_size, 1UL}); std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace()); sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
...@@ -64,8 +64,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -64,8 +64,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
pre_out_mat.abs().cwiseMax(static_cast<T>(40.0)); pre_out_mat.abs().cwiseMax(static_cast<T>(40.0));
math::SumByBitCode<T>(num_classes, *label, *out, pre_out, math::SumByBitCode<T>(num_classes, *label, *out, pre_out,
static_cast<T>(-1)); static_cast<T>(-1));
// softrelu
pre_out_mat.device(place) = (static_cast<T>(1) + pre_out_mat.exp()).log(); // softrelu with threshold is 40.0
pre_out_mat.device(place) =
pre_out_mat.abs().cwiseMax(static_cast<T>(40.0));
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(device_ctx, pre_out, &sum); row_sum(device_ctx, pre_out, &sum);
col_sum(device_ctx, *out, &sum); col_sum(device_ctx, *out, &sum);
...@@ -75,7 +78,46 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -75,7 +78,46 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
template <typename Place, typename T> template <typename Place, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override {} void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto ins_grad =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
auto params = ctx.MultiOutput<framework::Tensor>(
framework::GradVarName("Parameters"));
auto* bias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* label =
ctx.Output<framework::Tensor>(framework::GradVarName("Label"));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
framework::Tensor pre_out;
auto place = ctx.GetEigenDevice<Place>();
auto& dev_ctx = ctx.device_context();
int64_t batch_size = ins_grad.size();
int64_t code_length = math::FindLastSet(num_classes - 1);
auto pre_out_mat = EigenMatrix<T>::From(pre_out);
// init pre_out matrix with {1.0}
std::vector<int64_t> pre_out_dims({batch_size, code_length});
pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace());
math::SetConstant<Place, T> set;
set(dev_ctx, &pre_out, static_cast<T>(1.0));
// softrelu derivative
pre_out_mat.device(place) =
pre_out_mat * (static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat);
math::SubByBitCode<T>(num_classes, *label, pre_out);
if (bias) {
math::AddByBitCodeGrad<T>(num_classes, *label, pre_out, *bias);
}
for (size_t i = 0; i < ins_grad.size(); ++i) {
math::MulByBitCodeGradWeight<T>(num_classes, *label, pre_out, *params[i],
*ins[i]);
math::MulByBitCodeGradError<T>(num_classes, *label, pre_out, *params[i],
*ins_grad[i]);
}
}
}; };
} // namespace operators } // namespace operators
......
...@@ -69,19 +69,23 @@ static void AddByBitCodeT(Op op, CodeTable code_table, ...@@ -69,19 +69,23 @@ static void AddByBitCodeT(Op op, CodeTable code_table,
} }
} }
/* For j < codeLength:
a(i, j) += b(0, index(i, j))
*/
template <typename T> template <typename T>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes, void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& a, const framework::Tensor& b) { framework::Tensor& tmat, const framework::Tensor& vec) {
auto op = [](T& t, T& v) { t += v; }; auto op = [](T& t, T& v) { t += v; };
AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, a, b); AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, vec);
}
template <typename T>
void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes,
const framework::Tensor& tmat, framework::Tensor& vec) {
auto op = [](T& t, T& v) { v += t; };
AddByBitCode<T>(op, SimpleCodeTable(num_classes), codes, tmat, vec);
} }
template <class CodeTable, typename T> template <class CodeTable, typename T>
void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes,
framework::Tensor& tmat, framework::Tensor& sum, framework::Tensor& tmat, const framework::Tensor& sum,
const T& scale_sum) { const T& scale_sum) {
size_t max_code_length = code_table.get_max_code_length(); size_t max_code_length = code_table.get_max_code_length();
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
...@@ -142,8 +146,61 @@ void MulByBitCode(size_t num_classes, const framework::Tensor& codes, ...@@ -142,8 +146,61 @@ void MulByBitCode(size_t num_classes, const framework::Tensor& codes,
} }
t += sum; t += sum;
}; };
MulByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, weight, input); MulByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, weight,
input);
}
template <typename T>
void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes,
const framework::Tensor& tmat,
framework::Tensor& weight,
const framework::Tensor& input) {
auto op = [](const T t, T* weight_row, const T* input_row, size_t input_dim) {
for (size_t k = 0; k < input_dim; ++k) {
weight_row[k] += t * input_row[k];
}
};
MulByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, weight,
input);
}
template <typename T>
void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes,
const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor& input) {
auto op = [](const T t, const T* weight_row, T* input_row, size_t input_dim) {
for (size_t k = 0; k < input_dim; ++k) {
input_row[k] += t * weight_row[k];
}
};
MulByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, weight,
input);
} }
template <class CodeTable, typename T>
void SubByBitCodeT(CodeTable code_table, const framework::Tensor& codes,
framework::Tensor& tmat) {
size_t max_code_length = code_table.get_max_code_length();
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(codes.data<T>()[i]);
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
tmat.data<T>()[i * o_width + j] -= 1;
}
}
}
}
template <typename T>
void SubByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat) {
SubByBitCodeT<T>(SimpleCodeTable(num_classes), codes, tmat);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -59,27 +59,57 @@ struct SimpleCodeTable { ...@@ -59,27 +59,57 @@ struct SimpleCodeTable {
int max_code_length_; int max_code_length_;
}; };
/* For j < codeLength /* For j < code_length
tmat(i, j) += vec(0, index(i, j)) tmat(i, j) += vec(0, index(i, j))
*/ */
template <typename T> template <typename T>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes, void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat, const framework::Tensor& vec); framework::Tensor& tmat, const framework::Tensor& vec);
/* For j < codeLength /* For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
template <typename T>
void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes,
const framework::Tensor& tmat, framework::Tensor& vec);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/ */
template <typename T> template <typename T>
void SumByBitCode(size_t num_classes, const framework::Tensor& codes, void SumByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); framework::Tensor& tmat, framework::Tensor& sum, T scale_sum);
/* For j < codeLength /* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j)) input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/ */
template <typename T> template <typename T>
void MulByBitCode(size_t num_classes, const framework::Tensor& codes, void MulByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor& tmat, const framework::Tensor& weight,
const framework::Tensor& input); const framework::Tensor& input);
/* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
template <typename T>
void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes,
const framework::Tensor& tmat,
framework::Tensor& weight,
const framework::Tensor& input);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
template <typename T>
void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes,
const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor& input);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
template <typename T>
void SubByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册