提交 c8801e10 编写于 作者: J JiabinYang

grad diff problem to be fixed and need api spec change to be done

上级 f37bd035
......@@ -133,7 +133,8 @@ class SelectedRows {
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t> id_to_index_;
std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when ids has duplicate member
std::unique_ptr<Tensor> value_{nullptr};
int64_t height_;
std::unique_ptr<RWLock> rwlock_{nullptr};
......
......@@ -91,10 +91,19 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is"
"[num_classes - 1, D].");
"[K, D]. Which K is the num of non-leaf node in Path Tree");
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"tensor with shape [N, 1].");
AddInput("PTable",
"(Tensor, optional), The Path Table from root to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
AddInput("PCode",
"(Tensor, optional), The Code on each Node of the Path from root "
"to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
AddInput("Bias",
"(Tensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1].");
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
......@@ -34,12 +35,21 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* path = ctx.Input<framework::Tensor>("PTable");
auto* code = ctx.Input<framework::Tensor>("PCode");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
bool is_custom = false;
if (path) {
is_custom = true;
} else {
is_custom = false;
}
int64_t code_length =
path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
......@@ -52,7 +62,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
label->data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(path, code,
label->data<int64_t>()));
}
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
......@@ -60,15 +78,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) {
bit_code.Add(pre_out, *bias);
bit_code->Add(pre_out, *bias);
}
bit_code.Mul(pre_out, *w, *in);
bit_code->Mul(pre_out, *w, *in);
// clip to [-40, 40]
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
......@@ -86,6 +104,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* path = ctx.Input<framework::Tensor>("PTable");
auto* code = ctx.Input<framework::Tensor>("PCode");
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad =
......@@ -105,7 +125,22 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
bool is_custom = false;
if (path) {
is_custom = true;
} else {
is_custom = false;
}
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
label->data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(path, code,
label->data<int64_t>()));
}
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
......@@ -116,7 +151,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
// softrelu derivative
pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) =
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
......@@ -124,10 +159,10 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code.AddGrad(pre_out_grad, bias_grad);
bit_code->AddGrad(pre_out_grad, bias_grad);
}
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
bit_code.MulGradError(pre_out_grad, *w, in_grad);
bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
bit_code->MulGradError(pre_out_grad, *w, in_grad);
}
};
......
......@@ -21,14 +21,13 @@ namespace math {
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
auto code = code_table->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
size_t index = code->calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
}
}
......@@ -37,14 +36,13 @@ void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
auto code = code_table->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
size_t index = code->calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
}
}
......@@ -53,15 +51,14 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0);
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
auto code = code_table->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
if (code->calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order.
sm += tmat.data<T>()[i * o_width + j];
......@@ -75,7 +72,6 @@ template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1];
......@@ -84,10 +80,10 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
auto weight_value = weight.data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
auto code = code_table->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
size_t index = code->calc_index(j);
T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] *
......@@ -102,7 +98,6 @@ template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
......@@ -111,10 +106,10 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
auto code = code_table->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
......@@ -128,7 +123,6 @@ template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor* input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1];
......@@ -138,10 +132,10 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
auto input_value = input->data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
auto code = code_table->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] +=
......@@ -154,14 +148,13 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
auto code = code_table->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
if (code->calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1;
}
}
......
......@@ -93,9 +93,27 @@ inline int clz(const T& value) {
inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); }
#endif // !_WIN32
}
// set a code interface to create multiple code
class Code {
public:
virtual ~Code() {}
virtual size_t calc_index(int bit) const = 0;
virtual bool calc_bit(int bit) const = 0;
virtual int get_length() const = 0;
};
// set a CodeTable interface to create multiple code table
class CodeTable {
public:
virtual std::unique_ptr<Code> get_code(int64_t code) const = 0;
virtual size_t size() const = 0;
virtual int get_max_code_length() const = 0;
virtual ~CodeTable() {}
};
struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
class SimpleCode : public Code {
public:
SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
: c_(static_cast<size_t>(ids[code]) + num_classes) {}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
......@@ -105,31 +123,111 @@ struct SimpleCode {
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; }
size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
bool calc_bit(int bit) const { return c_ & (1 << bit); }
int get_length() const { return FindLastSet(c_) - 1; }
private:
size_t c_;
};
struct SimpleCodeTable {
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
SimpleCode operator()(size_t code) const {
return SimpleCode(code, num_classes_);
template <typename R>
class CustomCode : public Code {
public:
CustomCode(const framework::Tensor* ptable, const framework::Tensor* pcode,
const int64_t* ids, const int index)
: ptable_(ptable), pcode_(pcode), ids_(ids), index_(index) {}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
* prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
size_t calc_index(int bit) const {
return ptable_
->data<R>()[index_ * static_cast<int>(ptable_->dims()[1]) + bit];
}
bool calc_bit(int bit) const {
return pcode_
->data<R>()[index_ * static_cast<int>(ptable_->dims()[1]) + bit];
}
int get_length() const {
int length = 0;
for (int i = 0; i < ptable_->dims()[1]; i++) {
if (ptable_->data<R>()[index_ * static_cast<int>(ptable_->dims()[1]) +
i] != -1) {
length++;
} else {
return length;
}
}
return length;
}
private:
const framework::Tensor* ptable_;
const framework::Tensor* pcode_;
const int64_t* ids_;
const int index_;
};
class SimpleCodeTable : public CodeTable {
public:
explicit SimpleCodeTable(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new SimpleCode(code, num_classes_, ids_));
return coder;
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
size_t num_classes_;
const int64_t* ids_;
};
template <typename R>
class CustomCodeTable : public CodeTable {
public:
explicit CustomCodeTable(const framework::Tensor* ptable,
const framework::Tensor* pcode, const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new CustomCode<R>(ptable_, pcode_, ids_, code));
return coder;
}
size_t size() const { return static_cast<size_t>(ptable_->dims()[1]); }
int get_max_code_length() const {
return static_cast<size_t>(ptable_->dims()[1]);
}
private:
const framework::Tensor* ptable_;
const framework::Tensor* pcode_;
const int64_t* ids_;
};
template <typename T>
class MatrixBitCodeFunctor {
public:
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
: num_classes_(num_classes),
ids_(ids),
code_table(new SimpleCodeTable(num_classes, ids)) {}
explicit MatrixBitCodeFunctor(const framework::Tensor* ptable,
const framework::Tensor* pcode,
const int64_t* ids)
: num_classes_(static_cast<size_t>(ptable->dims()[1])),
ids_(ids),
code_table(new CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
......@@ -168,6 +266,7 @@ class MatrixBitCodeFunctor {
size_t num_classes_;
const int64_t* ids_;
std::unique_ptr<CodeTable> code_table;
};
} // namespace math
} // namespace operators
......
......@@ -4349,6 +4349,8 @@ def nce(input,
def hsigmoid(input,
label,
num_classes,
ptabl=None,
pcode=None,
param_attr=None,
bias_attr=None,
name=None):
......@@ -4372,6 +4374,12 @@ def hsigmoid(input,
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
ptable: (Variable|None) this variable can store each batch of samples' path to root,
it should be in leaf -> root order
ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like
structure and each element in this array is indexes in parent nodes' Weight Matrix.
pcode: (Variable|None) this variable can store each batch of samples' code,
each code consist with every code of parent nodes. it should be in leaf -> root order
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid
will create ParamAttr as param_attr. If the Initializer of the param_attr
......@@ -4403,12 +4411,25 @@ def hsigmoid(input,
dim = input.shape[1]
if num_classes < 2:
raise ValueError("num_classes must not be less than 2.")
if (ptable is not None) and (pcode is None):
raise ValueError("pcode should not be None when ptable has been set")
elif (ptable is None) and (pcode is not None):
raise ValueError("ptable should not be None when pcode has been set")
else:
pass
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
inputs = {"X": input, "W": weights, "Label": label}
inputs = {
"X": input,
"W": weights,
"PTable": ptable,
"PCode": pcode,
"Label": label
}
if helper.bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr,
......
......@@ -138,8 +138,11 @@ class OpTest(unittest.TestCase):
cls.dtype = "float32"
cls.outputs = {}
np.random.seed(123)
random.seed(124)
# np.random.seed(123)
# random.seed(124)
np.random.seed(190)
random.seed(200)
@classmethod
def tearDownClass(cls):
......
......@@ -40,6 +40,29 @@ class CodeTable(object):
return self.c & (1 << bit)
class CodeTableWithCustomTree(object):
def __init__(self, ptable, pcode, index):
self.ptable_ = ptable
self.pcode_ = pcode
self.index_ = index
def cal_index(self, bit):
return self.ptable_[self.index_][bit]
def get_length(self):
length = 0
for ele in self.ptable_[self.index_]:
if ele >= 0:
length = length + 1
else:
return length
return length
def cal_bit(self, bit):
return self.pcode_[self.index_][bit]
def hsigmoid(x, w, label, bias, num_classes):
batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1)
......@@ -48,10 +71,12 @@ def hsigmoid(x, w, label, bias, num_classes):
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size):
#print("\n leaf {leaf}: \n".format(leaf = label[i]))
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
#print("index {index} ".format(index = j))
pre_output[i][j] += bias[0][idx]
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
......@@ -63,10 +88,12 @@ def hsigmoid(x, w, label, bias, num_classes):
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
#print("\n leaf {leaf}: \n".format(leaf = label[i]))
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
sum = 0.0
for j in range(length):
#print("bit {bit} ".format(bit = code_table.cal_bit(j)))
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
......@@ -77,25 +104,101 @@ def hsigmoid(x, w, label, bias, num_classes):
return pre_output, out
class TestHSigmoidOp(OpTest):
def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
batch_size = x.shape[0]
code_length = len(ptable[0])
code_table = [0 for _ in range(code_length)]
pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i)
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx]
for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i)
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i)
length = code_table.get_length()
sum = 0.0
for j in range(length):
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
# soft relu
pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum
return pre_output, out
# class TestHSigmoidOp(OpTest):
# def setUp(self):
# self.op_type = "hierarchical_sigmoid"
# num_classes = 6
# feature_size = 8
# batch_size = 7
# x = np.random.random((batch_size, feature_size)).astype("float32")
# w = np.random.random((num_classes - 1, feature_size)).astype("float32")
# label = np.random.randint(0, num_classes, (batch_size, 1))
# bias = np.random.random((1, num_classes - 1)).astype("float32")
# self.attrs = {'num_classes': num_classes}
# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
# pre_output, out = hsigmoid(x, w, label, bias, num_classes)
# self.outputs = {'PreOut': pre_output, 'Out': out}
# def test_check_output(self):
# self.check_output()
# def test_check_grad(self):
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
class TestHSigmoidOpWithCostumTree(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.random.randint(0, num_classes, (batch_size, 1))
x = np.random.random((batch_size, feature_size)).astype("float32") * 10
w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 10
label = np.array([0, 1, 4, 5])
ptable = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.inputs = {
'X': x,
'W': w,
'PTable': ptable,
'PCode': pcode,
'Label': label,
'Bias': bias
}
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
print("checking output in CostumTree")
self.check_output()
def test_check_grad(self):
print("checking outputGrad in CostumTree")
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册