提交 b653ed05 编写于 作者: T tangwei12

add prefetch and remvoe selectedrows of bias

上级 7fa2e821
......@@ -243,24 +243,20 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front();
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad
VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to SelectedRows";
block->Var(weight_grad)
->SetType(framework::proto::VarType::SELECTED_ROWS);
block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad
VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
}
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType());
block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType());
}
};
......
......@@ -297,18 +297,19 @@ class NCEGradKernel : public framework::OpKernel<T> {
sample_grad_data[i] *= d_out_data[sample_idx];
}
// get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
if (d_bias != nullptr) {
T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
}
}
bool is_sparse = context.Attr<bool>("is_sparse");
if (!is_sparse) {
// get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
if (d_bias != nullptr) {
T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
}
}
// get d_w
auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
if (d_w != nullptr) {
......@@ -330,34 +331,6 @@ class NCEGradKernel : public framework::OpKernel<T> {
std::set<T> st(labels.begin(), labels.end());
labels.assign(st.begin(), st.end());
auto *bias_var = context.InputVar("Bias");
DDim bias_dim;
if (bias_var->IsType<LoDTensor>()) {
bias_dim = context.Input<LoDTensor>("Bias")->dims();
} else if (bias_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("Bias");
bias_dim = table_t->value().dims();
} else {
PADDLE_THROW(
"The parameter Bias of a NCE_OP "
"must be either LoDTensor or SelectedRows");
}
auto d_bias =
context.Output<SelectedRows>(framework::GradVarName("Bias"));
d_bias->set_rows(labels);
d_bias->set_height(bias_dim[0]);
d_bias->mutable_value()->Resize(
{static_cast<int64_t>(labels.size()), bias_dim[1]});
T *d_bias_data =
d_bias->mutable_value()->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + labels.size(), 0.0);
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[d_bias->Index(sample_labels_data[i])] +=
sample_grad_data[i];
}
auto *table_var = context.InputVar("Weight");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
......
......@@ -24,7 +24,7 @@ from ..initializer import Normal, Constant
from ..framework import Variable, OpProtoHolder
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat
from .tensor import concat, assign
from . import utils
from .. import unique_name
from functools import reduce
......@@ -4770,12 +4770,17 @@ def nce(input,
else:
num_neg_samples = int(num_neg_samples)
remote_prefetch = False
if os.environ.get('PADDLE_ENABLE_REMOTE_PREFETCH'):
remote_prefetch = True
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples,
'seed': seed,
'sampler': sampler,
'is_sparse': is_sparse
'is_sparse': is_sparse,
'remote_prefetch': remote_prefetch
}
helper.append_op(
......
......@@ -14,14 +14,15 @@
from __future__ import print_function
import traceback
import math
import collections
import six
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import delete_ops
import traceback
import collections
import six
class TranspilerTest(unittest.TestCase):
......@@ -823,5 +824,55 @@ class TestRemoteLookupTable(TestDistLookupTableBase):
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
# test for remote prefetch
class TestRemoteNce(TestDistLookupTableBase):
def network_with_table(self, is_sparse, is_distributed):
num_total_classes = 20
sampler = "uniform"
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
input = fluid.layers.data(name="input", shape=[10], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 10],
dtype='float32',
name='nce_w',
initializer=fluid.initializer.ConstantInitializer())
b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1],
dtype='float32',
name='nce_b',
initializer=fluid.initializer.ConstantInitializer())
cost = fluid.layers.nce(input=input,
label=label,
num_total_classes=num_total_classes,
sampler=sampler,
custom_dist=nid_freq_arr.tolist(),
sample_weight=None,
param_attr='nce_w',
bias_attr='nce_b',
seed=1,
num_neg_samples=5,
is_sparse=is_sparse)
avg_cost = fluid.layers.mean(cost)
# optimizer
optimizer = fluid.optimizer.Adam(learning_rate=0.003)
optimizer.minimize(avg_cost)
def net_conf(self):
import os
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
self.network_with_table(is_sparse=True, is_distributed=False)
def transpiler_test_impl(self):
trainer, _ = self.get_trainer()
for op in trainer.blocks[0].ops:
if op.type == "recv":
pass
if __name__ == "__main__":
unittest.main()
......@@ -242,8 +242,7 @@ class DistributeTranspiler(object):
sparse_update_op_types = ["lookup_table", "nce"]
for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True and not op.attr(
'is_distributed'):
'remote_prefetch') is True:
sparse_update_ops.append(op)
return sparse_update_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册