提交 1abd3b3a 编写于 作者: Y Yancey1989

implement forward

上级 1971f3ce
...@@ -185,7 +185,8 @@ set(DEPS_OPS ...@@ -185,7 +185,8 @@ set(DEPS_OPS
tensor_array_read_write_op tensor_array_read_write_op
gru_op gru_op
adagrad_op adagrad_op
sgd_op) sgd_op
hierarchical_sigmoid_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
...@@ -203,6 +204,7 @@ op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) ...@@ -203,6 +204,7 @@ op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op) op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)
op_library(array_to_lod_tensor_op SRCS array_to_lod_tensor_op.cc DEPS lod_rank_table_op) op_library(array_to_lod_tensor_op SRCS array_to_lod_tensor_op.cc DEPS lod_rank_table_op)
op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc) op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
if(WITH_GPU) if(WITH_GPU)
op_library(nccl_op DEPS nccl_common) op_library(nccl_op DEPS nccl_common)
endif() endif()
......
...@@ -85,12 +85,16 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -85,12 +85,16 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
"(TensorArray, required) The input array. Each Tensor has the " "(TensorArray, required) The input array. Each Tensor has the "
"same shape with [N * D].") "same shape with [N * D].")
.AsDuplicable(); .AsDuplicable();
AddInput("Parameters",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is s a 2-D tensor.")
.AsDuplicable();
AddInput("Label", AddInput("Label",
"(Tensor, required), The labels of training data. It's a" "(Tensor, required), The labels of training data. It's a"
"1-D tensor."); "1-D tensor.");
AddInput("Bias", AddInput("Bias",
"(Tensor, optional), The bias is a 1-D tensor, " "(Tensor, optional), The bias is a 1-D tensor, "
"which is applied to the output"); "which is applied to the output.");
AddOutput( AddOutput(
"Out", "Out",
"(Tensor, required) The output of hierarchical sigmoid operator."); "(Tensor, required) The output of hierarchical sigmoid operator.");
......
...@@ -14,28 +14,61 @@ limitations under the License. */ ...@@ -14,28 +14,61 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/matrix_bit_code.h" #include "paddle/operators/math/matrix_bit_code.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T>
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto params = ctx.MultiInput<framework::Tensor>("Parameters");
auto* label = ctx.Input<framework::Tensor>("Label"); auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias"); auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
framework::Tensor sum;
framework::Tensor pre_out;
auto place = ctx.GetEigenDevice<Place>();
auto& device_ctx = ctx.device_context();
math::ColwiseSum<Place, T> col_sum;
math::RowwiseSum<Place, T> row_sum;
auto pre_out_mat = EigenMatrix<T>::From(pre_out);
int64_t batch_size = ins[0]->dims()[0]; int64_t batch_size = ins[0]->dims()[0];
int64_t size = ins.size(); int64_t size = ins.size();
framework::Tensor pre_out;
std::vector<int64_t> pre_out_dims({batch_size, size}); std::vector<int64_t> pre_out_dims({batch_size, size});
pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace()); pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace());
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
out->mutable_data<T>(ctx.GetPlace());
if (bias != NULL) { if (bias) {
math::AddByBitCode<T>(num_classes, *label, pre_out, *bias); math::AddByBitCode<T>(num_classes, *label, pre_out, *bias);
} }
for (size_t i = 0; i < ins.size(); ++i) {
math::MulByBitCode<T>(num_classes, *label, pre_out, *params[i], *ins[i]);
}
// clip the matrix with (-40, 40)
pre_out_mat.device(place) =
pre_out_mat.abs().cwiseMax(static_cast<T>(40.0));
math::SumByBitCode<T>(num_classes, *label, *out, pre_out,
static_cast<T>(-1));
// softrelu
pre_out_mat.device(place) = (static_cast<T>(1) + pre_out_mat.exp()).log();
row_sum(device_ctx, pre_out, &sum);
col_sum(device_ctx, *out, &sum);
} }
}; };
......
...@@ -314,6 +314,8 @@ template struct RowwiseAdd<platform::CPUPlace, float>; ...@@ -314,6 +314,8 @@ template struct RowwiseAdd<platform::CPUPlace, float>;
template struct RowwiseAdd<platform::CPUPlace, double>; template struct RowwiseAdd<platform::CPUPlace, double>;
template struct ColwiseSum<platform::CPUPlace, float>; template struct ColwiseSum<platform::CPUPlace, float>;
template struct ColwiseSum<platform::CPUPlace, double>; template struct ColwiseSum<platform::CPUPlace, double>;
template struct RowwiseSum<platform::CPUPlace, float>;
template struct RowwiseSum<platform::CPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -298,6 +298,8 @@ template struct RowwiseAdd<platform::GPUPlace, float>; ...@@ -298,6 +298,8 @@ template struct RowwiseAdd<platform::GPUPlace, float>;
template struct RowwiseAdd<platform::GPUPlace, double>; template struct RowwiseAdd<platform::GPUPlace, double>;
template struct ColwiseSum<platform::GPUPlace, float>; template struct ColwiseSum<platform::GPUPlace, float>;
template struct ColwiseSum<platform::GPUPlace, double>; template struct ColwiseSum<platform::GPUPlace, double>;
template struct RowwiseSum<platform::GPUPlace, float>;
template struct RowwiseSum<platform::GPUPlace, float>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -130,6 +130,12 @@ struct ColwiseSum { ...@@ -130,6 +130,12 @@ struct ColwiseSum {
const framework::Tensor& input, framework::Tensor* vec); const framework::Tensor& input, framework::Tensor* vec);
}; };
template <typename Place, typename T>
struct RowwiseSum {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor* vec);
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -78,6 +78,20 @@ void ColwiseSum<Place, T>::operator()(const platform::DeviceContext& context, ...@@ -78,6 +78,20 @@ void ColwiseSum<Place, T>::operator()(const platform::DeviceContext& context,
in.sum(Eigen::array<int, 1>({{0}})).reshape(shape); in.sum(Eigen::array<int, 1>({{0}})).reshape(shape);
} }
template <typename Place, typename T>
void RowwiseSum<Place, T>::operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* vector) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[1];
PADDLE_ENFORCE_EQ(vector->numel(), size);
auto in = framework::EigenMatrix<T>::From(input);
auto vec = framework::EigenMatrix<T>::From(*vector);
Eigen::array<int, 2> shape({{static_cast<int>(size), 1}});
vec.reshape(shape).device(*context.GetEigenDevice<Place>()) =
in.sum(Eigen::array<int, 1>({{0}})).reshape(shape);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -53,18 +53,18 @@ namespace math { ...@@ -53,18 +53,18 @@ namespace math {
template <class CodeTable, class Op, typename T> template <class CodeTable, class Op, typename T>
static void AddByBitCodeT(Op op, CodeTable code_table, static void AddByBitCodeT(Op op, CodeTable code_table,
const framework::Tensor& codes, framework::Tensor& a, const framework::Tensor& codes, framework::Tensor& a,
framework::Tensor& b) { const framework::Tensor& b) {
size_t num_classes = code_table.size(); size_t num_classes = code_table.size();
size_t max_code_length = code_table.get_max_code_length(); size_t max_code_length = code_table.get_max_code_length();
size_t num_sample = a.dims()[0].size(); size_t num_sample = a.dims()[0];
size_t width = a.dims()[1].size(); size_t width = a.dims()[1];
for (size_t i = 0; i < num_sample; ++i) { for (size_t i = 0; i < num_sample; ++i) {
auto code = code_table(codes.data<T>()[i]) int code_length = auto code = code_table(codes.data<T>()[i]);
code.get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; + j) { for (int j = 0; j < code_length; + j) {
size_t index = code.calc_index(j); size_t index = code.calc_index(j);
op(a<T>.data()[i * width + j], b<T>.data()[index]); op(a.data<T>()[i * width + j], b.data<T>()[index]);
} }
} }
} }
...@@ -79,6 +79,71 @@ void AddByBitCode(size_t num_classes, const framework::Tensor& codes, ...@@ -79,6 +79,71 @@ void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, a, b); AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, a, b);
} }
template <class CodeTable, typename T>
void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes,
framework::Tensor& tmat, framework::Tensor& sum,
const T& scale_sum) {
size_t max_code_length = code_table.get_max_code_length();
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
T sm = 0;
auto code = code_table(codes.data<T>()[i]);
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
sm += tmat.data<T>()[i * o_width + j];
}
}
sum.data<T>()[i] = scale_sum * sm;
}
}
/* For j < codeLength:
sum(i, 0) = \sum_j bit(i, j) * input(i, j)
*/
template <typename T>
void SumByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat, framework::Tensor& sum,
T scale_sum) {
SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, scale_sum);
}
template <class Op, class CodeTable, typename T>
void MulByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes,
framework::Tensor& tmat, framework::Tensor& weight,
framework::Tensor& input) {
size_t num_classes = code_table.size();
size_t max_code_length = code_table.get_max_code_length();
size_t num_samples = tmat.dims()[0];
size_t input_dim = input.dims()[1];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(codes.data<T>()[i]);
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
op(tmat.data<T>()[i * o_width + j],
weight.data<T>() + index * weight.dims()[1],
input.data<T>() + i * input.dims()[1], input_dim);
}
}
}
template <typename T>
void MulByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat, const framework::Tensor& weight,
const framework::Tensor& input) {
auto op = [](T& t, const T* weight_row, const T* input_row,
size_t input_dim) {
T sum = 0;
for (size_t k = 0; k < input_dim; ++k) {
sum += weight_row[k] * input_row[k];
}
t += sum;
};
MulByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, weight, input);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -59,10 +59,27 @@ struct SimpleCodeTable { ...@@ -59,10 +59,27 @@ struct SimpleCodeTable {
int max_code_length_; int max_code_length_;
}; };
/* For j < codeLength
tmat(i, j) += vec(0, index(i, j))
*/
template <typename T> template <typename T>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes, void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& a, const framework::Tensor& b); framework::Tensor& tmat, const framework::Tensor& vec);
/* For j < codeLength
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
template <typename T>
void SumByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat, framework::Tensor& sum, T scale_sum);
/* For j < codeLength
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
template <typename T>
void MulByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat, const framework::Tensor& weight,
const framework::Tensor& input);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册