提交 dccd013b 编写于 作者: Y Yancey1989

refine distribute transpiler

上级 e393c86c
...@@ -55,12 +55,16 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -55,12 +55,16 @@ class LookupSparseTableOp : public framework::OperatorBase {
"The type of Out var should be LodTensor."); "The type of Out var should be LodTensor.");
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
"The type of W var should be SelectedRows."); "The type of W var should be SelectedRows.");
PADDLE_ENFORCE(ids_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"The type of Ids var should be SelectedRows."); "The type of Ids var should be SelectedRows.");
auto &ids_t = ids_var->Get<framework::SelectedRows>(); auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>(); auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>(); auto w_t = w_var->GetMutable<framework::SelectedRows>();
auto keys = ids_t.rows(); std::vector<int64_t> keys;
keys.resize(ids_t.numel());
for (size_t i = 0; i < ids_t.numel(); ++i) {
keys[i] = ids_t.data<int64_t>()[i];
}
// TODO(Yancey1989): support CUDA Place for the sparse table // TODO(Yancey1989): support CUDA Place for the sparse table
platform::CPUPlace cpu; platform::CPUPlace cpu;
...@@ -68,7 +72,6 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -68,7 +72,6 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_shape[0] = keys.size(); out_shape[0] = keys.size();
out_t->Resize(out_shape); out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type()); out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
"The sparse table only support FP32"); "The sparse table only support FP32");
......
...@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel {
} }
}; };
class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
}
}
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker) SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
...@@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$ ...@@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OPERATOR(sgd, ops::SGDOp, ops::SGDOpMaker,
paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>); REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);
...@@ -116,11 +116,31 @@ uniform distribution. ...@@ -116,11 +116,31 @@ uniform distribution.
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
} }
}; };
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("Out").front();
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OPERATOR(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::UniformRandomOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>, paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>); paddle::operators::CPUUniformRandomKernel<double>);
......
...@@ -650,7 +650,7 @@ class DistributeTranspiler: ...@@ -650,7 +650,7 @@ class DistributeTranspiler:
shape=trainer_out.shape, shape=trainer_out.shape,
dtype=trainer_out.dtype) dtype=trainer_out.dtype)
prefetch_block.append_op( prefetch_block.append_op(
type=LOOKUP_TABLE_TYPE, type="lookup_sparse_table",
inputs={'Ids': pserver_ids, inputs={'Ids': pserver_ids,
"W": table_var}, "W": table_var},
outputs={"Out": pserver_out}, outputs={"Out": pserver_out},
...@@ -674,9 +674,17 @@ class DistributeTranspiler: ...@@ -674,9 +674,17 @@ class DistributeTranspiler:
# STEP: create table optimize block # STEP: create table optimize block
# create table param and grad var in pserver program # create table param and grad var in pserver program
param_var = _clone_var( #param_var = _clone_var(
pserver_program.global_block(), # pserver_program.global_block(),
self.origin_program.global_block().vars[self.table_name]) # self.origin_program.global_block().vars[self.table_name])
origin_param_var = self.origin_program.global_block().vars[
self.table_name]
param_var = pserver_program.global_block().create_var(
name=origin_param_var.name,
shape=origin_param_var.shape,
dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
grad_var = _clone_var( grad_var = _clone_var(
pserver_program.global_block(), pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name( self.origin_program.global_block().vars[framework.grad_var_name(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册