diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index cf14538b1c284d297242197088a66cc156b1762c..aebf6376d16e1e918905f559b2b84396f1a2452e 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -117,6 +117,12 @@ static void MergeMultipleVarsIntoOneBySection( auto& id_tensor = scope->FindVar(id_name)->Get(); auto* out_tensor = scope->FindVar(out_name)->GetMutable(); + + PADDLE_ENFORCE_GT( + out_tensor->numel(), 0, + "When calling this method, the Tensor's numel must larger than zero. " + "Please check Tensor::Resize has been called first."); + auto* out_tensor_data = out_tensor->mutable_data(id_tensor.place()); bool is_on_cpu_place = true; @@ -172,8 +178,9 @@ void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& table_names, const std::vector& epmap, const std::vector& height_sections, - const framework::ExecutionContext& context) { - auto& local_scope = context.scope().NewScope(); + const framework::ExecutionContext& context, + const framework::Scope& scope) { + auto& local_scope = scope.NewScope(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& cpu_ctx = *pool.Get(platform::CPUPlace()); @@ -190,7 +197,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, out_var_names.push_back(out_name + "@" + epmap[i]); } - auto& id_tensor = local_scope.FindVar(id_name)->Get(); + auto& id_tensor = scope.FindVar(id_name)->Get(); std::vector ids_vector; if (platform::is_cpu_place(id_tensor.place())) { auto* id_data = id_tensor.data(); @@ -246,8 +253,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, out_var_names, height_sections, splited_ids, context, &local_scope, &actual_ctx); - - context.scope().DeleteScope(&local_scope); + scope.DeleteScope(&local_scope); } }; // namespace distributed diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h index 53b0fbfb51f60fa86351cca34fd1665c7802591b..53482c4c40e4446bc3f0fca8d4b3199354f9a130 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ b/paddle/fluid/operators/distributed/parameter_prefetch.h @@ -27,7 +27,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& table_names, const std::vector& epmap, const std::vector& height_sections, - const framework::ExecutionContext& context); + const framework::ExecutionContext& context, + const framework::Scope& scope); }; // namespace distributed }; // namespace operators diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 3a73a7637c6d7d3eff7443802a4a52be9149e0ef..a7d0fd4856edc74237151c64f286d468ad86e7ca 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -59,7 +59,8 @@ class LookupTableKernel : public framework::OpKernel { // server #ifdef PADDLE_WITH_DISTRIBUTE operators::distributed::prefetch(id_name, out_name, table_names, epmap, - height_sections, context); + height_sections, context, + context.scope()); #else PADDLE_THROW( "paddle is not compiled with distribute support, can not do " diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 9f97f7821ddf5f7adf61740599b7f998b0dfa6ed..0a0be24a540e3f234cf387f06466cee3e39c3984 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -155,6 +155,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("is_sparse", "(boolean, default false) Sparse update.") .SetDefault(false); + // for parameter prefetch + AddAttr("remote_prefetch", "").SetDefault(false); + AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); + AddAttr>( + "epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input variables for mapping") + .SetDefault({}); + AddAttr>( + "table_names", + "(string vector, the splited table names that will be fetched from " + "parameter server)" + "in the order of input variables for mapping") + .SetDefault({}); + AddAttr>("custom_neg_classes", "This attribute only be used in unitest. Classes " "in this list wiil be used as negative classes " @@ -225,24 +243,20 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference { void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front(); - auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front(); auto attr = op_desc.GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { - VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad + VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to SelectedRows"; block->Var(weight_grad) ->SetType(framework::proto::VarType::SELECTED_ROWS); - block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS); } else { - VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad + VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to LoDTensor"; block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR); - block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR); } block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType()); - block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType()); } }; diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index f2ca6ec247fd1ea09b707c2eaaad0548c8aa5757..2c97eef096eb3d23273e362e658cb1b5fc808609 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -15,8 +15,10 @@ limitations under the License. */ #pragma once #include +#include #include #include +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -24,6 +26,10 @@ limitations under the License. */ #include "paddle/fluid/operators/math/sampler.h" #include "unsupported/Eigen/CXX11/Tensor" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/fluid/operators/distributed/parameter_prefetch.h" +#endif + namespace paddle { namespace operators { @@ -43,7 +49,6 @@ void PrepareSamples(const framework::ExecutionContext &context, auto label = context.Input("Label"); const int64_t *label_data = label->data(); auto label_dims = label->dims(); - // int num_total_classes = context.Attr("num_total_classes"); // for unitest std::vector custom_neg_classes = context.Attr>("custom_neg_classes"); @@ -144,15 +149,82 @@ class NCEKernel : public framework::OpKernel { } // forward mul auto input_mat = EigenMatrix::From(*(context.Input("Input"))); - auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); - for (int64_t i = 0; i < sample_labels->numel(); ++i) { - Eigen::Tensor result = - (input_mat.chip(static_cast(i / sample_labels->dims()[1]), 0) * - weight_mat.chip(sample_labels_data[i], 0)) - .sum(); - sample_out_data[i] += result(0); - sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); + + // for remote prefetch + auto epmap = context.Attr>("epmap"); + + if (!epmap.empty()) { + // if epmap is not empty, then the parameter will be fetched from remote + // parameter + // server + + std::vector labels; + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + labels.push_back(sample_labels_data[i]); + } + std::set st(labels.begin(), labels.end()); + labels.assign(st.begin(), st.end()); + + framework::Scope &local_scope = context.scope().NewScope(); + + auto height_sections = context.Attr>("height_sections"); + auto table_names = context.Attr>("table_names"); + + auto *ids = local_scope.Var("Ids@Prefetch"); + auto *x_tensor = ids->GetMutable(); + x_tensor->mutable_data( + framework::make_ddim({static_cast(labels.size()), 1}), + context.GetPlace()); + // copy. + std::memcpy(x_tensor->data(), labels.data(), + labels.size() * sizeof(int64_t)); + + std::vector w_dims = paddle::framework::vectorize2int( + context.Input("Weight")->dims()); + w_dims[0] = static_cast(labels.size()); + + auto *w_tensor = local_scope.Var("Weight@Prefetch") + ->GetMutable(); + w_tensor->Resize(framework::make_ddim(w_dims)); + +#ifdef PADDLE_WITH_DISTRIBUTE + operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch", + table_names, epmap, height_sections, + context, local_scope); +#else + PADDLE_THROW( + "paddle is not compiled with distribute support, can not do " + "parameter prefetch!"); +#endif + + auto weight_mat = EigenMatrix::From( + (local_scope.Var("Weight@Prefetch")->Get())); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + std::vector::iterator it = + std::find(labels.begin(), labels.end(), sample_labels_data[i]); + int idx = std::distance(labels.begin(), it); + + Eigen::Tensor result = + (input_mat.chip(static_cast(i / sample_labels->dims()[1]), 0) * + weight_mat.chip(idx, 0)) + .sum(); + sample_out_data[i] += result(0); + sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); + } + context.scope().DeleteScope(&local_scope); + } else { + auto weight_mat = + EigenMatrix::From(*(context.Input("Weight"))); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + Eigen::Tensor result = + (input_mat.chip(static_cast(i / sample_labels->dims()[1]), 0) * + weight_mat.chip(sample_labels_data[i], 0)) + .sum(); + sample_out_data[i] += result(0); + sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); + } } + // forward cost for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) { out_data[i] = 0; @@ -240,18 +312,19 @@ class NCEGradKernel : public framework::OpKernel { sample_grad_data[i] *= d_out_data[sample_idx]; } + // get d_bias + auto d_bias = context.Output(framework::GradVarName("Bias")); + if (d_bias != nullptr) { + T *d_bias_data = d_bias->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + } + } + bool is_sparse = context.Attr("is_sparse"); if (!is_sparse) { - // get d_bias - auto d_bias = context.Output(framework::GradVarName("Bias")); - if (d_bias != nullptr) { - T *d_bias_data = d_bias->mutable_data(context.GetPlace()); - std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); - for (int64_t i = 0; i < sample_labels->numel(); ++i) { - d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; - } - } // get d_w auto d_w = context.Output(framework::GradVarName("Weight")); if (d_w != nullptr) { @@ -273,34 +346,6 @@ class NCEGradKernel : public framework::OpKernel { std::set st(labels.begin(), labels.end()); labels.assign(st.begin(), st.end()); - auto *bias_var = context.InputVar("Bias"); - DDim bias_dim; - if (bias_var->IsType()) { - bias_dim = context.Input("Bias")->dims(); - } else if (bias_var->IsType()) { - auto *table_t = context.Input("Bias"); - bias_dim = table_t->value().dims(); - } else { - PADDLE_THROW( - "The parameter Bias of a NCE_OP " - "must be either LoDTensor or SelectedRows"); - } - - auto d_bias = - context.Output(framework::GradVarName("Bias")); - d_bias->set_rows(labels); - d_bias->set_height(bias_dim[0]); - - d_bias->mutable_value()->Resize( - {static_cast(labels.size()), bias_dim[1]}); - T *d_bias_data = - d_bias->mutable_value()->mutable_data(context.GetPlace()); - std::fill(d_bias_data, d_bias_data + labels.size(), 0.0); - for (int64_t i = 0; i < sample_labels->numel(); ++i) { - d_bias_data[d_bias->Index(sample_labels_data[i])] += - sample_grad_data[i]; - } - auto *table_var = context.InputVar("Weight"); DDim table_dim; if (table_var->IsType()) { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e25eaaa9fda6add9d8e81d9e6bdfb711cee3648e..37ddfdf7d5817dac0013437b8a9d895aaa333675 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -24,7 +24,7 @@ from ..initializer import Normal, Constant from ..framework import Variable, OpProtoHolder from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ -from .tensor import concat +from .tensor import concat, assign from . import utils from .. import unique_name from functools import reduce @@ -4811,12 +4811,17 @@ def nce(input, else: num_neg_samples = int(num_neg_samples) + remote_prefetch = False + if os.environ.get('PADDLE_ENABLE_REMOTE_PREFETCH'): + remote_prefetch = True + attrs = { 'num_total_classes': int(num_total_classes), 'num_neg_samples': num_neg_samples, 'seed': seed, 'sampler': sampler, - 'is_sparse': is_sparse + 'is_sparse': is_sparse, + 'remote_prefetch': remote_prefetch } helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index d9ad4e2e2c7b8d0a99d917495fbc8efc6cbd188d..650a745cdc415edf0c7b733c95a72454b1305cfd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -14,14 +14,15 @@ from __future__ import print_function +import traceback import math +import collections +import six import unittest +import numpy as np + import paddle.fluid as fluid -from paddle.fluid.transpiler.distribute_transpiler import delete_ops -import traceback -import collections -import six class TranspilerTest(unittest.TestCase): @@ -824,5 +825,55 @@ class TestRemoteLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) +# test for remote prefetch +class TestRemoteNce(TestDistLookupTableBase): + def network_with_table(self, is_sparse, is_distributed): + + num_total_classes = 20 + sampler = "uniform" + nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32') + + input = fluid.layers.data(name="input", shape=[10], dtype="float32") + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + w_param = fluid.default_main_program().global_block().create_parameter( + shape=[num_total_classes, 10], + dtype='float32', + name='nce_w', + initializer=fluid.initializer.ConstantInitializer()) + b_param = fluid.default_main_program().global_block().create_parameter( + shape=[num_total_classes, 1], + dtype='float32', + name='nce_b', + initializer=fluid.initializer.ConstantInitializer()) + + cost = fluid.layers.nce(input=input, + label=label, + num_total_classes=num_total_classes, + sampler=sampler, + custom_dist=nid_freq_arr.tolist(), + sample_weight=None, + param_attr='nce_w', + bias_attr='nce_b', + seed=1, + num_neg_samples=5, + is_sparse=is_sparse) + avg_cost = fluid.layers.mean(cost) + # optimizer + optimizer = fluid.optimizer.Adam(learning_rate=0.003) + optimizer.minimize(avg_cost) + + def net_conf(self): + import os + os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1" + self.network_with_table(is_sparse=True, is_distributed=False) + + def transpiler_test_impl(self): + trainer, _ = self.get_trainer() + for op in trainer.blocks[0].ops: + if op.type == "recv": + pass + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d21ec42dccde80fd354a730274edb04f654113c3..378654ab5b1514f7799ef899b99831a9c5cc4e76 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -242,11 +242,10 @@ class DistributeTranspiler(object): def _get_all_remote_sparse_update_op(self, main_program): sparse_update_ops = [] - sparse_update_op_types = ["lookup_table"] + sparse_update_op_types = ["lookup_table", "nce"] for op in main_program.global_block().ops: if op.type in sparse_update_op_types and op.attr( - 'remote_prefetch') is True and not op.attr( - 'is_distributed'): + 'remote_prefetch') is True: sparse_update_ops.append(op) return sparse_update_ops