From fb9c08f0438fe0a25d0d2517e5e770ea5a22555d Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 25 Dec 2017 08:51:04 +0800 Subject: [PATCH] make forward work --- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/hierarchical_sigmoid_op.cc | 50 ++++- paddle/operators/hierarchical_sigmoid_op.h | 95 ++++---- paddle/operators/math/CMakeLists.txt | 2 +- paddle/operators/math/math_function.cc | 12 +- paddle/operators/math/math_function.h | 6 +- paddle/operators/math/math_function_impl.h | 16 +- paddle/operators/math/matrix_bit_code.cc | 211 ++++++++++-------- paddle/operators/math/matrix_bit_code.h | 86 ++++--- paddle/pybind/pybind.cc | 2 + python/paddle/v2/fluid/tests/op_test.py | 11 +- .../paddle/v2/fluid/tests/test_hsigmoid_op.py | 15 +- 12 files changed, 285 insertions(+), 223 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d79f19e670d..5fb14cc6d4a 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -207,7 +207,7 @@ set(DEPS_OPS gru_op adagrad_op sgd_op - hierarchical_sigmoid_op) + hierarchical_sigmoid_op save_op load_op send_op diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 063f8576e66..fa816d9215d 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -60,19 +60,48 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->hasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Parameters"), + "Input(Parameters)" + "should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); const int64_t batch_size = ctx->GetInputDim("X")[0]; - std::vector output_shape({batch_size, num_classes_ - 1}); + std::vector output_shape({batch_size, 1}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } }; class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Parameters"), + "Input(Parameters)" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label)" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Parameters")), + "Input(Parameters@Grad should not be null.)"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } }; class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { @@ -98,7 +127,8 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor, required) The output of hierarchical sigmoid operator." "the shape is [N, 1]"); - AddAttr("num_classes", "(int, required)", "The number of classes"); + AddAttr("num_classes", "(int, required)", "The number of classes") + .SetDefault(2); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. At each node, a sigmoid function is used to caculate the probability of @@ -116,9 +146,9 @@ namespace ops = paddle::operators; REGISTER_OP(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ops::HierarchicalSigmoidOpMaker, hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); -REGISTER_OP_CPU_KERNEL( - hierarchical_sigmoid, - ops::HierarchicalSigmoidOpKernel); -REGISTER_OP_CPU_KERNEL( - hierarchical_sigmoid_grad, - ops::HierarchicalSigmoidGradOpKernel); +REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid, + ops::HierarchicalSigmoidOpKernel< + paddle::platform::CPUDeviceContext, float>); +REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad, + ops::HierarchicalSigmoidGradOpKernel< + paddle::platform::CPUDeviceContext, float>); diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index e3f0bcacd8b..531fd9f7fc0 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -14,8 +14,10 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/clip_op.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/matrix_bit_code.h" +#include "paddle/platform/transform.h" namespace paddle { namespace operators { @@ -23,60 +25,64 @@ namespace operators { template using EigenMatrix = framework::EigenMatrix; +using platform::Transform; -template +template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - auto* param = ctx.Input("Parameter"); + auto* params = ctx.Input("Parameters"); auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); size_t num_classes = static_cast(ctx.Attr("num_classes")); - framework::Tensor sum; + int64_t code_length = math::FindLastSet(num_classes - 1); + int64_t batch_size = in->dims()[0]; + auto* ids = label->data(); framework::Tensor pre_out; - auto place = ctx.GetEigenDevice(); - auto& device_ctx = ctx.device_context(); - math::ColwiseSum col_sum; - math::RowwiseSum row_sum; - + framework::Tensor sum; + auto pre_out_data = pre_out.mutable_data( + framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); auto pre_out_mat = EigenMatrix::From(pre_out); - int64_t batch_size = ins[0]->dims()[0]; - int64_t code_length = math::FindLastSet(num_classes - 1); - std::vector pre_out_dims({batch_size, code_length}); - pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + auto& place = *ctx.template device_context().eigen_device(); + auto& device_ctx = ctx.template device_context(); + math::RowwiseSum row_sum; + math::MatrixBitCodeFunctor bit_code; + std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); + auto sum_mat = EigenMatrix::From(sum); out->mutable_data(ctx.GetPlace()); + auto out_mat = framework::EigenVector::Flatten(*out); if (bias) { - math::AddByBitCode(num_classes, *label, pre_out, *bias); + bit_code.Add(num_classes, ids, pre_out, *bias); } - - for (size_t i = 0; i < in.dims()[0]; ++i) { - math::MulByBitCode(num_classes, *label, pre_out, - *params->Slice(i, i + 1), *in->Slice(i, i + 1)); + for (int i = 0; i < in->dims()[0]; ++i) { + bit_code.Mul(num_classes, ids, pre_out, params->Slice(i, i + 1), + in->Slice(i, i + 1)); } // clip the matrix with (-40, 40) - pre_out_mat.device(place) = - pre_out_mat.abs().cwiseMax(static_cast(40.0)); - math::SumByBitCode(num_classes, *label, *out, pre_out, - static_cast(-1)); - + Transform trans; + trans(ctx.template device_context(), pre_out_data, + pre_out_data + pre_out.numel(), pre_out_data, + ClipFunctor(static_cast(-40.0), static_cast(40.0))); + bit_code.Sum(num_classes, ids, pre_out, *out, static_cast(-1)); // softrelu with threshold is 40.0 - pre_out_mat.device(place) = - pre_out_mat.abs().cwiseMax(static_cast(40.0)); + trans(ctx.template device_context(), pre_out_data, + pre_out_data + pre_out.numel(), pre_out_data, + ClipFunctor(static_cast(-40.0), static_cast(40.0))); pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(device_ctx, pre_out, &sum); - col_sum(device_ctx, *out, &sum); + out_mat.device(place) = sum_mat + out_mat; } }; -template +template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -85,37 +91,40 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto* params = ctx.Output(framework::GradVarName("Parameters")); auto* bias = ctx.Output(framework::GradVarName("Bias")); - auto* label = - ctx.Output(framework::GradVarName("Label")); + auto* label = ctx.Input("Label"); size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t code_length = math::FindLastSet(num_classes - 1); + int64_t batch_size = in->dims()[0]; framework::Tensor pre_out; - auto place = ctx.GetEigenDevice(); - auto& dev_ctx = ctx.device_context(); - int64_t batch_size = in_grad.dims()[0]; - int64_t code_length = math::FindLastSet(num_classes - 1); + pre_out.mutable_data(framework::make_ddim({batch_size, code_length}), + ctx.GetPlace()); + auto& place = *ctx.template device_context().eigen_device(); + auto& device_ctx = ctx.template device_context(); auto pre_out_mat = EigenMatrix::From(pre_out); + auto* ids = label->data(); // init pre_out matrix with {1.0} - std::vector pre_out_dims({batch_size, code_length}); - pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); - math::SetConstant set; - set(dev_ctx, &pre_out, static_cast(1.0)); + math::SetConstant one; + math::MatrixBitCodeFunctor bit_code; + one(device_ctx, &pre_out, static_cast(1.0)); // softrelu derivative pre_out_mat.device(place) = pre_out_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat); - math::SubByBitCode(num_classes, *label, pre_out); + bit_code.Sub(num_classes, ids, pre_out); if (bias) { - math::AddByBitCodeGrad(num_classes, *label, pre_out, *bias); + bit_code.AddGrad(num_classes, ids, pre_out, *bias); } - for (size_t i = 0; i < in_grad.dims()[0]; ++i) { - math::MulByBitCodeGradWeight(num_classes, *label, pre_out, *params[i], - *in[i]->Slice(i, i + 1)); - math::MulByBitCodeGradError(num_classes, *label, pre_out, *params[i], - *ins_grad[i]->Slice(i, i + 1)); + for (int i = 0; i < in_grad->dims()[0]; ++i) { + auto p_sliced = params->Slice(i, i + 1); + auto in_sliced = in->Slice(i, i + 1); + auto in_grad_sliced = in_grad->Slice(i, i + 1); + bit_code.MulGradWeight(num_classes, ids, pre_out, p_sliced, in_sliced); + bit_code.MulGradError(num_classes, ids, pre_out, p_sliced, + in_grad_sliced); } } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 6467d8ddb33..82ba24f35b9 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -27,7 +27,7 @@ else() cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) - cc_library(matrix_bit_code SRCS matrix_bit_code.cc) + cc_library(matrix_bit_code SRCS matrix_bit_code.cc DEPS device_context) cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index ead0fe19713..474fd0b0a94 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -302,12 +302,12 @@ void set_constant(const platform::DeviceContext& context, #endif } -template struct RowwiseAdd; -template struct RowwiseAdd; -template struct ColwiseSum; -template struct ColwiseSum; -template struct RowwiseSum; -template struct RowwiseSum; +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; +template struct RowwiseSum; +template struct RowwiseSum; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 51e0fd9ad77..b49294e6216 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -128,10 +128,10 @@ struct ColwiseSum { framework::Tensor* vec); }; -template +template struct RowwiseSum { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor* vec); + void operator()(const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* vec); }; } // namespace math diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index 4d5e8481018..2b3b6c335b3 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -79,19 +79,19 @@ void ColwiseSum::operator()(const DeviceContext& context, in.sum(Eigen::array({{0}})).reshape(shape); } -template -void RowwiseSum::operator()(const platform::DeviceContext& context, - const framework::Tensor& input, - framework::Tensor* vector) { +template +void RowwiseSum::operator()(const DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* vector) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[1]; PADDLE_ENFORCE_EQ(vector->numel(), size); - auto in = framework::EigenMatrix::From(input); - auto vec = framework::EigenMatrix::From(*vector); + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenMatrix::From(*vector); Eigen::array shape({{static_cast(size), 1}}); - vec.reshape(shape).device(*context.GetEigenDevice()) = - in.sum(Eigen::array({{0}})).reshape(shape); + vec.reshape(shape).device(*context.eigen_device()) = + in.sum(Eigen::array({{1}})).reshape(shape); } } // namespace math } // namespace operators diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index df988510547..9e3836b06dc 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -50,50 +50,52 @@ namespace math { for j < codeLength: op(a(i, j), b(0, index(i, j))) */ -template -static void AddByBitCodeT(Op op, CodeTable code_table, - const framework::Tensor& codes, - framework::Tensor& tmat, +template +static void AddByBitCodeT(Op op, CodeTable code_table, const int64_t* codes, + const framework::Tensor& tmat, const framework::Tensor& vec) { - size_t num_classes = code_table.size(); - size_t max_code_length = code_table.get_max_code_length(); size_t num_sample = tmat.dims()[0]; size_t width = vec.dims()[1]; for (size_t i = 0; i < num_sample; ++i) { - auto code = code_table(codes.data()[i]); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); - for (int j = 0; j < code_length; + j) { + for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); - op(tmat.data()[i * width + j], vec.data()[index]); + auto t = tmat.data()[i * width + j]; + auto v = vec.data()[index]; + op(t, v); } } } -template -void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& vec) { - auto op = [](T& t, T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); -} - -template -void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, framework::Tensor& vec) { - auto op = [](T& t, T& v) { v += t; }; - AddByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, vec); +template +void SubByBitCodeT(CodeTable code_table, const int64_t* codes, + framework::Tensor& tmat) { + // size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(codes[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat.data()[i * o_width + j] -= 1; + } + } + } } -template -void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& sum, +template +void SumByBitCodeT(CodeTable code_table, const int64_t* codes, + framework::Tensor& tmat, framework::Tensor& sum, const T& scale_sum) { - size_t max_code_length = code_table.get_max_code_length(); + // size_t max_code_length = code_table.get_max_code_length(); size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { - T sm = 0; - auto code = code_table(codes.data()[i]); + T sm = static_cast(0.0); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { if (code.calc_bit(j)) { @@ -103,105 +105,124 @@ void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, sum.data()[i] = scale_sum * sm; } } -/* For j < codeLength: - sum(i, 0) = \sum_j bit(i, j) * input(i, j) -*/ + template -void SumByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& sum, - T scale_sum) { - SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, scale_sum); +void MatrixBitCodeFunctor::Add(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + const framework::Tensor& vec) { + auto op = [](T& t, const T& v) { t += v; }; + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); } -template -void MulByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& weight, - framework::Tensor& input) { - size_t num_classes = code_table.size(); - size_t max_code_length = code_table.get_max_code_length(); - size_t num_samples = tmat.dims()[0]; - size_t input_dim = input.dims()[1]; - size_t o_width = tmat.dims()[1]; +template +void MatrixBitCodeFunctor::AddGrad(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + framework::Tensor& vec) { + auto op = [](T& t, T& v) { v += t; }; + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); +} +template +void MatrixBitCodeFunctor::Sum(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + framework::Tensor& sum, T scale_sum) { + SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, sum, scale_sum); +} + +template +void MatrixBitCodeFunctor::Mul(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + const framework::Tensor& weight, + const framework::Tensor& input) { + size_t num_samples = tmat.dims()[0]; + size_t tmat_width = tmat.dims()[1]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_p = tmat.data(); + auto weight_p = weight.data(); + auto input_p = input.data(); + auto code_table = SimpleCodeTable(num_classes); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(codes.data()[i]); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); - op(tmat.data()[i * o_width + j], - weight.data() + index * weight.dims()[1], - input.data() + i * input.dims()[1], input_dim); + + T sum = static_cast(0.0); + for (size_t k = 0; k < input_width; ++k) { + sum += + weight_p[weight_width * index + k] * input_p[input_width * i + k]; + } + std::cout << sum << std::endl; + tmat_p[i * tmat_width + j] += sum; } } } template -void MulByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& weight, - const framework::Tensor& input) { - auto op = [](T& t, const T* weight_row, const T* input_row, - size_t input_dim) { - T sum = 0; - for (size_t k = 0; k < input_dim; ++k) { - sum += weight_row[k] * input_row[k]; - } - t += sum; - }; - MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, - input); -} +void MatrixBitCodeFunctor::MulGradWeight(size_t num_classes, + const int64_t* codes, + const framework::Tensor& tmat, + framework::Tensor& weight, + const framework::Tensor& input) { + size_t num_samples = tmat.dims()[0]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_p = tmat.data(); + auto weight_p = weight.data(); + auto input_p = input.data(); + auto code_table = SimpleCodeTable(num_classes); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(codes[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); -template -void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - framework::Tensor& weight, - const framework::Tensor& input) { - auto op = [](const T t, T* weight_row, const T* input_row, size_t input_dim) { - for (size_t k = 0; k < input_dim; ++k) { - weight_row[k] += t * input_row[k]; + for (size_t k = 0; k < input_width; ++k) { + weight_p[weight_width * index * k] += + tmat_p[i * weight_width * j] * input_p[input_width * i + k]; + } } - }; - MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, - input); + } } template -void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - const framework::Tensor& weight, - framework::Tensor& input) { - auto op = [](const T t, const T* weight_row, T* input_row, size_t input_dim) { - for (size_t k = 0; k < input_dim; ++k) { - input_row[k] += t * weight_row[k]; - } - }; - MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, - input); -} - -template -void SubByBitCodeT(CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat) { - size_t max_code_length = code_table.get_max_code_length(); +void MatrixBitCodeFunctor::MulGradError(size_t num_classes, + const int64_t* codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor& input) { size_t num_samples = tmat.dims()[0]; - size_t o_width = tmat.dims()[1]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_p = tmat.data(); + auto weight_p = weight.data(); + auto input_p = input.data(); + auto code_table = SimpleCodeTable(num_classes); + for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(codes.data()[i]); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { - tmat.data()[i * o_width + j] -= 1; + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + input_p[weight_width * index * k] += + tmat_p[i * weight_width * j] * weight_p[weight_width * i + k]; } } } } template -void SubByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat) { +void MatrixBitCodeFunctor::Sub(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat) { SubByBitCodeT(SimpleCodeTable(num_classes), codes, tmat); } +template class MatrixBitCodeFunctor; +template class MatrixBitCodeFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index 43c9d43d89d..d2ebf182c86 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace operators { @@ -59,57 +60,50 @@ struct SimpleCodeTable { int max_code_length_; }; -/* For j < code_length - tmat(i, j) += vec(0, index(i, j)) -*/ template -void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& vec); +class MatrixBitCodeFunctor { + public: + /* For j < code_length + tmat(i, j) += vec(0, index(i, j)) + */ + void Add(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, + const framework::Tensor& vec); -/* For j < code_length - vec(0, index(i, j)) += tmat(i, j) -*/ -template -void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, framework::Tensor& vec); -/* For j < code_length + /* For j < code_length + vec(0, index(i, j)) += tmat(i, j) + */ + void AddGrad(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, framework::Tensor& vec); + + /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) -*/ -template -void SumByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); + */ + void Sum(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, + framework::Tensor& sum, T scale_sum); -/* For j < code_length - input.row(i) += tmat(i, j) * weight.row(index(i, j)) -*/ -template -void MulByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& weight, - const framework::Tensor& input); + /* For j < code_length + tmat(i, j) -= bit(i, j) + */ + void Sub(size_t num_classes, const int64_t* codes, framework::Tensor& tmat); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void Mul(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, + const framework::Tensor& weight, const framework::Tensor& input); -/* For index(i, j) >= 0: - weight.row(index(i, j)) += tmat(i, j) * input.row(i) -*/ -template -void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - framework::Tensor& weight, - const framework::Tensor& input); -/* For j < code_length - input.row(i) += tmat(i, j) * weight.row(index(i, j)) -*/ -template -void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - const framework::Tensor& weight, - framework::Tensor& input); - -/* For j < code_length - tmat(i, j) -= bit(i, j) -*/ -template -void SubByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat); + /* For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) + */ + void MulGradWeight(size_t num_classes, const int64_t* codes, + const framework::Tensor& tmat, framework::Tensor& weight, + const framework::Tensor& input); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void MulGradError(size_t num_classes, const int64_t* codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, framework::Tensor& input); +}; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index c16d3e0cbe0..a05fcd0451c 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -126,6 +126,8 @@ PYBIND11_PLUGIN(core) { .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("set_float_element", TensorSetElement) .def("get_float_element", TensorGetElement) + .def("set_int64_element", TensorSetElement) + .def("get_int64_element", TensorGetElement) .def("set_double_element", TensorSetElement) .def("get_double_element", TensorGetElement) .def("dtype", [](Tensor &self) { return ToDataType(self.type()); }); diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py index e83c4a06220..edf68075bec 100644 --- a/python/paddle/v2/fluid/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -49,7 +49,6 @@ def create_op(scope, op_type, inputs, outputs, attrs): for attr_name in Operator.get_op_attr_names(op_type): if attr_name in attrs: kwargs[attr_name] = attrs[attr_name] - return Operator(op_type, **kwargs) @@ -107,6 +106,8 @@ def get_numeric_gradient(scope, tensor_to_check_dtype = np.float32 elif tensor_to_check_dtype == core.DataType.FP64: tensor_to_check_dtype = np.float64 + elif tensor_to_check_dtype == core.DataType.INT64: + tensor_to_check_dtype = np.int64 else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) @@ -116,12 +117,16 @@ def get_numeric_gradient(scope, def __get_elem__(tensor, i): if tensor_to_check_dtype == np.float32: return tensor.get_float_element(i) + elif tensor_to_check_dtype == np.int64: + return tensor.get_int64_element(i) else: return tensor.get_double_element(i) def __set_elem__(tensor, i, e): if tensor_to_check_dtype == np.float32: tensor.set_float_element(i, e) + elif tensor_to_check_dtype == np.int64: + tensor.set_int64_element(i, e) else: tensor.set_double_element(i, e) @@ -355,13 +360,11 @@ class OpTest(unittest.TestCase): op_attrs = self.attrs if hasattr(self, "attrs") else dict() self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) - if no_grad_set is None: no_grad_set = set() if not type(output_names) is list: output_names = [output_names] - numeric_grads = user_defined_grads or [ get_numeric_gradient( self.scope, @@ -457,9 +460,7 @@ class OpTest(unittest.TestCase): # infer variable type and infer shape in compile-time op.desc.infer_var_type(block.desc) op.desc.infer_shape(block.desc) - mean_inputs = map(block.var, output_names) - if len(mean_inputs) == 1: loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) op = block.append_op( diff --git a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py index 25c13aabe97..194d5e315fc 100644 --- a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py +++ b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py @@ -5,15 +5,15 @@ from op_test import OpTest class TestHSigmoidOp(OpTest): def setUp(self): - self.op_type = "hierarchical_sigmoid_op" + self.op_type = "hierarchical_sigmoid" num_classes = 6 embded_size = 10 batch_size = 5 x = np.random.random((batch_size, embded_size)).astype("float32") parameter = np.random.random( (batch_size, num_classes - 1, embded_size)).astype("float32") - label = np.random.randint(0, num_classes, batch_size).astype("int64") - bias = np.random.random((1, num_classes - 1)) + label = np.random.randint(0, num_classes, batch_size) + bias = np.random.random((1, num_classes - 1)).astype("float32") self.inputs = { 'X': x, 'Parameters': parameter, @@ -21,13 +21,18 @@ class TestHSigmoidOp(OpTest): 'Bias': bias } self.attrs = {'num_classes': num_classes} - self.outputs = {'Out': label} + self.outputs = { + 'Out': np.random.random((batch_size, 1)).astype("float32") + } def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['x0'], 'Out') + self.check_grad( + ['X', 'Parameters', 'Label', 'Bias'], + 'Out', + no_grad_set=set(['Label'])) if __name__ == '__main__': -- GitLab