diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index 88fa59c5fb99adb058c4d6e1c2cc6ac519470942..ff3734b8f0ab7ef32b98a4eee0439ae5336f8c3a 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -55,12 +55,16 @@ class LookupSparseTableOp : public framework::OperatorBase { "The type of Out var should be LodTensor."); PADDLE_ENFORCE(w_var->IsType(), "The type of W var should be SelectedRows."); - PADDLE_ENFORCE(ids_var->IsType(), + PADDLE_ENFORCE(ids_var->IsType(), "The type of Ids var should be SelectedRows."); - auto &ids_t = ids_var->Get(); + auto &ids_t = ids_var->Get(); auto out_t = out_var->GetMutable(); auto w_t = w_var->GetMutable(); - auto keys = ids_t.rows(); + std::vector keys; + keys.resize(ids_t.numel()); + for (size_t i = 0; i < ids_t.numel(); ++i) { + keys[i] = ids_t.data()[i]; + } // TODO(Yancey1989): support CUDA Place for the sparse table platform::CPUPlace cpu; @@ -68,7 +72,6 @@ class LookupSparseTableOp : public framework::OperatorBase { out_shape[0] = keys.size(); out_t->Resize(out_shape); out_t->mutable_data(cpu, w_t->value().type()); - PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), framework::proto::VarType::FP32, "The sparse table only support FP32"); diff --git a/paddle/fluid/operators/sgd_op.cc b/paddle/fluid/operators/sgd_op.cc index 06cb0550ad7d4ad0241a4f439ea9ac16d9714c38..bd04c60ffa5c1e5eb8d2051ce495ab6c685b14b5 100644 --- a/paddle/fluid/operators/sgd_op.cc +++ b/paddle/fluid/operators/sgd_op.cc @@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel { } }; +class SGDOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto input_var = op_desc.Input("Param")[0]; + for (auto& out_var : op_desc.Output("ParamOut")) { + if (block->FindRecursiveOrCreateVar(input_var).GetType() == + framework::proto::VarType::SELECTED_ROWS) { + block->FindRecursiveOrCreateVar(out_var).SetType( + framework::proto::VarType::SELECTED_ROWS); + } else { + block->FindRecursiveOrCreateVar(out_var).SetType( + framework::proto::VarType::LOD_TENSOR); + } + } + } +}; + class SGDOpMaker : public framework::OpProtoAndCheckerMaker { public: SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$ } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); +REGISTER_OPERATOR(sgd, ops::SGDOp, ops::SGDOpMaker, + paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType); REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel, ops::SGDOpKernel); diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index acaefaacdaa593c090d81084fdc1b3665314833f..3b5cf68dd4f28d23e507058337fe55de9b88d3cd 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -116,11 +116,31 @@ uniform distribution. .SetDefault(framework::proto::VarType::FP32); } }; + +class UniformRandomOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output("Out").front(); + if (block->FindRecursiveOrCreateVar(out_var_name).GetType() == + framework::proto::VarType::SELECTED_ROWS) { + block->FindRecursiveOrCreateVar(out_var_name) + .SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + block->FindRecursiveOrCreateVar(out_var_name) + .SetType(framework::proto::VarType::LOD_TENSOR); + } + } +}; + } // namespace operators } // namespace paddle -REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, - paddle::operators::UniformRandomOpMaker); +REGISTER_OPERATOR(uniform_random, paddle::operators::UniformRandomOp, + paddle::operators::UniformRandomOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::operators::UniformRandomOpVarTypeInference); + REGISTER_OP_CPU_KERNEL(uniform_random, paddle::operators::CPUUniformRandomKernel, paddle::operators::CPUUniformRandomKernel); diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index d07e0f696e79cfb98efc09a9f40d7961678b6af4..3e437ef799060d7c961d3892ebffcbccd945b03e 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -650,7 +650,7 @@ class DistributeTranspiler: shape=trainer_out.shape, dtype=trainer_out.dtype) prefetch_block.append_op( - type=LOOKUP_TABLE_TYPE, + type="lookup_sparse_table", inputs={'Ids': pserver_ids, "W": table_var}, outputs={"Out": pserver_out}, @@ -674,9 +674,17 @@ class DistributeTranspiler: # STEP: create table optimize block # create table param and grad var in pserver program - param_var = _clone_var( - pserver_program.global_block(), - self.origin_program.global_block().vars[self.table_name]) + #param_var = _clone_var( + # pserver_program.global_block(), + # self.origin_program.global_block().vars[self.table_name]) + origin_param_var = self.origin_program.global_block().vars[ + self.table_name] + param_var = pserver_program.global_block().create_var( + name=origin_param_var.name, + shape=origin_param_var.shape, + dtype=origin_param_var.dtype, + type=core.VarDesc.VarType.SELECTED_ROWS, + persistable=True) grad_var = _clone_var( pserver_program.global_block(), self.origin_program.global_block().vars[framework.grad_var_name(