提交 ee13b396 编写于 作者: W weixing02

fix some errors

上级 8bd148dc
...@@ -62,7 +62,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -62,7 +62,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"), PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
...@@ -87,19 +87,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -87,19 +87,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"(Tensor, required) The input Tensor, which the shape is" "(Tensor, required) The input Tensor, which the shape is"
"[N * D], which N is the size of mini-batch," "[N, D], which N is the size of mini-batch,"
"D is the embded size"); "D is the embded size");
AddInput("W", AddInput("W",
"(Tensor, required), The parameters of hierarchical " "(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is s a 3-D tensor, the shape is" "sigmoid operator, each of them is s a 2-D tensor, the shape is"
"[num_classes - 1, D]"); "[num_classes - 1, D]");
AddInput("Ids", AddInput("Label",
"(Tensor, required), The labels of training data. It's a" "(Tensor, required), The labels of training data. It's a"
"1-D tensor, which the shape is [1, N]"); "1-D tensor, which the shape is [1, N]");
AddInput("Bias", AddInput("Bias",
"(Tensor, optional), The bias is a 1-D tensor, " "(Tensor, optional), The bias is a tensor with shape"
"which is applied to the output, the shape is" "[1, num_classes - 1]");
"[1, num_classes -1]");
AddOutput("Out", AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator." "(Tensor, required) The output of hierarchical sigmoid operator."
"the shape is [N, 1]"); "the shape is [N, 1]");
...@@ -111,7 +110,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -111,7 +110,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(2); .SetDefault(2);
AddComment(R"DOC( AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree. The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to caculate the probability of At each node, a sigmoid function is used to calculate the probability of
belonging to the right branch. This idea is from belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05): "F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model." Hierarchical Probabilistic Neural Network Language Model."
...@@ -124,7 +123,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -124,7 +123,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PreOut"), PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null."); "Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
...@@ -155,9 +154,14 @@ REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ...@@ -155,9 +154,14 @@ REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>, ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid, REGISTER_OP_CPU_KERNEL(
ops::HierarchicalSigmoidOpKernel< hierarchical_sigmoid,
paddle::platform::CPUDeviceContext, float>); ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
ops::HierarchicalSigmoidGradOpKernel< double>);
paddle::platform::CPUDeviceContext, float>); REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -34,7 +34,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W"); auto* w = ctx.Input<framework::Tensor>("W");
auto* ids = ctx.Input<framework::Tensor>("Ids"); auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias"); auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut"); auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
...@@ -50,7 +50,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -50,7 +50,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, pre_out, static_cast<T>(0.0)); zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum; math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, ids->data<int64_t>()); math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::vector<int64_t> sum_dims({batch_size, 1UL}); std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace()); sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
...@@ -87,7 +87,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -87,7 +87,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W")); auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad = auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias")); ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* ids = ctx.Input<framework::Tensor>("Ids"); auto* label = ctx.Input<framework::Tensor>("Label");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut"); auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
...@@ -101,9 +101,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -101,9 +101,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out); auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad); auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
math::MatrixBitCodeFunctor<T> bit_code(num_classes, ids->data<int64_t>()); math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
// softrelu derivative // softrelu derivative
bit_code.OutGrad(&pre_out_grad, *out_grad); Eigen::array<int, 2> bcast({1, static_cast<int>(pre_out_grad.dims()[1])});
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
pre_out_grad_mat = out_grad_mat.broadcast(bcast);
pre_out_grad_mat.device(place) = pre_out_grad_mat.device(place) =
pre_out_grad_mat * pre_out_grad_mat *
(static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp()); (static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp());
......
...@@ -18,32 +18,6 @@ namespace paddle { ...@@ -18,32 +18,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
/**
* CodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int getMaxCodeLength()
* return the maximal code length
*
* Code operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* Code class should support 3 functions:
*
* int getLength()
* return the length of the code
*
* bool calcIndex(int bit)
* bit ranges from 0 to getLength() - 1
* return the index for the (1+bit) level parent
*
* bool calcBit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat, void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) { const framework::Tensor& vec) {
...@@ -192,17 +166,6 @@ void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) { ...@@ -192,17 +166,6 @@ void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
} }
} }
template <typename T>
void MatrixBitCodeFunctor<T>::OutGrad(framework::Tensor* tmat,
const framework::Tensor& input) {
size_t num_samples = tmat->dims()[0];
size_t code_length = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i)
for (size_t j = 0; j < code_length; ++j) {
tmat->data<T>()[i * code_length + j] = input.data<T>()[i];
}
}
template class MatrixBitCodeFunctor<float>; template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>; template class MatrixBitCodeFunctor<double>;
......
...@@ -20,13 +20,39 @@ limitations under the License. */ ...@@ -20,13 +20,39 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
/**
* SimpleCodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int get_max_code_length()
* return the maximal code length
*
* SimpleCode operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* SimpleCode class should support 3 functions:
*
* int get_length()
* return the length of the code
*
* size_t cal_index(int bit)
* bit ranges from 0 to get_length() - 1
* return the index for the (1+bit) level parent
*
* bool calc_bit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/** /**
* return the 1-based index of the highest bit set * return the 1-based index of the highest bit set
* *
* for x > 0: * for x > 0:
* \f[ * \f[
* findLastSet(x) = 1 + \floor*{\log_{2}x} * FindLastSet(x) = 1 + \floor*{\log_{2}x}
* \f] * \f]
*/ */
inline constexpr size_t FindLastSet(size_t x) { inline constexpr size_t FindLastSet(size_t x) {
...@@ -100,10 +126,6 @@ class MatrixBitCodeFunctor { ...@@ -100,10 +126,6 @@ class MatrixBitCodeFunctor {
*/ */
void MulGradError(const framework::Tensor& tmat, void MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, framework::Tensor* input); const framework::Tensor& weight, framework::Tensor* input);
/* For j < code_length
tmat(i, j) == input(i)
*/
void OutGrad(framework::Tensor* tmat, const framework::Tensor& input);
size_t num_classes_; size_t num_classes_;
const int64_t* ids_; const int64_t* ids_;
......
...@@ -3571,18 +3571,17 @@ def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None): ...@@ -3571,18 +3571,17 @@ def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None):
shape=[num_classes - 1, dim], shape=[num_classes - 1, dim],
is_bias=False, is_bias=False,
dtype=input.dtype) dtype=input.dtype)
bias = helper.create_parameter( inputs = {"X": input, "W": weights, "Label": label}
attr=helper.bias_attr, if helper.bias_attr:
shape=[1, num_classes - 1], bias = helper.create_parameter(
is_bias=True, attr=helper.bias_attr,
dtype=input.dtype) shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op( helper.append_op(
type="hierarchical_sigmoid", type="hierarchical_sigmoid",
inputs={"X": input, inputs=inputs,
"W": weights,
"Ids": label,
"Bias": bias},
outputs={"Out": out, outputs={"Out": out,
"PreOut": pre_out}, "PreOut": pre_out},
attrs={"num_classes": num_classes}) attrs={"num_classes": num_classes})
......
...@@ -36,7 +36,7 @@ class CodeTable(object): ...@@ -36,7 +36,7 @@ class CodeTable(object):
return self.c & (1 << bit) return self.c & (1 << bit)
def hsigmoid(x, w, ids, bias, num_classes): def hsigmoid(x, w, label, bias, num_classes):
global pre_output global pre_output
batch_size = x.shape[0] batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1) code_length = find_latest_set(num_classes - 1)
...@@ -45,13 +45,13 @@ def hsigmoid(x, w, ids, bias, num_classes): ...@@ -45,13 +45,13 @@ def hsigmoid(x, w, ids, bias, num_classes):
pre_sum = np.zeros((batch_size, 1)) pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32") out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTable(num_classes, ids[i]) code_table = CodeTable(num_classes, label[i])
length = code_table.get_length() length = code_table.get_length()
for j in range(length): for j in range(length):
idx = code_table.cal_index(j) idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx] pre_output[i][j] += bias[0][idx]
for j in range(batch_size): for j in range(batch_size):
code_table = CodeTable(num_classes, ids[j]) code_table = CodeTable(num_classes, label[j])
length = code_table.get_length() length = code_table.get_length()
for k in range(length): for k in range(length):
idx = code_table.cal_index(k) idx = code_table.cal_index(k)
...@@ -60,10 +60,10 @@ def hsigmoid(x, w, ids, bias, num_classes): ...@@ -60,10 +60,10 @@ def hsigmoid(x, w, ids, bias, num_classes):
sum += w[idx][l] * x[j][l] sum += w[idx][l] * x[j][l]
pre_output[j][k] += sum pre_output[j][k] += sum
# clip[-40.0, 40.0] # clip[-40.0, 40.0]
np.clip(pre_output, -40.0, 40.0) pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j) # out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTable(num_classes, ids[i]) code_table = CodeTable(num_classes, label[i])
length = code_table.get_length() length = code_table.get_length()
sum = 0.0 sum = 0.0
for j in range(length): for j in range(length):
...@@ -86,18 +86,18 @@ class TestHSigmoidOp(OpTest): ...@@ -86,18 +86,18 @@ class TestHSigmoidOp(OpTest):
batch_size = 1 batch_size = 1
x = np.random.random((batch_size, embded_size)).astype("float32") x = np.random.random((batch_size, embded_size)).astype("float32")
w = np.random.random((num_classes - 1, embded_size)).astype("float32") w = np.random.random((num_classes - 1, embded_size)).astype("float32")
ids = np.random.randint(0, num_classes, batch_size) label = np.random.randint(0, num_classes, batch_size)
bias = np.random.random((1, num_classes - 1)).astype("float32") bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes} self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
out = hsigmoid(x, w, ids, bias, num_classes) out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out} self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], 'Out', no_grad_set=set('Ids')) self.check_grad(['Bias', 'X', 'W'], 'Out', no_grad_set=set('Label'))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -176,8 +176,8 @@ class TestBook(unittest.TestCase): ...@@ -176,8 +176,8 @@ class TestBook(unittest.TestCase):
def test_hsigmoid(self): def test_hsigmoid(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.data(name='x', shape=[2, 2], dtype='float32') x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1, 2], dtype='int64') y = layers.data(name='y', shape=[2], dtype='int64')
self.assertIsNotNone( self.assertIsNotNone(
layers.hsigmoid( layers.hsigmoid(
input=x, label=y, num_classes=2)) input=x, label=y, num_classes=2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册