提交 014e50c2 编写于 作者: J JiabinYang

test=develop

上级 ba9ff508
...@@ -533,6 +533,12 @@ class CPUVector : public std::vector<T, std::allocator<T>> { ...@@ -533,6 +533,12 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
return os; return os;
} }
size_t size() const noexcept {
size_t size =
static_cast<size_t>(std::vector<T, std::allocator<T>>::size());
return size;
}
T &operator[](size_t id) { return this->at(id); } T &operator[](size_t id) { return this->at(id); }
const T &operator[](size_t id) const { return this->at(id); } const T &operator[](size_t id) const { return this->at(id); }
......
...@@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
const int64_t batch_size = ctx->GetInputDim("X")[0]; const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1}); std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", /*->*/ "Out");
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -86,32 +87,34 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,32 +87,34 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"(Tensor, required) The input tensor with shape [N, D], " "(LoDTensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size."); "where N is the size of mini-batch, and D is the feature size.");
AddInput("W", AddInput("W",
"(Tensor, required), The parameters of hierarchical " "(LoDTensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is" "sigmoid operator, each of them is a 2-D tensor, the shape is"
"[K, D]. Which K is the num of non-leaf node in Path Tree"); "[K, D]. Which K is the num of non-leaf node in Path Tree");
AddInput("Label", AddInput("Label",
"(Tensor, required), The labels of training data. It's a" "(LoDTensor, required), The labels of training data. It's a"
"tensor with shape [N, 1]."); "tensor with shape [N, 1].");
AddInput("PTable", AddInput("PTable",
"(Tensor, optional), The Path Table from root to current word" "(LoDTensor, optional), The Path Table from root to current word"
"it should have shape like [N, L], L is the length of the Path") "it should have shape like [N, L], L is the length of the Path")
.AsDispensable(); .AsDispensable();
AddInput("PCode", AddInput(
"(Tensor, optional), The Code on each Node of the Path from root " "PCode",
"to current word" "(LoDTensor, optional), The Code on each Node of the Path from root "
"it should have shape like [N, L], L is the length of the Path") "to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable(); .AsDispensable();
AddInput("Bias", AddInput("Bias",
"(Tensor, optional), The bias is a tensor with shape" "(LoDTensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1]."); "[1, num_classes - 1].");
AddOutput("Out", AddOutput(
"(Tensor, required) The output of hierarchical sigmoid operator." "Out",
"The shape is [N, 1]."); "(LoDTensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1].");
AddOutput("PreOut", AddOutput("PreOut",
"(Tensor, required) A intermedia 2-D tensor with shape " "(LoDTensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the " "[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.") "maximum path length from root to leaf nodes.")
.AsIntermediate(); .AsIntermediate();
...@@ -124,6 +127,10 @@ belonging to the right branch. This idea is from ...@@ -124,6 +127,10 @@ belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05): "F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model." Hierarchical Probabilistic Neural Network Language Model."
)DOC"); )DOC");
AddAttr<bool>("is_sparse",
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
} }
}; };
...@@ -133,6 +140,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -133,6 +140,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasInput("PreOut"), PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null."); "Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
...@@ -142,7 +151,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -142,7 +151,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias")); ctx->GetInputDim("Bias"));
} }
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); if (!ctx->Attrs().Get<bool>("is_sparse")) {
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
}
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
...@@ -150,11 +161,33 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -150,11 +161,33 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(out_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
}
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -162,7 +195,8 @@ namespace ops = paddle::operators; ...@@ -162,7 +195,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>, ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid, hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>, ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,9 +14,10 @@ limitations under the License. */ ...@@ -14,9 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/operators/math/matrix_bit_code.h"
...@@ -29,18 +30,37 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -29,18 +30,37 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform; using platform::Transform;
std::vector<int64_t> cal_rows(const framework::LoDTensor* path) {
std::set<int64_t> tmp;
std::vector<int64_t> rows;
rows.clear();
for (size_t i = 0; i < static_cast<size_t>(path->dims()[0]); i++) {
for (size_t j = 0; j < static_cast<size_t>(path->dims()[1]); j++) {
int64_t temp =
path->data<int64_t>()[i * static_cast<size_t>(path->dims()[1]) + j];
if (temp >= 0) {
tmp.insert(temp);
}
}
}
for (std::set<int64_t>::iterator it = tmp.begin(); it != tmp.end(); ++it) {
rows.push_back(*it);
}
return rows;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::LoDTensor>("X");
auto* w = ctx.Input<framework::Tensor>("W"); auto* w = ctx.Input<framework::LoDTensor>("W");
auto* path = ctx.Input<framework::Tensor>("PTable"); auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* code = ctx.Input<framework::Tensor>("PCode"); auto* code = ctx.Input<framework::LoDTensor>("PCode");
auto* label = ctx.Input<framework::Tensor>("Label"); auto* label = ctx.Input<framework::LoDTensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias"); auto* bias = ctx.Input<framework::LoDTensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut"); auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
bool is_custom = false; bool is_custom = false;
if (path) { if (path) {
...@@ -51,7 +71,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -51,7 +71,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
int64_t code_length = int64_t code_length =
path ? path->dims()[1] : math::FindLastSet(num_classes - 1); path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0]; int64_t batch_size = in->dims()[0];
framework::Tensor sum; framework::LoDTensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>( auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
...@@ -102,27 +122,26 @@ template <typename DeviceContext, typename T> ...@@ -102,27 +122,26 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::LoDTensor>("X");
auto* w = ctx.Input<framework::Tensor>("W"); auto* w = ctx.Input<framework::LoDTensor>("W");
auto* path = ctx.Input<framework::Tensor>("PTable"); auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* code = ctx.Input<framework::Tensor>("PCode"); auto* code = ctx.Input<framework::LoDTensor>("PCode");
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* in_grad =
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W")); ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
bool is_sparse = ctx.Attr<bool>("is_sparse");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
auto* bias_grad = auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias")); ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
auto* label = ctx.Input<framework::Tensor>("Label"); auto* label = ctx.Input<framework::LoDTensor>("Label");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut"); auto* pre_out = ctx.Input<framework::LoDTensor>("PreOut");
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
framework::Tensor pre_out_grad; framework::LoDTensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace()); pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace()); in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0)); zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
...@@ -162,7 +181,28 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -162,7 +181,28 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, bias_grad, static_cast<T>(0.0)); zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad); bit_code->AddGrad(pre_out_grad, bias_grad);
} }
bit_code->MulGradWeight(pre_out_grad, w_grad, *in); if (!is_sparse) {
auto* w_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
w_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
} else {
framework::Vector<int64_t> real_rows = cal_rows(path);
auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows);
// build ids -> rows index map
w_grad->SyncIndex();
auto* w_grad_value = w_grad->mutable_value();
framework::DDim temp_dim(w->dims());
set(temp_dim, 0, real_rows.size());
w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
}
bit_code->MulGradError(pre_out_grad, *w, in_grad); bit_code->MulGradError(pre_out_grad, *w, in_grad);
} }
}; };
......
...@@ -19,8 +19,8 @@ namespace operators { ...@@ -19,8 +19,8 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat, void MatrixBitCodeFunctor<T>::Add(framework::LoDTensor* tmat,
const framework::Tensor& vec) { const framework::LoDTensor& vec) {
size_t batch_size = tmat->dims()[0]; size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1]; size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
...@@ -34,8 +34,8 @@ void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat, ...@@ -34,8 +34,8 @@ void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::AddGrad(const framework::LoDTensor& tmat,
framework::Tensor* vec) { framework::LoDTensor* vec) {
size_t batch_size = tmat.dims()[0]; size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1]; size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
...@@ -49,8 +49,8 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, ...@@ -49,8 +49,8 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::Sum(const framework::LoDTensor& tmat,
framework::Tensor* sum, T scale_sum) { framework::LoDTensor* sum, T scale_sum) {
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1]; size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
...@@ -69,9 +69,9 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, ...@@ -69,9 +69,9 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, void MatrixBitCodeFunctor<T>::Mul(framework::LoDTensor* tmat,
const framework::Tensor& weight, const framework::LoDTensor& weight,
const framework::Tensor& input) { const framework::LoDTensor& input) {
size_t num_samples = tmat->dims()[0]; size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1]; size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1]; size_t input_width = input.dims()[1];
...@@ -95,9 +95,9 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, ...@@ -95,9 +95,9 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::LoDTensor& tmat,
framework::Tensor* weight, framework::LoDTensor* weight,
const framework::Tensor& input) { const framework::LoDTensor& input) {
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1]; size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat.dims()[1];
...@@ -119,37 +119,38 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, ...@@ -119,37 +119,38 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
} }
} }
// template <typename T> template <typename T>
// void MatrixBitCodeFunctor<T>::MulGradSparseWeight(const framework::Tensor& void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::LoDTensor& tmat,
// tmat, framework::SelectedRows* weight,
// framework::SelectedRows* weight, const framework::LoDTensor& input) {
// const framework::Tensor& input) { size_t num_samples = tmat.dims()[0];
// size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1];
// size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1];
// size_t tmat_width = tmat.dims()[1]; size_t weight_width = weight->value().dims()[1];
// size_t weight_width = weight->dims()[1]; auto tmat_value = tmat.data<T>();
// auto tmat_value = tmat.data<T>(); auto weight_value = weight->mutable_value()->data<T>();
// auto weight_value = weight->data<T>(); auto input_value = input.data<T>();
// auto input_value = input.data<T>(); for (size_t i = 0; i < num_samples; ++i) {
// for (size_t i = 0; i < num_samples; ++i) { auto code = code_table->get_code(i);
// auto code = code_table->get_code(i); int code_length = code->get_length();
// int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) {
// for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j);
// // size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
// for (size_t k = 0; k < input_width; ++k) { int64_t row_index =
// weight_value[j * weight_width + k] += weight->AutoGrownIndex(static_cast<int64_t>(index), false);
// tmat_value[i * tmat_width + j] * input_value[input_width * i +
// k]; weight_value[row_index * weight_width + k] +=
// } tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
// } }
// } }
// } }
}
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::MulGradError(const framework::LoDTensor& tmat,
const framework::Tensor& weight, const framework::LoDTensor& weight,
framework::Tensor* input) { framework::LoDTensor* input) {
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1]; size_t input_width = input->dims()[1];
...@@ -174,7 +175,7 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, ...@@ -174,7 +175,7 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
} }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) { void MatrixBitCodeFunctor<T>::Sub(framework::LoDTensor* tmat) {
size_t num_samples = tmat->dims()[0]; size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1]; size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -134,8 +136,9 @@ class SimpleCode : public Code { ...@@ -134,8 +136,9 @@ class SimpleCode : public Code {
template <typename R> template <typename R>
class CustomCode : public Code { class CustomCode : public Code {
public: public:
CustomCode(const framework::Tensor* ptable, const framework::Tensor* pcode, CustomCode(const framework::LoDTensor* ptable,
const int64_t* ids, const int index) const framework::LoDTensor* pcode, const int64_t* ids,
const int index)
: ptable_(ptable), pcode_(pcode), ids_(ids), index_(index) {} : ptable_(ptable), pcode_(pcode), ids_(ids), index_(index) {}
/** /**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c * Here the id of root shoud be 1 rather than 0, thus the encoding of class c
...@@ -169,8 +172,8 @@ class CustomCode : public Code { ...@@ -169,8 +172,8 @@ class CustomCode : public Code {
} }
private: private:
const framework::Tensor* ptable_; const framework::LoDTensor* ptable_;
const framework::Tensor* pcode_; const framework::LoDTensor* pcode_;
const int64_t* ids_; const int64_t* ids_;
const int index_; const int index_;
}; };
...@@ -194,8 +197,9 @@ class SimpleCodeTable : public CodeTable { ...@@ -194,8 +197,9 @@ class SimpleCodeTable : public CodeTable {
template <typename R> template <typename R>
class CustomCodeTable : public CodeTable { class CustomCodeTable : public CodeTable {
public: public:
explicit CustomCodeTable(const framework::Tensor* ptable, explicit CustomCodeTable(const framework::LoDTensor* ptable,
const framework::Tensor* pcode, const int64_t* ids) const framework::LoDTensor* pcode,
const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {} : ptable_(ptable), pcode_(pcode), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const { std::unique_ptr<Code> get_code(int64_t code) const {
...@@ -209,8 +213,8 @@ class CustomCodeTable : public CodeTable { ...@@ -209,8 +213,8 @@ class CustomCodeTable : public CodeTable {
} }
private: private:
const framework::Tensor* ptable_; const framework::LoDTensor* ptable_;
const framework::Tensor* pcode_; const framework::LoDTensor* pcode_;
const int64_t* ids_; const int64_t* ids_;
}; };
...@@ -222,8 +226,8 @@ class MatrixBitCodeFunctor { ...@@ -222,8 +226,8 @@ class MatrixBitCodeFunctor {
ids_(ids), ids_(ids),
code_table(new SimpleCodeTable(num_classes, ids)) {} code_table(new SimpleCodeTable(num_classes, ids)) {}
explicit MatrixBitCodeFunctor(const framework::Tensor* ptable, explicit MatrixBitCodeFunctor(const framework::LoDTensor* ptable,
const framework::Tensor* pcode, const framework::LoDTensor* pcode,
const int64_t* ids) const int64_t* ids)
: num_classes_(static_cast<size_t>(ptable->dims()[1])), : num_classes_(static_cast<size_t>(ptable->dims()[1])),
ids_(ids), ids_(ids),
...@@ -231,38 +235,47 @@ class MatrixBitCodeFunctor { ...@@ -231,38 +235,47 @@ class MatrixBitCodeFunctor {
/* For j < code_length /* For j < code_length
tmat(i, j) += vec(0, index(i, j)) tmat(i, j) += vec(0, index(i, j))
*/ */
void Add(framework::Tensor* tmat, const framework::Tensor& vec); void Add(framework::LoDTensor* tmat, const framework::LoDTensor& vec);
/* For j < code_length /* For j < code_length
vec(0, index(i, j)) += tmat(i, j) vec(0, index(i, j)) += tmat(i, j)
*/ */
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); void AddGrad(const framework::LoDTensor& tmat, framework::LoDTensor* vec);
/* For j < code_length /* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/ */
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum); void Sum(const framework::LoDTensor& tmat, framework::LoDTensor* sum,
T scale_sum);
/* For j < code_length /* For j < code_length
tmat(i, j) -= bit(i, j) tmat(i, j) -= bit(i, j)
*/ */
void Sub(framework::Tensor* tmat); void Sub(framework::LoDTensor* tmat);
/* For j < code_length /* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j)) input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/ */
void Mul(framework::Tensor* tmat, const framework::Tensor& weight, void Mul(framework::LoDTensor* tmat, const framework::LoDTensor& weight,
const framework::Tensor& input); const framework::LoDTensor& input);
/* For index(i, j) >= 0: /* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i) weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/ */
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, void MulGradWeight(const framework::LoDTensor& tmat,
const framework::Tensor& input); framework::LoDTensor* weight,
const framework::LoDTensor& input);
/* For SelectedRows Weight, For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void MulGradWeight(const framework::LoDTensor& tmat,
framework::SelectedRows* weight,
const framework::LoDTensor& input);
/* For j < code_length /* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j)) input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/ */
void MulGradError(const framework::Tensor& tmat, void MulGradError(const framework::LoDTensor& tmat,
const framework::Tensor& weight, framework::Tensor* input); const framework::LoDTensor& weight,
framework::LoDTensor* input);
size_t num_classes_; size_t num_classes_;
const int64_t* ids_; const int64_t* ids_;
......
...@@ -4355,7 +4355,8 @@ def hsigmoid(input, ...@@ -4355,7 +4355,8 @@ def hsigmoid(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
name=None, name=None,
is_costum=False): is_costum=False,
is_sparse=False):
""" """
The hierarchical sigmoid operator is used to accelerate the training The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a process of language model. This operator organizes the classes into a
...@@ -4394,9 +4395,11 @@ def hsigmoid(input, ...@@ -4394,9 +4395,11 @@ def hsigmoid(input,
is not set, the bias is initialized zero. Default: None. is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
is_costum: (bool|False)using user defined binary tree instead of default complete binary tree
is_sparse: (bool|False)using sparse update instead of dense update
Returns: Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] Out: (LodTensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples: Examples:
...@@ -4466,7 +4469,8 @@ def hsigmoid(input, ...@@ -4466,7 +4469,8 @@ def hsigmoid(input,
inputs=inputs, inputs=inputs,
outputs={"Out": out, outputs={"Out": out,
"PreOut": pre_out}, "PreOut": pre_out},
attrs={"num_classes": num_classes}) attrs={"num_classes": num_classes,
"is_sparse": is_sparse})
return out return out
......
...@@ -16,10 +16,9 @@ from __future__ import print_function ...@@ -16,10 +16,9 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
import math import math
# import paddle.fluid as fluid
# import paddle.fluid.core as core
# from op_builder import OpBuilder
from op_test import OpTest from op_test import OpTest
np.random.seed(100) np.random.seed(100)
...@@ -141,67 +140,148 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -141,67 +140,148 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
return pre_output, out return pre_output, out
class TestHSigmoidOp(OpTest): # class TestHSigmoidOp(OpTest):
def setUp(self): # def setUp(self):
self.op_type = "hierarchical_sigmoid" # self.op_type = "hierarchical_sigmoid"
num_classes = 6 # num_classes = 6
feature_size = 8 # feature_size = 8
batch_size = 4 # batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32") * 2 # x = np.random.random((batch_size, feature_size)).astype("float32") * 2
w = np.random.random( # w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 2 # (num_classes - 1, feature_size)).astype("float32") * 2
label = np.random.randint(0, num_classes, (batch_size, 1)) # label = np.random.randint(0, num_classes, (batch_size, 1))
bias = np.random.random((1, num_classes - 1)).astype("float32") # bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes} # self.attrs = {'num_classes': num_classes, 'is_sparse': False}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} # self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
pre_output, out = hsigmoid(x, w, label, bias, num_classes) # pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out} # self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
class TestHSigmoidOpWithCostumTree(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32") * 2
w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 2
label = np.array([0, 1, 4, 5])
ptable = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes}
self.inputs = {
'X': x,
'W': w,
'PTable': ptable,
'PCode': pcode,
'Label': label,
'Bias': bias
}
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
print("checking output in CostumTree")
self.check_output()
def test_check_grad(self):
print("checking outputGrad in CostumTree")
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
# def test_check_output(self):
# self.check_output()
# def test_check_grad(self):
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
# class TestHSigmoidOpSparse(OpTest):
# def setUp(self):
# self.op_type = "hierarchical_sigmoid"
# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
# feature_size = 8
# batch_size = 4
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2
# w = np.random.random(
# (num_classes - 1, feature_size)).astype("float32") * 2
# label = np.array([0, 1, 4, 5])
# ptable = np.array(
# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
# (0, 2, -1, -1,
# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
# bias = np.random.random((1, num_classes - 1)).astype("float32")
# self.attrs = {'num_classes': num_classes, 'is_sparse': True}
# self.inputs = {
# 'X': x,
# 'W': w,
# 'PTable': ptable,
# 'PCode': pcode,
# 'Label': label,
# 'Bias': bias
# }
# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
# bias, num_classes)
# self.outputs = {'PreOut': pre_output, 'Out': out}
# def test_check_output(self):
# print("checking output in CostumTree")
# self.check_output()
class TestHSigmoidOpWithSparseGrad():
def hs_net_conf(self):
emb = fluid.layers.data(name="x", shape=[3], dtype='int64')
ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64')
pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
data_list = [emb, ptable, pcode, label]
cost = fluid.layers.hsigmoid(
input=emb,
label=predict_word,
non_leaf_num=4,
ptable=ptable,
pcode=pcode,
is_costum=True,
is_sparse=True)
avg_cost = fluid.layers.reduce_mean(cost)
return avg_cost, data_list
def test_training_test(self):
print("im here")
w = np.arange(12).reshape(4, 3)
x = np.ones((2, 3))
ptable = np.array([(1, 2, -1), (1, 2, -1)])
pcode = np.array([(1, 0, -1), (0, 0, -1)])
label = np.array([(1, 4)])
loss, data_list = hs_net_conf()
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
main_program = fluid.default_main_program()
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
data_name_list = [var.name for var in data_list]
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for pass_id in range(args.num_passes):
for i in range(10):
data = [w, x[i % 2], ptable[i % 2], pcode[i % 2], label[i % 2]]
loss_val = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])
print("loss is: {loss}".format(loss=loss))
# class TestHSigmoidOpWithCostumTree(OpTest):
# def setUp(self):
# self.op_type = "hierarchical_sigmoid"
# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
# feature_size = 8
# batch_size = 4
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2
# w = np.random.random(
# (num_classes - 1, feature_size)).astype("float32") * 2
# label = np.array([0, 1, 4, 5])
# ptable = np.array(
# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
# (0, 2, -1, -1,
# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
# bias = np.random.random((1, num_classes - 1)).astype("float32")
# self.attrs = {'num_classes': num_classes, 'is_sparse': False}
# self.inputs = {
# 'X': x,
# 'W': w,
# 'PTable': ptable,
# 'PCode': pcode,
# 'Label': label,
# 'Bias': bias
# }
# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
# bias, num_classes)
# self.outputs = {'PreOut': pre_output, 'Out': out}
# def test_check_output(self):
# print("checking output in CostumTree")
# self.check_output()
# def test_check_grad(self):
# print("checking outputGrad in CostumTree")
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册