未验证 提交 512329b0 编写于 作者: P pangyoki 提交者: GitHub

add asExtra for nce op (#35474)

* add asExtra for nce op

* fix unittest error in macos

* remove asExtra for is_test
上级 4beaa754
......@@ -33,10 +33,13 @@ class NCEOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce");
OP_INOUT_CHECK(ctx->HasOutput("Cost"), "Output", "Cost", "nce");
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
OP_INOUT_CHECK(ctx->HasOutput("SampleLogits"), "Output", "SampleLogits",
"nce");
OP_INOUT_CHECK(ctx->HasOutput("SampleLabels"), "Output", "SampleLabels",
"nce");
}
auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label");
......@@ -89,6 +92,7 @@ class NCEOp : public framework::OperatorWithKernel {
out_dims.push_back(1);
ctx->SetOutputDim("Cost", framework::make_ddim(out_dims));
if (!is_test) {
// set dims of output(SampleOut)
std::vector<int64_t> sample_out_dims;
sample_out_dims.push_back(x_dims[0]);
......@@ -97,6 +101,7 @@ class NCEOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims));
ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -162,14 +167,16 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"Given X is the dot product of input tensor and sampled labels' "
"weights."
"Then 'SampleLogits' is sigmoid(X).")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddOutput("SampleLabels",
"An intermediate tensor of shape[batch_size, num_neg_samples + "
"num_pos_samples]."
"This tensor is output of forward kernel and used in backward "
"kernel to compute grads."
"")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddAttr<int>("num_total_classes",
"Total number of classes in all samples.");
......@@ -189,28 +196,38 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.")
.SetDefault(0)
.AsExtra();
AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({}));
.SetDefault(std::vector<int64_t>({}))
.AsExtra();
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({});
.SetDefault({})
.AsExtra();
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, the split table names that will be fetched from "
"parameter server)"
"in the order of input variables for mapping")
.SetDefault({});
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes "
"for every samples. Under normal conditions, "
"user should avoid setting this attribute.")
.SetDefault({});
.SetDefault({})
.AsExtra();
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference "
"only, false for training.")
.SetDefault(false);
AddComment(R"DOC(
Compute and return the noise-contrastive estimation training loss. See
`Noise-contrastive estimation: A new estimation principle for unnormalized
......
......@@ -41,7 +41,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
void PrepareSamples(const framework::ExecutionContext &context,
Sampler *sampler) {
Sampler *sampler, Tensor *sample_labels) {
auto label = context.Input<Tensor>("Label");
const int64_t *label_data = label->data<int64_t>();
auto label_dims = label->dims();
......@@ -49,7 +49,6 @@ void PrepareSamples(const framework::ExecutionContext &context,
std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes");
auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims();
int64_t *sample_labels_data =
sample_labels->mutable_data<int64_t>(context.GetPlace());
......@@ -82,6 +81,7 @@ class NCEKernel : public framework::OpKernel<T> {
int seed = context.Attr<int>("seed");
int num_total_classes = context.Attr<int>("num_total_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples");
bool is_test = context.Attr<bool>("is_test");
Sampler *sampler;
switch (sampler_type) {
......@@ -139,8 +139,29 @@ class NCEKernel : public framework::OpKernel<T> {
}
}
PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels");
std::vector<int64_t> sample_out_dims;
auto label = context.Input<Tensor>("Label");
Tensor *sample_labels;
Tensor *sample_out;
Tensor sample_labels_tmp, sample_out_tmp;
if (is_test) {
// set dims of output(SampleOut)
int num_true_classes = label->dims().size() == 2 ? label->dims()[1] : 1;
sample_out_dims.push_back((context.Input<Tensor>("Input"))->dims()[0]);
sample_out_dims.push_back(
(num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes));
sample_labels = &sample_labels_tmp;
sample_labels->Resize(framework::make_ddim(sample_out_dims));
sample_out = &sample_out_tmp;
sample_out->Resize(framework::make_ddim(sample_out_dims));
} else {
sample_labels = context.Output<Tensor>("SampleLabels");
sample_out = context.Output<Tensor>("SampleLogits");
}
PrepareSamples<DeviceContext, T>(context, sampler, sample_labels);
const int64_t *sample_labels_data = sample_labels->data<int64_t>();
for (int x = 0; x < sample_labels->numel(); x++) {
......@@ -152,9 +173,7 @@ class NCEKernel : public framework::OpKernel<T> {
x, sample_labels_data[x]));
}
auto sample_out = context.Output<Tensor>("SampleLogits");
T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
auto label = context.Input<Tensor>("Label");
auto sample_weight = context.Input<Tensor>("SampleWeight");
const T *sample_weight_data = nullptr;
if (sample_weight != nullptr) {
......
......@@ -77,7 +77,8 @@ class TestNCE(OpTest):
'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0,
'sampler': 0,
'is_sparse': is_sparse
'is_sparse': is_sparse,
'is_test': self.is_test
}
self.inputs = {
'Input': input,
......@@ -87,6 +88,9 @@ class TestNCE(OpTest):
'SampleWeight': sample_weight
}
def set_is_test(self):
self.is_test = False
def set_data(self):
self.generate_data(5, 25, 100, 1, 2, False)
......@@ -95,6 +99,9 @@ class TestNCE(OpTest):
self.inputs['Bias'], self.inputs['SampleWeight'],
self.inputs['Label'], self.attrs['num_total_classes'],
self.attrs['num_neg_samples'])
if self.is_test:
self.outputs = {'Cost': out[0]}
else:
self.outputs = {
'Cost': out[0],
'SampleLogits': out[1],
......@@ -103,6 +110,7 @@ class TestNCE(OpTest):
def setUp(self):
self.op_type = 'nce'
self.set_is_test()
self.set_data()
self.compute()
......@@ -119,6 +127,15 @@ class TestNCECase1Tensor(TestNCE):
self.generate_data(10, 20, 100, 2, 5, False)
class TestNCETensorIsTest(TestNCE):
# if is_test = True, there's no need to calculate grad
def set_is_test(self):
self.is_test = True
def test_check_grad(self):
pass
class TestNCECase1SelectedRows(unittest.TestCase):
def setUp(self):
self.base_lr = 0.0001
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册