From 80ce7edbb79244f6946cf38e233d3914ef40ddf5 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 9 Jan 2018 20:26:37 +0800 Subject: [PATCH] make farward correct --- paddle/operators/hierarchical_sigmoid_op.cc | 4 +- paddle/operators/hierarchical_sigmoid_op.h | 35 ++-- paddle/operators/math/math_function_impl.h | 8 +- paddle/operators/math/matrix_bit_code.cc | 156 ++++++++---------- paddle/operators/math/matrix_bit_code.h | 25 ++- python/paddle/v2/fluid/tests/op_test.py | 9 +- .../paddle/v2/fluid/tests/test_hsigmoid_op.py | 70 +++++++- 7 files changed, 170 insertions(+), 137 deletions(-) diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index bc6ceb98747..e2ba65d6f9c 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -70,7 +70,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelType( + framework::OpKernelType GetActualKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), @@ -96,7 +96,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelType( + framework::OpKernelType GetActualKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 1b8d21c095e..f5b1b97169c 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -49,34 +49,31 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { auto& place = *ctx.template device_context().eigen_device(); auto& device_ctx = ctx.template device_context(); math::RowwiseSum row_sum; - math::MatrixBitCodeFunctor bit_code; + math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); auto sum_mat = EigenMatrix::From(sum); out->mutable_data(ctx.GetPlace()); auto out_mat = framework::EigenVector::Flatten(*out); - if (bias) { - bit_code.Add(num_classes, ids->data(), pre_out, *bias); + bit_code.Add(pre_out, *bias); } - for (int i = 0; i < in->dims()[0]; ++i) { - bit_code.Mul(num_classes, ids->data(), pre_out, - w->Slice(i, i + 1), in->Slice(i, i + 1)); + for (int64_t i = 0; i < batch_size; ++i) { + auto w_i = w->Slice(i, i + 1); + bit_code.Mul(pre_out, w_i, *in); } // clip the matrix with (-40, 40) Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out.numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); - bit_code.Sum(num_classes, ids->data(), pre_out, *out, - static_cast(-1)); + bit_code.Sum(pre_out, *out, static_cast(-1)); // softrelu with threshold is 40.0 trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out.numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); - row_sum(device_ctx, pre_out, &sum); out_mat.device(place) = sum_mat + out_mat; } @@ -103,28 +100,26 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto pre_out_mat = EigenMatrix::From(pre_out); // init pre_out matrix with {1.0} math::SetConstant one; - math::MatrixBitCodeFunctor bit_code; + math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); one(device_ctx, &pre_out, static_cast(1.0)); // softrelu derivative pre_out_mat.device(place) = pre_out_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat); - bit_code.Sub(num_classes, ids->data(), pre_out); + bit_code.Sub(pre_out); if (bias) { bias->mutable_data(ctx.GetPlace()); - bit_code.AddGrad(num_classes, ids->data(), pre_out, *bias); + bit_code.AddGrad(pre_out, *bias); } in_grad->mutable_data(ctx.GetPlace()); w->mutable_data(ctx.GetPlace()); - for (int i = 0; i < in_grad->dims()[0]; ++i) { - auto p_sliced = w->Slice(i, i + 1); - auto in_sliced = in->Slice(i, i + 1); - auto in_grad_sliced = in_grad->Slice(i, i + 1); - bit_code.MulGradWeight(num_classes, ids->data(), pre_out, - p_sliced, in_sliced); - bit_code.MulGradError(num_classes, ids->data(), pre_out, - p_sliced, in_grad_sliced); + for (int i = 0; i < batch_size; ++i) { + auto w_i = w->Slice(i, i + 1); + // auto in_i = in->Slice(i, i + 1); + // auto in_grad_i = in_grad->Slice(i, i + 1); + bit_code.MulGradWeight(pre_out, w_i, *in); + bit_code.MulGradError(pre_out, w_i, *in_grad); } } }; diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index 98722ff5d29..63fb7182dfd 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -62,13 +62,13 @@ void ColwiseSum::operator()(const DeviceContext& context, template void RowwiseSum::operator()(const DeviceContext& context, const framework::Tensor& input, - framework::Tensor* vector) { + framework::Tensor* out) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[1]; - PADDLE_ENFORCE_EQ(vector->numel(), size); + PADDLE_ENFORCE_EQ(out->numel(), size); - auto in = framework::EigenMatrix::From(input); - auto vec = framework::EigenMatrix::From(*vector); + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenVector::Flatten(*out); vec.device(*context.eigen_device()) = in.sum(Eigen::array({{1}})); } diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index b192183b101..34f5f6ef61b 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -22,7 +22,7 @@ namespace math { * CodeTable class should support 3 functions: * * size_t size() - * return the number of codes + * return the number of ids * * int getMaxCodeLength() * return the maximal code length @@ -45,56 +45,47 @@ namespace math { * */ -/* - for i: - for j < codeLength: - op(a(i, j), b(0, index(i, j))) -*/ -template -static void AddByBitCodeT(Op op, CodeTable code_table, const int64_t* codes, - const framework::Tensor& tmat, - const framework::Tensor& vec) { - size_t num_sample = tmat.dims()[0]; - size_t width = vec.dims()[1]; - for (size_t i = 0; i < num_sample; ++i) { - auto code = code_table(static_cast(codes[i])); +template +void MatrixBitCodeFunctor::Add(framework::Tensor& tmat, + const framework::Tensor& vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); - auto t = tmat.data()[i * width + j]; - auto v = vec.data()[index]; - op(t, v); + tmat.data()[i * width + j] += vec.data()[index]; } } } -template -void SubByBitCodeT(CodeTable code_table, const int64_t* codes, - framework::Tensor& tmat) { - // size_t max_code_length = code_table.get_max_code_length(); - size_t num_samples = tmat.dims()[0]; - size_t o_width = tmat.dims()[1]; - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); +template +void MatrixBitCodeFunctor::AddGrad(framework::Tensor& tmat, + framework::Tensor& vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { - tmat.data()[i * o_width + j] -= 1; - } + size_t index = code.calc_index(j); + vec.data()[index] += tmat.data()[i * width + j]; } } } -template -void SumByBitCodeT(CodeTable code_table, const int64_t* codes, - framework::Tensor& tmat, framework::Tensor& sum, - const T& scale_sum) { - // size_t max_code_length = code_table.get_max_code_length(); +template +void MatrixBitCodeFunctor::Sum(framework::Tensor& tmat, + framework::Tensor& sum, T scale_sum) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { T sm = static_cast(0.0); - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { if (code.calc_bit(j)) { @@ -106,116 +97,99 @@ void SumByBitCodeT(CodeTable code_table, const int64_t* codes, } template -void MatrixBitCodeFunctor::Add(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, - const framework::Tensor& vec) { - auto op = [](T& t, const T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); -} - -template -void MatrixBitCodeFunctor::AddGrad(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, - framework::Tensor& vec) { - auto op = [](T& t, T& v) { v += t; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); -} - -template -void MatrixBitCodeFunctor::Sum(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, - framework::Tensor& sum, T scale_sum) { - SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, sum, scale_sum); -} - -template -void MatrixBitCodeFunctor::Mul(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, +void MatrixBitCodeFunctor::Mul(framework::Tensor& tmat, const framework::Tensor& weight, const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t tmat_width = tmat.dims()[1]; size_t input_width = input.dims()[1]; - size_t weight_width = weight.dims()[1]; - auto tmat_p = tmat.data(); - auto weight_p = weight.data(); - auto input_p = input.data(); - auto code_table = SimpleCodeTable(num_classes); + size_t weight_width = weight.dims()[2]; + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); T sum = static_cast(0.0); for (size_t k = 0; k < input_width; ++k) { - sum += - weight_p[weight_width * index + k] * input_p[input_width * i + k]; + sum += weight_value[weight_width * index + k] * + input_value[input_width * i + k]; } - tmat_p[i * tmat_width + j] += sum; + tmat_value[i * tmat_width + j] += sum; } } } template -void MatrixBitCodeFunctor::MulGradWeight(size_t num_classes, - const int64_t* codes, - const framework::Tensor& tmat, +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::Tensor& weight, const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t weight_width = weight.dims()[1]; - auto tmat_p = tmat.data(); - auto weight_p = weight.data(); - auto input_p = input.data(); - auto code_table = SimpleCodeTable(num_classes); + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); for (size_t k = 0; k < input_width; ++k) { - weight_p[weight_width * index * k] += - tmat_p[i * weight_width * j] * input_p[input_width * i + k]; + weight_value[weight_width * index * k] += + tmat_value[i * weight_width * j] * input_value[input_width * i + k]; } } } } template -void MatrixBitCodeFunctor::MulGradError(size_t num_classes, - const int64_t* codes, - const framework::Tensor& tmat, +void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t weight_width = weight.dims()[1]; - auto tmat_p = tmat.data(); - auto weight_p = weight.data(); - auto input_p = input.data(); - auto code_table = SimpleCodeTable(num_classes); + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); for (size_t k = 0; k < input_width; ++k) { - input_p[weight_width * index * k] += - tmat_p[i * weight_width * j] * weight_p[weight_width * i + k]; + input_value[weight_width * index * k] += + tmat_value[i * weight_width * j] * + weight_value[weight_width * i + k]; } } } } template -void MatrixBitCodeFunctor::Sub(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat) { - SubByBitCodeT(SimpleCodeTable(num_classes), codes, tmat); +void MatrixBitCodeFunctor::Sub(framework::Tensor& tmat) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat.data()[i * o_width + j] -= 1; + } + } + } } template class MatrixBitCodeFunctor; diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index d2ebf182c86..43c676f5cc5 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -63,46 +63,45 @@ struct SimpleCodeTable { template class MatrixBitCodeFunctor { public: + explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), ids_(ids) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ - void Add(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, - const framework::Tensor& vec); + void Add(framework::Tensor& tmat, const framework::Tensor& vec); /* For j < code_length vec(0, index(i, j)) += tmat(i, j) */ - void AddGrad(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, framework::Tensor& vec); + void AddGrad(framework::Tensor& tmat, framework::Tensor& vec); /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ - void Sum(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, - framework::Tensor& sum, T scale_sum); + void Sum(framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); /* For j < code_length tmat(i, j) -= bit(i, j) */ - void Sub(size_t num_classes, const int64_t* codes, framework::Tensor& tmat); + void Sub(framework::Tensor& tmat); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void Mul(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, - const framework::Tensor& weight, const framework::Tensor& input); + void Mul(framework::Tensor& tmat, const framework::Tensor& weight, + const framework::Tensor& input); /* For index(i, j) >= 0: weight.row(index(i, j)) += tmat(i, j) * input.row(i) */ - void MulGradWeight(size_t num_classes, const int64_t* codes, - const framework::Tensor& tmat, framework::Tensor& weight, + void MulGradWeight(const framework::Tensor& tmat, framework::Tensor& weight, const framework::Tensor& input); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void MulGradError(size_t num_classes, const int64_t* codes, - const framework::Tensor& tmat, + void MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor& input); + size_t num_classes_; + const int64_t* ids_; }; } // namespace math } // namespace operators diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py index acc42fd3b3d..b77d2b1268f 100644 --- a/python/paddle/v2/fluid/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -49,6 +49,7 @@ def create_op(scope, op_type, inputs, outputs, attrs): for attr_name in Operator.get_op_attr_names(op_type): if attr_name in attrs: kwargs[attr_name] = attrs[attr_name] + return Operator(op_type, **kwargs) @@ -104,8 +105,6 @@ def get_numeric_gradient(scope, tensor_to_check_dtype = np.float32 elif tensor_to_check_dtype == core.DataType.FP64: tensor_to_check_dtype = np.float64 - elif tensor_to_check_dtype == core.DataType.INT64: - tensor_to_check_dtype = np.int64 else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) @@ -115,8 +114,6 @@ def get_numeric_gradient(scope, def __get_elem__(tensor, i): if tensor_to_check_dtype == np.float32: return tensor.get_float_element(i) - elif tensor_to_check_dtype == np.int64: - return tensor.get_int64_element(i) else: return tensor.get_double_element(i) @@ -356,11 +353,13 @@ class OpTest(unittest.TestCase): op_attrs = self.attrs if hasattr(self, "attrs") else dict() self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) + if no_grad_set is None: no_grad_set = set() if not type(output_names) is list: output_names = [output_names] + numeric_grads = user_defined_grads or [ get_numeric_gradient( self.scope, @@ -456,7 +455,9 @@ class OpTest(unittest.TestCase): # infer variable type and infer shape in compile-time op.desc.infer_var_type(block.desc) op.desc.infer_shape(block.desc) + mean_inputs = map(block.var, output_names) + if len(mean_inputs) == 1: loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) op = block.append_op( diff --git a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py index b6d961b6318..41e95e43639 100644 --- a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py +++ b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py @@ -1,6 +1,71 @@ import unittest import numpy as np from op_test import OpTest +import math + + +def find_latest_set(num): + return 1 + int(math.floor(math.log(num, 2))) + + +class CodeTable(object): + def __init__(self, num_classes, code): + self.c = num_classes + code + + def cal_index(self, bit): + return (self.c >> (bit + 1)) - 1 + + def get_length(self): + return find_latest_set(self.c) - 1 + + def cal_bit(self, bit): + return self.c & (1 << bit) + + +def hsigmoid(x, w, ids, bias, num_classes): + # code length = + # initialize pre out with dims={batch_size, code_length} + batch_size = x.shape[0] + code_length = find_latest_set(num_classes - 1) + code_table = [0 for _ in range(code_length)] + pre_output = np.zeros((batch_size, code_length)) + pre_sum = np.zeros((batch_size, 1)) + out = np.zeros((batch_size, 1)).astype("float32") + # pre_out += code(bias) + for i in xrange(batch_size): + code_table = CodeTable(num_classes, ids[i]) + length = code_table.get_length() + for j in xrange(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[0][idx] + # pre_out += code(w) * x + for i in xrange(batch_size): + for j in xrange(batch_size): + code_table = CodeTable(num_classes, ids[j]) + length = code_table.get_length() + for k in xrange(length): + idx = code_table.cal_index(k) + sum = 0.0 + for l in xrange(x.shape[1]): + sum += w[i][idx][l] * x[j][l] + pre_output[j][k] += sum + # clip[-40.0, 40.0] + np.clip(pre_output, -40.0, 40.0) + # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + for i in xrange(batch_size): + code_table = CodeTable(num_classes, ids[i]) + length = code_table.get_length() + sum = 0.0 + for j in xrange(length): + if code_table.cal_bit(j): + sum += pre_output[i][j] + out[i] = -1.0 * sum + # soft relu + np.clip(pre_output, -40.0, 40.0) + pre_output = np.log(1 + np.exp(pre_output)) + pre_sum = pre_output.sum(1).reshape((batch_size, 1)) + out += pre_sum + return out class TestHSigmoidOp(OpTest): @@ -16,9 +81,8 @@ class TestHSigmoidOp(OpTest): bias = np.random.random((1, num_classes - 1)).astype("float32") self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias} self.attrs = {'num_classes': num_classes} - self.outputs = { - 'Out': np.random.random((batch_size, 1)).astype("float32") - } + out = hsigmoid(x, w, ids, bias, num_classes) + self.outputs = {'Out': out} def test_check_output(self): self.check_output() -- GitLab