提交 dccd013b 编写于 作者: Y Yancey1989

refine distribute transpiler

上级 e393c86c
......@@ -55,12 +55,16 @@ class LookupSparseTableOp : public framework::OperatorBase {
"The type of Out var should be LodTensor.");
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
"The type of W var should be SelectedRows.");
PADDLE_ENFORCE(ids_var->IsType<framework::SelectedRows>(),
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"The type of Ids var should be SelectedRows.");
auto &ids_t = ids_var->Get<framework::SelectedRows>();
auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>();
auto keys = ids_t.rows();
std::vector<int64_t> keys;
keys.resize(ids_t.numel());
for (size_t i = 0; i < ids_t.numel(); ++i) {
keys[i] = ids_t.data<int64_t>()[i];
}
// TODO(Yancey1989): support CUDA Place for the sparse table
platform::CPUPlace cpu;
......@@ -68,7 +72,6 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_shape[0] = keys.size();
out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32,
"The sparse table only support FP32");
......
......@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel {
}
};
class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
}
}
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OPERATOR(sgd, ops::SGDOp, ops::SGDOpMaker,
paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);
......@@ -116,11 +116,31 @@ uniform distribution.
.SetDefault(framework::proto::VarType::FP32);
}
};
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("Out").front();
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker);
REGISTER_OPERATOR(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::UniformRandomOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>);
......
......@@ -650,7 +650,7 @@ class DistributeTranspiler:
shape=trainer_out.shape,
dtype=trainer_out.dtype)
prefetch_block.append_op(
type=LOOKUP_TABLE_TYPE,
type="lookup_sparse_table",
inputs={'Ids': pserver_ids,
"W": table_var},
outputs={"Out": pserver_out},
......@@ -674,9 +674,17 @@ class DistributeTranspiler:
# STEP: create table optimize block
# create table param and grad var in pserver program
param_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[self.table_name])
#param_var = _clone_var(
# pserver_program.global_block(),
# self.origin_program.global_block().vars[self.table_name])
origin_param_var = self.origin_program.global_block().vars[
self.table_name]
param_var = pserver_program.global_block().create_var(
name=origin_param_var.name,
shape=origin_param_var.shape,
dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
grad_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册