提交 fb9c08f0 编写于 作者: Y Yancey1989

make forward work

上级 28630dd8
...@@ -207,7 +207,7 @@ set(DEPS_OPS ...@@ -207,7 +207,7 @@ set(DEPS_OPS
gru_op gru_op
adagrad_op adagrad_op
sgd_op sgd_op
hierarchical_sigmoid_op) hierarchical_sigmoid_op
save_op save_op
load_op load_op
send_op send_op
......
...@@ -60,19 +60,48 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -60,19 +60,48 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->hasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Parameters"),
"Input(Parameters)"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
const int64_t batch_size = ctx->GetInputDim("X")[0]; const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, num_classes_ - 1}); std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
} }
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
}; };
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Parameters"),
"Input(Parameters)"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label)"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Parameters")),
"Input(Parameters@Grad should not be null.)");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
}; };
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -98,7 +127,8 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,7 +127,8 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator." "(Tensor, required) The output of hierarchical sigmoid operator."
"the shape is [N, 1]"); "the shape is [N, 1]");
AddAttr<int>("num_classes", "(int, required)", "The number of classes"); AddAttr<int>("num_classes", "(int, required)", "The number of classes")
.SetDefault(2);
AddComment(R"DOC( AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree. The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to caculate the probability of At each node, a sigmoid function is used to caculate the probability of
...@@ -116,9 +146,9 @@ namespace ops = paddle::operators; ...@@ -116,9 +146,9 @@ namespace ops = paddle::operators;
REGISTER_OP(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, REGISTER_OP(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker, hierarchical_sigmoid_grad, ops::HierarchicalSigmoidOpMaker, hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOp); ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid,
hierarchical_sigmoid, ops::HierarchicalSigmoidOpKernel<
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUPlace, float>); paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad,
hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOpKernel<
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUPlace, float>); paddle::platform::CPUDeviceContext, float>);
...@@ -14,8 +14,10 @@ limitations under the License. */ ...@@ -14,8 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/clip_op.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/matrix_bit_code.h" #include "paddle/operators/math/matrix_bit_code.h"
#include "paddle/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,60 +25,64 @@ namespace operators { ...@@ -23,60 +25,64 @@ namespace operators {
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto* param = ctx.Input<framework::Tensor>("Parameter"); auto* params = ctx.Input<framework::Tensor>("Parameters");
auto* label = ctx.Input<framework::Tensor>("Label"); auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias"); auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
framework::Tensor sum; int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
auto* ids = label->data<int64_t>();
framework::Tensor pre_out; framework::Tensor pre_out;
auto place = ctx.GetEigenDevice<Place>(); framework::Tensor sum;
auto& device_ctx = ctx.device_context(); auto pre_out_data = pre_out.mutable_data<T>(
math::ColwiseSum<Place, T> col_sum; framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
math::RowwiseSum<Place, T> row_sum;
auto pre_out_mat = EigenMatrix<T>::From(pre_out); auto pre_out_mat = EigenMatrix<T>::From(pre_out);
int64_t batch_size = ins[0]->dims()[0];
int64_t code_length = math::FindLastSet(num_classes - 1);
std::vector<int64_t> pre_out_dims({batch_size, code_length}); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace()); auto& device_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code;
std::vector<int64_t> sum_dims({batch_size, 1UL}); std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace()); sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) { if (bias) {
math::AddByBitCode<T>(num_classes, *label, pre_out, *bias); bit_code.Add(num_classes, ids, pre_out, *bias);
} }
for (int i = 0; i < in->dims()[0]; ++i) {
for (size_t i = 0; i < in.dims()[0]; ++i) { bit_code.Mul(num_classes, ids, pre_out, params->Slice(i, i + 1),
math::MulByBitCode<T>(num_classes, *label, pre_out, in->Slice(i, i + 1));
*params->Slice(i, i + 1), *in->Slice(i, i + 1));
} }
// clip the matrix with (-40, 40) // clip the matrix with (-40, 40)
pre_out_mat.device(place) = Transform<DeviceContext> trans;
pre_out_mat.abs().cwiseMax(static_cast<T>(40.0)); trans(ctx.template device_context<DeviceContext>(), pre_out_data,
math::SumByBitCode<T>(num_classes, *label, *out, pre_out, pre_out_data + pre_out.numel(), pre_out_data,
static_cast<T>(-1)); ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(num_classes, ids, pre_out, *out, static_cast<T>(-1));
// softrelu with threshold is 40.0 // softrelu with threshold is 40.0
pre_out_mat.device(place) = trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_mat.abs().cwiseMax(static_cast<T>(40.0)); pre_out_data + pre_out.numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log(); pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(device_ctx, pre_out, &sum); row_sum(device_ctx, pre_out, &sum);
col_sum(device_ctx, *out, &sum); out_mat.device(place) = sum_mat + out_mat;
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -85,37 +91,40 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -85,37 +91,40 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto* params = auto* params =
ctx.Output<framework::Tensor>(framework::GradVarName("Parameters")); ctx.Output<framework::Tensor>(framework::GradVarName("Parameters"));
auto* bias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias")); auto* bias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* label = auto* label = ctx.Input<framework::Tensor>("Label");
ctx.Output<framework::Tensor>(framework::GradVarName("Label"));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor pre_out; framework::Tensor pre_out;
auto place = ctx.GetEigenDevice<Place>(); pre_out.mutable_data<T>(framework::make_ddim({batch_size, code_length}),
auto& dev_ctx = ctx.device_context(); ctx.GetPlace());
int64_t batch_size = in_grad.dims()[0]; auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
int64_t code_length = math::FindLastSet(num_classes - 1); auto& device_ctx = ctx.template device_context<DeviceContext>();
auto pre_out_mat = EigenMatrix<T>::From(pre_out); auto pre_out_mat = EigenMatrix<T>::From(pre_out);
auto* ids = label->data<int64_t>();
// init pre_out matrix with {1.0} // init pre_out matrix with {1.0}
std::vector<int64_t> pre_out_dims({batch_size, code_length}); math::SetConstant<DeviceContext, T> one;
pre_out.mutable_data<T>(framework::make_ddim(pre_out_dims), ctx.GetPlace()); math::MatrixBitCodeFunctor<T> bit_code;
math::SetConstant<Place, T> set; one(device_ctx, &pre_out, static_cast<T>(1.0));
set(dev_ctx, &pre_out, static_cast<T>(1.0));
// softrelu derivative // softrelu derivative
pre_out_mat.device(place) = pre_out_mat.device(place) =
pre_out_mat * (static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat); pre_out_mat * (static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat);
math::SubByBitCode<T>(num_classes, *label, pre_out); bit_code.Sub(num_classes, ids, pre_out);
if (bias) { if (bias) {
math::AddByBitCodeGrad<T>(num_classes, *label, pre_out, *bias); bit_code.AddGrad(num_classes, ids, pre_out, *bias);
} }
for (size_t i = 0; i < in_grad.dims()[0]; ++i) { for (int i = 0; i < in_grad->dims()[0]; ++i) {
math::MulByBitCodeGradWeight<T>(num_classes, *label, pre_out, *params[i], auto p_sliced = params->Slice(i, i + 1);
*in[i]->Slice(i, i + 1)); auto in_sliced = in->Slice(i, i + 1);
math::MulByBitCodeGradError<T>(num_classes, *label, pre_out, *params[i], auto in_grad_sliced = in_grad->Slice(i, i + 1);
*ins_grad[i]->Slice(i, i + 1)); bit_code.MulGradWeight(num_classes, ids, pre_out, p_sliced, in_sliced);
bit_code.MulGradError(num_classes, ids, pre_out, p_sliced,
in_grad_sliced);
} }
} }
}; };
......
...@@ -27,7 +27,7 @@ else() ...@@ -27,7 +27,7 @@ else()
cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(matrix_bit_code SRCS matrix_bit_code.cc) cc_library(matrix_bit_code SRCS matrix_bit_code.cc DEPS device_context)
cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
......
...@@ -302,12 +302,12 @@ void set_constant(const platform::DeviceContext& context, ...@@ -302,12 +302,12 @@ void set_constant(const platform::DeviceContext& context,
#endif #endif
} }
template struct RowwiseAdd<platform::CPUPlace, float>; template struct RowwiseAdd<platform::CPUDeviceContext, float>;
template struct RowwiseAdd<platform::CPUPlace, double>; template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUPlace, float>; template struct ColwiseSum<platform::CPUDeviceContext, float>;
template struct ColwiseSum<platform::CPUPlace, double>; template struct ColwiseSum<platform::CPUDeviceContext, double>;
template struct RowwiseSum<platform::CPUPlace, float>; template struct RowwiseSum<platform::CPUDeviceContext, float>;
template struct RowwiseSum<platform::CPUPlace, double>; template struct RowwiseSum<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -128,10 +128,10 @@ struct ColwiseSum { ...@@ -128,10 +128,10 @@ struct ColwiseSum {
framework::Tensor* vec); framework::Tensor* vec);
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
struct RowwiseSum { struct RowwiseSum {
void operator()(const platform::DeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& input, framework::Tensor* vec); framework::Tensor* vec);
}; };
} // namespace math } // namespace math
......
...@@ -79,19 +79,19 @@ void ColwiseSum<DeviceContext, T>::operator()(const DeviceContext& context, ...@@ -79,19 +79,19 @@ void ColwiseSum<DeviceContext, T>::operator()(const DeviceContext& context,
in.sum(Eigen::array<int, 1>({{0}})).reshape(shape); in.sum(Eigen::array<int, 1>({{0}})).reshape(shape);
} }
template <typename Place, typename T> template <typename DeviceContext, typename T>
void RowwiseSum<Place, T>::operator()(const platform::DeviceContext& context, void RowwiseSum<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
framework::Tensor* vector) { framework::Tensor* vector) {
auto in_dims = input.dims(); auto in_dims = input.dims();
auto size = input.numel() / in_dims[1]; auto size = input.numel() / in_dims[1];
PADDLE_ENFORCE_EQ(vector->numel(), size); PADDLE_ENFORCE_EQ(vector->numel(), size);
auto in = framework::EigenMatrix<T>::From(input); auto in = framework::EigenMatrix<T, Eigen::ColMajor>::From(input);
auto vec = framework::EigenMatrix<T>::From(*vector); auto vec = framework::EigenMatrix<T, Eigen::ColMajor>::From(*vector);
Eigen::array<int, 2> shape({{static_cast<int>(size), 1}}); Eigen::array<int, 2> shape({{static_cast<int>(size), 1}});
vec.reshape(shape).device(*context.GetEigenDevice<Place>()) = vec.reshape(shape).device(*context.eigen_device()) =
in.sum(Eigen::array<int, 1>({{0}})).reshape(shape); in.sum(Eigen::array<int, 1>({{1}})).reshape(shape);
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -50,50 +50,52 @@ namespace math { ...@@ -50,50 +50,52 @@ namespace math {
for j < codeLength: for j < codeLength:
op(a(i, j), b(0, index(i, j))) op(a(i, j), b(0, index(i, j)))
*/ */
template <class CodeTable, class Op, typename T> template <typename T, class CodeTable, class Op>
static void AddByBitCodeT(Op op, CodeTable code_table, static void AddByBitCodeT(Op op, CodeTable code_table, const int64_t* codes,
const framework::Tensor& codes, const framework::Tensor& tmat,
framework::Tensor& tmat,
const framework::Tensor& vec) { const framework::Tensor& vec) {
size_t num_classes = code_table.size();
size_t max_code_length = code_table.get_max_code_length();
size_t num_sample = tmat.dims()[0]; size_t num_sample = tmat.dims()[0];
size_t width = vec.dims()[1]; size_t width = vec.dims()[1];
for (size_t i = 0; i < num_sample; ++i) { for (size_t i = 0; i < num_sample; ++i) {
auto code = code_table(codes.data<T>()[i]); auto code = code_table(static_cast<size_t>(codes[i]));
int code_length = code.get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; + j) { for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j); size_t index = code.calc_index(j);
op(tmat.data<T>()[i * width + j], vec.data<T>()[index]); auto t = tmat.data<T>()[i * width + j];
auto v = vec.data<T>()[index];
op(t, v);
} }
} }
} }
template <typename T> template <typename T, class CodeTable>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes, void SubByBitCodeT(CodeTable code_table, const int64_t* codes,
framework::Tensor& tmat, const framework::Tensor& vec) { framework::Tensor& tmat) {
auto op = [](T& t, T& v) { t += v; }; // size_t max_code_length = code_table.get_max_code_length();
AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, vec); size_t num_samples = tmat.dims()[0];
} size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
template <typename T> auto code = code_table(static_cast<size_t>(codes[i]));
void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, int code_length = code.get_length();
const framework::Tensor& tmat, framework::Tensor& vec) { for (int j = 0; j < code_length; ++j) {
auto op = [](T& t, T& v) { v += t; }; if (code.calc_bit(j)) {
AddByBitCode<T>(op, SimpleCodeTable(num_classes), codes, tmat, vec); tmat.data<T>()[i * o_width + j] -= 1;
}
}
}
} }
template <class CodeTable, typename T> template <typename T, class CodeTable>
void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, void SumByBitCodeT(CodeTable code_table, const int64_t* codes,
framework::Tensor& tmat, const framework::Tensor& sum, framework::Tensor& tmat, framework::Tensor& sum,
const T& scale_sum) { const T& scale_sum) {
size_t max_code_length = code_table.get_max_code_length(); // size_t max_code_length = code_table.get_max_code_length();
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1]; size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
T sm = 0; T sm = static_cast<T>(0.0);
auto code = code_table(codes.data<T>()[i]); auto code = code_table(static_cast<size_t>(codes[i]));
int code_length = code.get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) { if (code.calc_bit(j)) {
...@@ -103,105 +105,124 @@ void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, ...@@ -103,105 +105,124 @@ void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes,
sum.data<T>()[i] = scale_sum * sm; sum.data<T>()[i] = scale_sum * sm;
} }
} }
/* For j < codeLength:
sum(i, 0) = \sum_j bit(i, j) * input(i, j)
*/
template <typename T> template <typename T>
void SumByBitCode(size_t num_classes, const framework::Tensor& codes, void MatrixBitCodeFunctor<T>::Add(size_t num_classes, const int64_t* codes,
framework::Tensor& tmat, framework::Tensor& sum, framework::Tensor& tmat,
T scale_sum) { const framework::Tensor& vec) {
SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, scale_sum); auto op = [](T& t, const T& v) { t += v; };
AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, vec);
} }
template <class Op, class CodeTable, typename T> template <typename T>
void MulByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, void MatrixBitCodeFunctor<T>::AddGrad(size_t num_classes, const int64_t* codes,
framework::Tensor& tmat, framework::Tensor& weight, framework::Tensor& tmat,
framework::Tensor& input) { framework::Tensor& vec) {
size_t num_classes = code_table.size(); auto op = [](T& t, T& v) { v += t; };
size_t max_code_length = code_table.get_max_code_length(); AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, vec);
size_t num_samples = tmat.dims()[0]; }
size_t input_dim = input.dims()[1];
size_t o_width = tmat.dims()[1];
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(size_t num_classes, const int64_t* codes,
framework::Tensor& tmat,
framework::Tensor& sum, T scale_sum) {
SumByBitCodeT<T>(SimpleCodeTable(num_classes), codes, tmat, sum, scale_sum);
}
template <typename T>
void MatrixBitCodeFunctor<T>::Mul(size_t num_classes, const int64_t* codes,
framework::Tensor& tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1];
size_t input_width = input.dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_p = tmat.data<T>();
auto weight_p = weight.data<T>();
auto input_p = input.data<T>();
auto code_table = SimpleCodeTable(num_classes);
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(codes.data<T>()[i]); auto code = code_table(static_cast<size_t>(codes[i]));
int code_length = code.get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j); size_t index = code.calc_index(j);
op(tmat.data<T>()[i * o_width + j],
weight.data<T>() + index * weight.dims()[1], T sum = static_cast<T>(0.0);
input.data<T>() + i * input.dims()[1], input_dim); for (size_t k = 0; k < input_width; ++k) {
sum +=
weight_p[weight_width * index + k] * input_p[input_width * i + k];
}
std::cout << sum << std::endl;
tmat_p[i * tmat_width + j] += sum;
} }
} }
} }
template <typename T> template <typename T>
void MulByBitCode(size_t num_classes, const framework::Tensor& codes, void MatrixBitCodeFunctor<T>::MulGradWeight(size_t num_classes,
framework::Tensor& tmat, const framework::Tensor& weight, const int64_t* codes,
const framework::Tensor& input) { const framework::Tensor& tmat,
auto op = [](T& t, const T* weight_row, const T* input_row, framework::Tensor& weight,
size_t input_dim) { const framework::Tensor& input) {
T sum = 0; size_t num_samples = tmat.dims()[0];
for (size_t k = 0; k < input_dim; ++k) { size_t input_width = input.dims()[1];
sum += weight_row[k] * input_row[k]; size_t weight_width = weight.dims()[1];
} auto tmat_p = tmat.data<T>();
t += sum; auto weight_p = weight.data<T>();
}; auto input_p = input.data<T>();
MulByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, weight, auto code_table = SimpleCodeTable(num_classes);
input); for (size_t i = 0; i < num_samples; ++i) {
} auto code = code_table(static_cast<size_t>(codes[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
template <typename T> for (size_t k = 0; k < input_width; ++k) {
void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, weight_p[weight_width * index * k] +=
const framework::Tensor& tmat, tmat_p[i * weight_width * j] * input_p[input_width * i + k];
framework::Tensor& weight, }
const framework::Tensor& input) {
auto op = [](const T t, T* weight_row, const T* input_row, size_t input_dim) {
for (size_t k = 0; k < input_dim; ++k) {
weight_row[k] += t * input_row[k];
} }
}; }
MulByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, weight,
input);
} }
template <typename T> template <typename T>
void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, void MatrixBitCodeFunctor<T>::MulGradError(size_t num_classes,
const framework::Tensor& tmat, const int64_t* codes,
const framework::Tensor& weight, const framework::Tensor& tmat,
framework::Tensor& input) { const framework::Tensor& weight,
auto op = [](const T t, const T* weight_row, T* input_row, size_t input_dim) { framework::Tensor& input) {
for (size_t k = 0; k < input_dim; ++k) {
input_row[k] += t * weight_row[k];
}
};
MulByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, tmat, weight,
input);
}
template <class CodeTable, typename T>
void SubByBitCodeT(CodeTable code_table, const framework::Tensor& codes,
framework::Tensor& tmat) {
size_t max_code_length = code_table.get_max_code_length();
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1]; size_t input_width = input.dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_p = tmat.data<T>();
auto weight_p = weight.data<T>();
auto input_p = input.data<T>();
auto code_table = SimpleCodeTable(num_classes);
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(codes.data<T>()[i]); auto code = code_table(static_cast<size_t>(codes[i]));
int code_length = code.get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) { size_t index = code.calc_index(j);
tmat.data<T>()[i * o_width + j] -= 1;
for (size_t k = 0; k < input_width; ++k) {
input_p[weight_width * index * k] +=
tmat_p[i * weight_width * j] * weight_p[weight_width * i + k];
} }
} }
} }
} }
template <typename T> template <typename T>
void SubByBitCode(size_t num_classes, const framework::Tensor& codes, void MatrixBitCodeFunctor<T>::Sub(size_t num_classes, const int64_t* codes,
framework::Tensor& tmat) { framework::Tensor& tmat) {
SubByBitCodeT<T>(SimpleCodeTable(num_classes), codes, tmat); SubByBitCodeT<T>(SimpleCodeTable(num_classes), codes, tmat);
} }
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -59,57 +60,50 @@ struct SimpleCodeTable { ...@@ -59,57 +60,50 @@ struct SimpleCodeTable {
int max_code_length_; int max_code_length_;
}; };
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
template <typename T> template <typename T>
void AddByBitCode(size_t num_classes, const framework::Tensor& codes, class MatrixBitCodeFunctor {
framework::Tensor& tmat, const framework::Tensor& vec); public:
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
void Add(size_t num_classes, const int64_t* codes, framework::Tensor& tmat,
const framework::Tensor& vec);
/* For j < code_length /* For j < code_length
vec(0, index(i, j)) += tmat(i, j) vec(0, index(i, j)) += tmat(i, j)
*/ */
template <typename T> void AddGrad(size_t num_classes, const int64_t* codes,
void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, framework::Tensor& tmat, framework::Tensor& vec);
const framework::Tensor& tmat, framework::Tensor& vec);
/* For j < code_length /* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/ */
template <typename T> void Sum(size_t num_classes, const int64_t* codes, framework::Tensor& tmat,
void SumByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& sum, T scale_sum);
framework::Tensor& tmat, framework::Tensor& sum, T scale_sum);
/* For j < code_length /* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j)) tmat(i, j) -= bit(i, j)
*/ */
template <typename T> void Sub(size_t num_classes, const int64_t* codes, framework::Tensor& tmat);
void MulByBitCode(size_t num_classes, const framework::Tensor& codes, /* For j < code_length
framework::Tensor& tmat, const framework::Tensor& weight, input.row(i) += tmat(i, j) * weight.row(index(i, j))
const framework::Tensor& input); */
void Mul(size_t num_classes, const int64_t* codes, framework::Tensor& tmat,
const framework::Tensor& weight, const framework::Tensor& input);
/* For index(i, j) >= 0: /* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i) weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/ */
template <typename T> void MulGradWeight(size_t num_classes, const int64_t* codes,
void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, const framework::Tensor& tmat, framework::Tensor& weight,
const framework::Tensor& tmat, const framework::Tensor& input);
framework::Tensor& weight, /* For j < code_length
const framework::Tensor& input); input.row(i) += tmat(i, j) * weight.row(index(i, j))
/* For j < code_length */
input.row(i) += tmat(i, j) * weight.row(index(i, j)) void MulGradError(size_t num_classes, const int64_t* codes,
*/ const framework::Tensor& tmat,
template <typename T> const framework::Tensor& weight, framework::Tensor& input);
void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, };
const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor& input);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
template <typename T>
void SubByBitCode(size_t num_classes, const framework::Tensor& codes,
framework::Tensor& tmat);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -126,6 +126,8 @@ PYBIND11_PLUGIN(core) { ...@@ -126,6 +126,8 @@ PYBIND11_PLUGIN(core) {
.def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("set_float_element", TensorSetElement<float>) .def("set_float_element", TensorSetElement<float>)
.def("get_float_element", TensorGetElement<float>) .def("get_float_element", TensorGetElement<float>)
.def("set_int64_element", TensorSetElement<int64_t>)
.def("get_int64_element", TensorGetElement<int64_t>)
.def("set_double_element", TensorSetElement<double>) .def("set_double_element", TensorSetElement<double>)
.def("get_double_element", TensorGetElement<double>) .def("get_double_element", TensorGetElement<double>)
.def("dtype", [](Tensor &self) { return ToDataType(self.type()); }); .def("dtype", [](Tensor &self) { return ToDataType(self.type()); });
......
...@@ -49,7 +49,6 @@ def create_op(scope, op_type, inputs, outputs, attrs): ...@@ -49,7 +49,6 @@ def create_op(scope, op_type, inputs, outputs, attrs):
for attr_name in Operator.get_op_attr_names(op_type): for attr_name in Operator.get_op_attr_names(op_type):
if attr_name in attrs: if attr_name in attrs:
kwargs[attr_name] = attrs[attr_name] kwargs[attr_name] = attrs[attr_name]
return Operator(op_type, **kwargs) return Operator(op_type, **kwargs)
...@@ -107,6 +106,8 @@ def get_numeric_gradient(scope, ...@@ -107,6 +106,8 @@ def get_numeric_gradient(scope,
tensor_to_check_dtype = np.float32 tensor_to_check_dtype = np.float32
elif tensor_to_check_dtype == core.DataType.FP64: elif tensor_to_check_dtype == core.DataType.FP64:
tensor_to_check_dtype = np.float64 tensor_to_check_dtype = np.float64
elif tensor_to_check_dtype == core.DataType.INT64:
tensor_to_check_dtype = np.int64
else: else:
raise ValueError("Not supported data type " + str( raise ValueError("Not supported data type " + str(
tensor_to_check_dtype)) tensor_to_check_dtype))
...@@ -116,12 +117,16 @@ def get_numeric_gradient(scope, ...@@ -116,12 +117,16 @@ def get_numeric_gradient(scope,
def __get_elem__(tensor, i): def __get_elem__(tensor, i):
if tensor_to_check_dtype == np.float32: if tensor_to_check_dtype == np.float32:
return tensor.get_float_element(i) return tensor.get_float_element(i)
elif tensor_to_check_dtype == np.int64:
return tensor.get_int64_element(i)
else: else:
return tensor.get_double_element(i) return tensor.get_double_element(i)
def __set_elem__(tensor, i, e): def __set_elem__(tensor, i, e):
if tensor_to_check_dtype == np.float32: if tensor_to_check_dtype == np.float32:
tensor.set_float_element(i, e) tensor.set_float_element(i, e)
elif tensor_to_check_dtype == np.int64:
tensor.set_int64_element(i, e)
else: else:
tensor.set_double_element(i, e) tensor.set_double_element(i, e)
...@@ -355,13 +360,11 @@ class OpTest(unittest.TestCase): ...@@ -355,13 +360,11 @@ class OpTest(unittest.TestCase):
op_attrs = self.attrs if hasattr(self, "attrs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
op_attrs) op_attrs)
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
if not type(output_names) is list: if not type(output_names) is list:
output_names = [output_names] output_names = [output_names]
numeric_grads = user_defined_grads or [ numeric_grads = user_defined_grads or [
get_numeric_gradient( get_numeric_gradient(
self.scope, self.scope,
...@@ -457,9 +460,7 @@ class OpTest(unittest.TestCase): ...@@ -457,9 +460,7 @@ class OpTest(unittest.TestCase):
# infer variable type and infer shape in compile-time # infer variable type and infer shape in compile-time
op.desc.infer_var_type(block.desc) op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc) op.desc.infer_shape(block.desc)
mean_inputs = map(block.var, output_names) mean_inputs = map(block.var, output_names)
if len(mean_inputs) == 1: if len(mean_inputs) == 1:
loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1])
op = block.append_op( op = block.append_op(
......
...@@ -5,15 +5,15 @@ from op_test import OpTest ...@@ -5,15 +5,15 @@ from op_test import OpTest
class TestHSigmoidOp(OpTest): class TestHSigmoidOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "hierarchical_sigmoid_op" self.op_type = "hierarchical_sigmoid"
num_classes = 6 num_classes = 6
embded_size = 10 embded_size = 10
batch_size = 5 batch_size = 5
x = np.random.random((batch_size, embded_size)).astype("float32") x = np.random.random((batch_size, embded_size)).astype("float32")
parameter = np.random.random( parameter = np.random.random(
(batch_size, num_classes - 1, embded_size)).astype("float32") (batch_size, num_classes - 1, embded_size)).astype("float32")
label = np.random.randint(0, num_classes, batch_size).astype("int64") label = np.random.randint(0, num_classes, batch_size)
bias = np.random.random((1, num_classes - 1)) bias = np.random.random((1, num_classes - 1)).astype("float32")
self.inputs = { self.inputs = {
'X': x, 'X': x,
'Parameters': parameter, 'Parameters': parameter,
...@@ -21,13 +21,18 @@ class TestHSigmoidOp(OpTest): ...@@ -21,13 +21,18 @@ class TestHSigmoidOp(OpTest):
'Bias': bias 'Bias': bias
} }
self.attrs = {'num_classes': num_classes} self.attrs = {'num_classes': num_classes}
self.outputs = {'Out': label} self.outputs = {
'Out': np.random.random((batch_size, 1)).astype("float32")
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['x0'], 'Out') self.check_grad(
['X', 'Parameters', 'Label', 'Bias'],
'Out',
no_grad_set=set(['Label']))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册