提交 4d9feb35 编写于 作者: X Xin Pan

support multi grad ops

test=develop
上级 22db82c0
...@@ -204,59 +204,68 @@ framework::LoDTensor& VarBase::GradValue() { ...@@ -204,59 +204,68 @@ framework::LoDTensor& VarBase::GradValue() {
} }
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (!grad_op_desc_ && backward_id_ <= 0) { if (grad_op_descs_.empty() && backward_id_ <= 0) {
LOG(WARNING) << "op with no grad: " << op_desc_->Type(); LOG(WARNING) << "op with no grad: " << op_desc_->Type();
return {}; return {};
} }
std::map<std::string, std::vector<framework::Variable*>> grad_outputs; std::vector<framework::VariableValueMap> grad_outputs;
if (backward_id_ > 0) { if (backward_id_ > 0) {
grad_outputs.resize(1);
VLOG(3) << "py_layer_grad"; VLOG(3) << "py_layer_grad";
grad_outputs[framework::GradVarName(PyLayer::kFwdOut)] = PyLayer::ApplyGrad( grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
backward_id_, PyLayer::ApplyGrad(
grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)]); backward_id_,
grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
} else { } else {
VLOG(3) << "op grad " << grad_op_desc_->Type(); grad_outputs.resize(grad_op_descs_.size());
for (auto it : grad_output_vars_) { for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
auto& outputs = grad_outputs[it.first]; framework::OpDesc* grad_op_desc = grad_op_descs_[k];
for (size_t i = 0; i < it.second.size(); ++i) { VLOG(3) << "op grad " << grad_op_desc->Type();
// Allocate a new variable for (auto it : grad_output_vars_[k]) {
Variable* tmp_var = new framework::Variable(); auto& outputs = grad_outputs[k][it.first];
tmp_var->GetMutable<framework::LoDTensor>(); for (size_t i = 0; i < it.second.size(); ++i) {
outputs.push_back(tmp_var); // Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
outputs.push_back(tmp_var);
}
} }
}
framework::RuntimeContext ctx(grad_input_vars_, grad_outputs); framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]);
// No need to do compile time infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_); grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc_); framework::OpRegistry::CreateOp(*grad_op_desc);
framework::OperatorWithKernel* op_kernel = framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get()); dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope; framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx); p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
}
} }
for (auto it : grad_output_vars_) { for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
auto& outputs = grad_outputs[it.first]; for (auto it : grad_output_vars_[k]) {
auto& origin_outputs = it.second; auto& outputs = grad_outputs[k][it.first];
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); auto& origin_outputs = it.second;
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i]; for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* orig_grad = origin_outputs[i]; framework::Variable* grad = outputs[i];
AddTo(grad, orig_grad, place_); framework::Variable* orig_grad = origin_outputs[i];
delete grad; AddTo(grad, orig_grad, place_);
delete grad;
}
} }
} }
return input_vars_; return input_vars_;
} }
......
...@@ -184,12 +184,13 @@ class OpBase { ...@@ -184,12 +184,13 @@ class OpBase {
OpBase() OpBase()
: op_desc_(nullptr), : op_desc_(nullptr),
forward_id_(-1), forward_id_(-1),
grad_op_desc_(nullptr),
backward_id_(-1), backward_id_(-1),
place_(platform::CPUPlace()) {} place_(platform::CPUPlace()) {}
virtual ~OpBase() { virtual ~OpBase() {
if (grad_op_desc_) delete grad_op_desc_; for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;
}
} }
std::map<std::string, std::vector<VarBase*>> ApplyGrad(); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
...@@ -198,9 +199,9 @@ class OpBase { ...@@ -198,9 +199,9 @@ class OpBase {
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_. // For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
framework::OpDesc* op_desc_; framework::OpDesc* op_desc_;
int forward_id_; int forward_id_;
// When has backward, one of `grad_op_desc_` or `backward_id_` is set, // When has backward, one of `grad_op_descs_` or `backward_id_` is set,
// not both. // not both.
framework::OpDesc* grad_op_desc_; std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_; int backward_id_;
platform::Place place_; platform::Place place_;
...@@ -210,8 +211,8 @@ class OpBase { ...@@ -210,8 +211,8 @@ class OpBase {
OpBasePtrMap pre_ops_; OpBasePtrMap pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_; std::map<std::string, std::vector<int>> pre_ops_out_idx_;
framework::VariableValueMap grad_input_vars_; std::vector<framework::VariableValueMap> grad_input_vars_;
framework::VariableValueMap grad_output_vars_; std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_; framework::BlockDesc* block_;
}; };
......
...@@ -24,15 +24,16 @@ namespace imperative { ...@@ -24,15 +24,16 @@ namespace imperative {
void CreateGradOp(const framework::OpDesc& op_desc, void CreateGradOp(const framework::OpDesc& op_desc,
const std::unordered_set<std::string>& no_grad_set, const std::unordered_set<std::string>& no_grad_set,
const std::vector<framework::BlockDesc*>& grad_sub_block, const std::vector<framework::BlockDesc*>& grad_sub_block,
framework::OpDesc** grad_op_desc, std::vector<framework::OpDesc*>* grad_op_descs,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs = PADDLE_ENFORCE(grad_op_descs->empty());
std::vector<std::unique_ptr<framework::OpDesc>> descs =
framework::OpInfoMap::Instance() framework::OpInfoMap::Instance()
.Get(op_desc.Type()) .Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block); .GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now."); for (auto& desc : descs) {
// TODO(panyx0718): Leak? grad_op_descs->emplace_back(desc.release());
*grad_op_desc = grad_op_descs[0].release(); }
} }
void InitVar(framework::Variable* var, framework::Variable* grad_var, void InitVar(framework::Variable* var, framework::Variable* grad_var,
...@@ -138,49 +139,52 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -138,49 +139,52 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
if (!stop_gradient) { if (!stop_gradient) {
framework::OpDesc* grad_op_desc;
// TODO(panyx): Is this leaked?
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
new std::unordered_map<std::string, std::string>()); new std::unordered_map<std::string, std::string>());
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get()); CreateGradOp(*op_desc, {}, {block}, &op->grad_op_descs_, grad_to_var.get());
op->grad_op_desc_ = grad_op_desc;
op->grad_input_vars_.resize(op->grad_op_descs_.size());
for (auto it : grad_op_desc->Inputs()) { op->grad_output_vars_.resize(op->grad_op_descs_.size());
auto& grad_in_vars = op->grad_input_vars_[it.first]; for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) {
for (const std::string& grad_invar : it.second) { framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
block->FindRecursiveOrCreateVar(grad_invar); for (auto it : grad_op_desc->Inputs()) {
auto var_it = grad_to_var->find(grad_invar); auto& grad_in_vars = op->grad_input_vars_[i][it.first];
if (var_it == grad_to_var->end()) { for (const std::string& grad_invar : it.second) {
auto fwd_var_it = vars.find(grad_invar); block->FindRecursiveOrCreateVar(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end()); auto var_it = grad_to_var->find(grad_invar);
// Forward inputs or outputs. if (var_it == grad_to_var->end()) {
grad_in_vars.push_back(fwd_var_it->second->var_); auto fwd_var_it = vars.find(grad_invar);
} else { PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_);
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
// Douts.
grad_in_vars.push_back(var->grads_->var_);
}
}
}
for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[i][it.first];
for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end(),
"Could not found the grad op output var, should this "
"operator %s's stop gradient be True",
op_desc->Type());
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_, InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext()); prepared_op.GetDeviceContext());
} }
// Douts. grad_out_vars.push_back(var->grads_->var_);
grad_in_vars.push_back(var->grads_->var_);
}
}
}
for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end(),
"Could not found the grad op output var, should this "
"operator %s's stop gradient be True",
op_desc->Type());
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext());
} }
grad_out_vars.push_back(var->grads_->var_);
} }
} }
} }
...@@ -209,10 +213,12 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -209,10 +213,12 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient); out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
} }
if (!stop_gradient) { if (!stop_gradient) {
op->grad_input_vars_.resize(1);
op->grad_output_vars_.resize(1);
auto& grad_input_vars = auto& grad_input_vars =
op->grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)]; op->grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)];
auto& grad_output_vars = auto& grad_output_vars =
op->grad_output_vars_[framework::GradVarName(PyLayer::kFwdOut)]; op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) { for (const VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_); grad_input_vars.push_back(inp->var_);
......
...@@ -67,6 +67,21 @@ class MLP(fluid.imperative.Layer): ...@@ -67,6 +67,21 @@ class MLP(fluid.imperative.Layer):
class TestImperative(unittest.TestCase): class TestImperative(unittest.TestCase):
def test_sum_op(self):
with fluid.imperative.guard():
inputs = []
for _ in range(10):
inputs.append(
fluid.imperative.base.to_variable(
np.ones([2, 2], np.float32)))
sys.stderr.write('%s\n' % inputs[0].dtype)
ret = fluid.layers.sums(inputs)
sys.stderr.write('%s\n' % ret.dtype)
loss = fluid.layers.reduce_sum(ret)
sys.stderr.write('%s\n' % loss.dtype)
loss._backward()
sys.stderr.write('%s %s\n' % (ret._numpy(), inputs[0]._gradient()))
def test_layer(self): def test_layer(self):
with fluid.imperative.guard(): with fluid.imperative.guard():
cl = core.Layer() cl = core.Layer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册