提交 4d9feb35 编写于 作者: X Xin Pan

support multi grad ops

test=develop
上级 22db82c0
...@@ -204,21 +204,26 @@ framework::LoDTensor& VarBase::GradValue() { ...@@ -204,21 +204,26 @@ framework::LoDTensor& VarBase::GradValue() {
} }
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (!grad_op_desc_ && backward_id_ <= 0) { if (grad_op_descs_.empty() && backward_id_ <= 0) {
LOG(WARNING) << "op with no grad: " << op_desc_->Type(); LOG(WARNING) << "op with no grad: " << op_desc_->Type();
return {}; return {};
} }
std::map<std::string, std::vector<framework::Variable*>> grad_outputs; std::vector<framework::VariableValueMap> grad_outputs;
if (backward_id_ > 0) { if (backward_id_ > 0) {
grad_outputs.resize(1);
VLOG(3) << "py_layer_grad"; VLOG(3) << "py_layer_grad";
grad_outputs[framework::GradVarName(PyLayer::kFwdOut)] = PyLayer::ApplyGrad( grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
PyLayer::ApplyGrad(
backward_id_, backward_id_,
grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)]); grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
} else { } else {
VLOG(3) << "op grad " << grad_op_desc_->Type(); grad_outputs.resize(grad_op_descs_.size());
for (auto it : grad_output_vars_) { for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
auto& outputs = grad_outputs[it.first]; framework::OpDesc* grad_op_desc = grad_op_descs_[k];
VLOG(3) << "op grad " << grad_op_desc->Type();
for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first];
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
// Allocate a new variable // Allocate a new variable
Variable* tmp_var = new framework::Variable(); Variable* tmp_var = new framework::Variable();
...@@ -227,14 +232,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -227,14 +232,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
} }
framework::RuntimeContext ctx(grad_input_vars_, grad_outputs); framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]);
// No need to do compile time infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_); grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc_); framework::OpRegistry::CreateOp(*grad_op_desc);
framework::OperatorWithKernel* op_kernel = framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get()); dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
...@@ -244,9 +249,11 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -244,9 +249,11 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
p.op.RuntimeInferShape(scope, place_, ctx); p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
} }
}
for (auto it : grad_output_vars_) { for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
auto& outputs = grad_outputs[it.first]; for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first];
auto& origin_outputs = it.second; auto& origin_outputs = it.second;
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
...@@ -257,6 +264,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -257,6 +264,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
delete grad; delete grad;
} }
} }
}
return input_vars_; return input_vars_;
} }
......
...@@ -184,12 +184,13 @@ class OpBase { ...@@ -184,12 +184,13 @@ class OpBase {
OpBase() OpBase()
: op_desc_(nullptr), : op_desc_(nullptr),
forward_id_(-1), forward_id_(-1),
grad_op_desc_(nullptr),
backward_id_(-1), backward_id_(-1),
place_(platform::CPUPlace()) {} place_(platform::CPUPlace()) {}
virtual ~OpBase() { virtual ~OpBase() {
if (grad_op_desc_) delete grad_op_desc_; for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;
}
} }
std::map<std::string, std::vector<VarBase*>> ApplyGrad(); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
...@@ -198,9 +199,9 @@ class OpBase { ...@@ -198,9 +199,9 @@ class OpBase {
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_. // For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
framework::OpDesc* op_desc_; framework::OpDesc* op_desc_;
int forward_id_; int forward_id_;
// When has backward, one of `grad_op_desc_` or `backward_id_` is set, // When has backward, one of `grad_op_descs_` or `backward_id_` is set,
// not both. // not both.
framework::OpDesc* grad_op_desc_; std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_; int backward_id_;
platform::Place place_; platform::Place place_;
...@@ -210,8 +211,8 @@ class OpBase { ...@@ -210,8 +211,8 @@ class OpBase {
OpBasePtrMap pre_ops_; OpBasePtrMap pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_; std::map<std::string, std::vector<int>> pre_ops_out_idx_;
framework::VariableValueMap grad_input_vars_; std::vector<framework::VariableValueMap> grad_input_vars_;
framework::VariableValueMap grad_output_vars_; std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_; framework::BlockDesc* block_;
}; };
......
...@@ -24,15 +24,16 @@ namespace imperative { ...@@ -24,15 +24,16 @@ namespace imperative {
void CreateGradOp(const framework::OpDesc& op_desc, void CreateGradOp(const framework::OpDesc& op_desc,
const std::unordered_set<std::string>& no_grad_set, const std::unordered_set<std::string>& no_grad_set,
const std::vector<framework::BlockDesc*>& grad_sub_block, const std::vector<framework::BlockDesc*>& grad_sub_block,
framework::OpDesc** grad_op_desc, std::vector<framework::OpDesc*>* grad_op_descs,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs = PADDLE_ENFORCE(grad_op_descs->empty());
std::vector<std::unique_ptr<framework::OpDesc>> descs =
framework::OpInfoMap::Instance() framework::OpInfoMap::Instance()
.Get(op_desc.Type()) .Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block); .GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now."); for (auto& desc : descs) {
// TODO(panyx0718): Leak? grad_op_descs->emplace_back(desc.release());
*grad_op_desc = grad_op_descs[0].release(); }
} }
void InitVar(framework::Variable* var, framework::Variable* grad_var, void InitVar(framework::Variable* var, framework::Variable* grad_var,
...@@ -138,15 +139,16 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -138,15 +139,16 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
if (!stop_gradient) { if (!stop_gradient) {
framework::OpDesc* grad_op_desc;
// TODO(panyx): Is this leaked?
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
new std::unordered_map<std::string, std::string>()); new std::unordered_map<std::string, std::string>());
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get()); CreateGradOp(*op_desc, {}, {block}, &op->grad_op_descs_, grad_to_var.get());
op->grad_op_desc_ = grad_op_desc;
op->grad_input_vars_.resize(op->grad_op_descs_.size());
op->grad_output_vars_.resize(op->grad_op_descs_.size());
for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) {
framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
for (auto it : grad_op_desc->Inputs()) { for (auto it : grad_op_desc->Inputs()) {
auto& grad_in_vars = op->grad_input_vars_[it.first]; auto& grad_in_vars = op->grad_input_vars_[i][it.first];
for (const std::string& grad_invar : it.second) { for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar); block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = grad_to_var->find(grad_invar); auto var_it = grad_to_var->find(grad_invar);
...@@ -168,7 +170,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -168,7 +170,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} }
for (auto it : grad_op_desc->Outputs()) { for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[it.first]; auto& grad_out_vars = op->grad_output_vars_[i][it.first];
for (const std::string& grad_outvar : it.second) { for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar); block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar); auto var_it = grad_to_var->find(grad_outvar);
...@@ -178,12 +180,14 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -178,12 +180,14 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op_desc->Type()); op_desc->Type());
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext()); InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
} }
grad_out_vars.push_back(var->grads_->var_); grad_out_vars.push_back(var->grads_->var_);
} }
} }
} }
}
op->block_ = block; op->block_ = block;
} }
...@@ -209,10 +213,12 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -209,10 +213,12 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient); out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
} }
if (!stop_gradient) { if (!stop_gradient) {
op->grad_input_vars_.resize(1);
op->grad_output_vars_.resize(1);
auto& grad_input_vars = auto& grad_input_vars =
op->grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)]; op->grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)];
auto& grad_output_vars = auto& grad_output_vars =
op->grad_output_vars_[framework::GradVarName(PyLayer::kFwdOut)]; op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) { for (const VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_); grad_input_vars.push_back(inp->var_);
......
...@@ -67,6 +67,21 @@ class MLP(fluid.imperative.Layer): ...@@ -67,6 +67,21 @@ class MLP(fluid.imperative.Layer):
class TestImperative(unittest.TestCase): class TestImperative(unittest.TestCase):
def test_sum_op(self):
with fluid.imperative.guard():
inputs = []
for _ in range(10):
inputs.append(
fluid.imperative.base.to_variable(
np.ones([2, 2], np.float32)))
sys.stderr.write('%s\n' % inputs[0].dtype)
ret = fluid.layers.sums(inputs)
sys.stderr.write('%s\n' % ret.dtype)
loss = fluid.layers.reduce_sum(ret)
sys.stderr.write('%s\n' % loss.dtype)
loss._backward()
sys.stderr.write('%s %s\n' % (ret._numpy(), inputs[0]._gradient()))
def test_layer(self): def test_layer(self):
with fluid.imperative.guard(): with fluid.imperative.guard():
cl = core.Layer() cl = core.Layer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册