未验证 提交 512329b0 编写于 作者: P pangyoki 提交者: GitHub

add asExtra for nce op (#35474)

* add asExtra for nce op

* fix unittest error in macos

* remove asExtra for is_test
上级 4beaa754
...@@ -33,10 +33,13 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -33,10 +33,13 @@ class NCEOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce"); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce");
OP_INOUT_CHECK(ctx->HasOutput("Cost"), "Output", "Cost", "nce"); OP_INOUT_CHECK(ctx->HasOutput("Cost"), "Output", "Cost", "nce");
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
OP_INOUT_CHECK(ctx->HasOutput("SampleLogits"), "Output", "SampleLogits", OP_INOUT_CHECK(ctx->HasOutput("SampleLogits"), "Output", "SampleLogits",
"nce"); "nce");
OP_INOUT_CHECK(ctx->HasOutput("SampleLabels"), "Output", "SampleLabels", OP_INOUT_CHECK(ctx->HasOutput("SampleLabels"), "Output", "SampleLabels",
"nce"); "nce");
}
auto x_dims = ctx->GetInputDim("Input"); auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
...@@ -89,6 +92,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -89,6 +92,7 @@ class NCEOp : public framework::OperatorWithKernel {
out_dims.push_back(1); out_dims.push_back(1);
ctx->SetOutputDim("Cost", framework::make_ddim(out_dims)); ctx->SetOutputDim("Cost", framework::make_ddim(out_dims));
if (!is_test) {
// set dims of output(SampleOut) // set dims of output(SampleOut)
std::vector<int64_t> sample_out_dims; std::vector<int64_t> sample_out_dims;
sample_out_dims.push_back(x_dims[0]); sample_out_dims.push_back(x_dims[0]);
...@@ -97,6 +101,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -97,6 +101,7 @@ class NCEOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims));
ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims));
} }
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -162,14 +167,16 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -162,14 +167,16 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"Given X is the dot product of input tensor and sampled labels' " "Given X is the dot product of input tensor and sampled labels' "
"weights." "weights."
"Then 'SampleLogits' is sigmoid(X).") "Then 'SampleLogits' is sigmoid(X).")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddOutput("SampleLabels", AddOutput("SampleLabels",
"An intermediate tensor of shape[batch_size, num_neg_samples + " "An intermediate tensor of shape[batch_size, num_neg_samples + "
"num_pos_samples]." "num_pos_samples]."
"This tensor is output of forward kernel and used in backward " "This tensor is output of forward kernel and used in backward "
"kernel to compute grads." "kernel to compute grads."
"") "")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddAttr<int>("num_total_classes", AddAttr<int>("num_total_classes",
"Total number of classes in all samples."); "Total number of classes in all samples.");
...@@ -189,28 +196,38 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -189,28 +196,38 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false); AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.")
.SetDefault(0)
.AsExtra();
AddAttr<std::vector<int64_t>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({})); .SetDefault(std::vector<int64_t>({}))
.AsExtra();
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping") "Server endpoints in the order of input variables for mapping")
.SetDefault({}); .SetDefault({})
.AsExtra();
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"table_names", "table_names",
"(string vector, the split table names that will be fetched from " "(string vector, the split table names that will be fetched from "
"parameter server)" "parameter server)"
"in the order of input variables for mapping") "in the order of input variables for mapping")
.SetDefault({}); .SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("custom_neg_classes", AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes " "This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes " "in this list wiil be used as negative classes "
"for every samples. Under normal conditions, " "for every samples. Under normal conditions, "
"user should avoid setting this attribute.") "user should avoid setting this attribute.")
.SetDefault({}); .SetDefault({})
.AsExtra();
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference "
"only, false for training.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Compute and return the noise-contrastive estimation training loss. See Compute and return the noise-contrastive estimation training loss. See
`Noise-contrastive estimation: A new estimation principle for unnormalized `Noise-contrastive estimation: A new estimation principle for unnormalized
......
...@@ -41,7 +41,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; ...@@ -41,7 +41,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PrepareSamples(const framework::ExecutionContext &context, void PrepareSamples(const framework::ExecutionContext &context,
Sampler *sampler) { Sampler *sampler, Tensor *sample_labels) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const int64_t *label_data = label->data<int64_t>(); const int64_t *label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
...@@ -49,7 +49,6 @@ void PrepareSamples(const framework::ExecutionContext &context, ...@@ -49,7 +49,6 @@ void PrepareSamples(const framework::ExecutionContext &context,
std::vector<int> custom_neg_classes = std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes"); context.Attr<std::vector<int>>("custom_neg_classes");
auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
int64_t *sample_labels_data = int64_t *sample_labels_data =
sample_labels->mutable_data<int64_t>(context.GetPlace()); sample_labels->mutable_data<int64_t>(context.GetPlace());
...@@ -82,6 +81,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -82,6 +81,7 @@ class NCEKernel : public framework::OpKernel<T> {
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
int num_total_classes = context.Attr<int>("num_total_classes"); int num_total_classes = context.Attr<int>("num_total_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples"); int num_neg_samples = context.Attr<int>("num_neg_samples");
bool is_test = context.Attr<bool>("is_test");
Sampler *sampler; Sampler *sampler;
switch (sampler_type) { switch (sampler_type) {
...@@ -139,8 +139,29 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -139,8 +139,29 @@ class NCEKernel : public framework::OpKernel<T> {
} }
} }
PrepareSamples<DeviceContext, T>(context, sampler); std::vector<int64_t> sample_out_dims;
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto label = context.Input<Tensor>("Label");
Tensor *sample_labels;
Tensor *sample_out;
Tensor sample_labels_tmp, sample_out_tmp;
if (is_test) {
// set dims of output(SampleOut)
int num_true_classes = label->dims().size() == 2 ? label->dims()[1] : 1;
sample_out_dims.push_back((context.Input<Tensor>("Input"))->dims()[0]);
sample_out_dims.push_back(
(num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes));
sample_labels = &sample_labels_tmp;
sample_labels->Resize(framework::make_ddim(sample_out_dims));
sample_out = &sample_out_tmp;
sample_out->Resize(framework::make_ddim(sample_out_dims));
} else {
sample_labels = context.Output<Tensor>("SampleLabels");
sample_out = context.Output<Tensor>("SampleLogits");
}
PrepareSamples<DeviceContext, T>(context, sampler, sample_labels);
const int64_t *sample_labels_data = sample_labels->data<int64_t>(); const int64_t *sample_labels_data = sample_labels->data<int64_t>();
for (int x = 0; x < sample_labels->numel(); x++) { for (int x = 0; x < sample_labels->numel(); x++) {
...@@ -152,9 +173,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -152,9 +173,7 @@ class NCEKernel : public framework::OpKernel<T> {
x, sample_labels_data[x])); x, sample_labels_data[x]));
} }
auto sample_out = context.Output<Tensor>("SampleLogits");
T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace()); T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
auto label = context.Input<Tensor>("Label");
auto sample_weight = context.Input<Tensor>("SampleWeight"); auto sample_weight = context.Input<Tensor>("SampleWeight");
const T *sample_weight_data = nullptr; const T *sample_weight_data = nullptr;
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
......
...@@ -77,7 +77,8 @@ class TestNCE(OpTest): ...@@ -77,7 +77,8 @@ class TestNCE(OpTest):
'custom_neg_classes': list(range(num_neg_samples)), 'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0, 'seed': 0,
'sampler': 0, 'sampler': 0,
'is_sparse': is_sparse 'is_sparse': is_sparse,
'is_test': self.is_test
} }
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
...@@ -87,6 +88,9 @@ class TestNCE(OpTest): ...@@ -87,6 +88,9 @@ class TestNCE(OpTest):
'SampleWeight': sample_weight 'SampleWeight': sample_weight
} }
def set_is_test(self):
self.is_test = False
def set_data(self): def set_data(self):
self.generate_data(5, 25, 100, 1, 2, False) self.generate_data(5, 25, 100, 1, 2, False)
...@@ -95,6 +99,9 @@ class TestNCE(OpTest): ...@@ -95,6 +99,9 @@ class TestNCE(OpTest):
self.inputs['Bias'], self.inputs['SampleWeight'], self.inputs['Bias'], self.inputs['SampleWeight'],
self.inputs['Label'], self.attrs['num_total_classes'], self.inputs['Label'], self.attrs['num_total_classes'],
self.attrs['num_neg_samples']) self.attrs['num_neg_samples'])
if self.is_test:
self.outputs = {'Cost': out[0]}
else:
self.outputs = { self.outputs = {
'Cost': out[0], 'Cost': out[0],
'SampleLogits': out[1], 'SampleLogits': out[1],
...@@ -103,6 +110,7 @@ class TestNCE(OpTest): ...@@ -103,6 +110,7 @@ class TestNCE(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'nce' self.op_type = 'nce'
self.set_is_test()
self.set_data() self.set_data()
self.compute() self.compute()
...@@ -119,6 +127,15 @@ class TestNCECase1Tensor(TestNCE): ...@@ -119,6 +127,15 @@ class TestNCECase1Tensor(TestNCE):
self.generate_data(10, 20, 100, 2, 5, False) self.generate_data(10, 20, 100, 2, 5, False)
class TestNCETensorIsTest(TestNCE):
# if is_test = True, there's no need to calculate grad
def set_is_test(self):
self.is_test = True
def test_check_grad(self):
pass
class TestNCECase1SelectedRows(unittest.TestCase): class TestNCECase1SelectedRows(unittest.TestCase):
def setUp(self): def setUp(self):
self.base_lr = 0.0001 self.base_lr = 0.0001
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册