From ce7e503cbe10dee0f3cad2145bec4559ab89f00f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 25 Dec 2018 14:40:55 +0800 Subject: [PATCH] refactor to avoid scope. test=develop --- paddle/fluid/framework/operator.cc | 60 +++++- paddle/fluid/framework/operator.h | 10 + paddle/fluid/imperative/layer.cc | 188 ++++++++---------- paddle/fluid/imperative/layer.h | 45 +++-- paddle/fluid/imperative/tracer.h | 120 ++++++++--- paddle/fluid/operators/fill_constant_op.cc | 35 ++++ paddle/fluid/pybind/pybind.cc | 12 +- python/paddle/fluid/framework.py | 37 ++-- python/paddle/fluid/imperative/base.py | 3 +- python/paddle/fluid/layer_helper.py | 21 +- python/paddle/fluid/layers/nn.py | 2 + .../fluid/tests/unittests/test_imperative.py | 13 +- 12 files changed, 347 insertions(+), 199 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2e7006ed953..38675d2cac0 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -180,6 +180,11 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { VLOG(3) << place << " " << DebugStringEx(&scope); } +void OperatorBase::Run(const RuntimeContext& ctx, + const platform::Place& place) { + RunImpl(ctx, place); +} + bool OperatorBase::HasInputs(const std::string& name) const { return inputs_.find(name) != inputs_.end(); } @@ -954,6 +959,51 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } +void OperatorWithKernel::RunImpl(const RuntimeContext& ctx, + const platform::Place& place) const { + Scope scope; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + + // check if op[type] has kernel registered. + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + if (kernels_iter == all_op_kernels.end()) { + PADDLE_THROW( + "There are no kernels which are registered in the %s operator.", type_); + } + + OpKernelMap& kernels = kernels_iter->second; + + auto expected_kernel_key = this->GetExpectedKernelType( + ExecutionContext(*this, scope, *dev_ctx, ctx)); + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + + auto kernel_iter = kernels.find(expected_kernel_key); +#ifdef PADDLE_WITH_MKLDNN + // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set + if (kernel_iter == kernels.end() && + expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { + VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; + expected_kernel_key.library_type_ = LibraryType::kPlain; + expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; + kernel_iter = kernels.find(expected_kernel_key); + } +#endif + if (kernel_iter == kernels.end()) { + PADDLE_THROW("op %s does not have kernel for %s", type_, + KernelTypeToString(expected_kernel_key)); + } + + if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { + dev_ctx = pool.Get(expected_kernel_key.place_); + } + + RuntimeInferShapeContext infer_shape_ctx(*this, scope, ctx); + this->InferShape(&infer_shape_ctx); + kernel_iter->second(ExecutionContext(*this, scope, *dev_ctx, ctx)); +} + void OperatorWithKernel::TransferInplaceVarsBack( const Scope& scope, const std::vector& inplace_vars, const Scope& transfer_scope) const { @@ -1041,12 +1091,9 @@ Scope* OperatorWithKernel::PrepareData( proto::VarType::Type OperatorWithKernel::IndicateDataType( const ExecutionContext& ctx) const { - auto& scope = ctx.scope(); int data_type = -1; - std::string last_input_name; for (auto& input : this->inputs_) { - for (auto& ipt_name : input.second) { - auto* var = scope.FindVar(ipt_name); + for (const Variable* var : ctx.MultiInputVar(input.first)) { if (var != nullptr) { const Tensor* t = nullptr; if (var->IsType()) { @@ -1062,10 +1109,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( int tmp = static_cast(t->type()); PADDLE_ENFORCE( tmp == data_type || data_type == -1, - "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", - Type(), last_input_name, data_type, ipt_name, tmp); + "DataType of Paddle Op %s must be the same. Get (%d) != (%d)", + Type(), data_type, tmp); data_type = tmp; - last_input_name = ipt_name; } } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index bad9716e8b4..446d27efa04 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -81,6 +81,10 @@ class RuntimeContext { RuntimeContext(const VariableNameMap& innames, const VariableNameMap& outnames, const Scope& scope); + RuntimeContext(const VariableValueMap& invars, + const VariableValueMap& outvars) + : inputs(invars), outputs(outvars) {} + VariableValueMap inputs; VariableValueMap outputs; }; @@ -101,6 +105,7 @@ class OperatorBase { /// Executor will call this interface function to Run an op. // The implementation should be written at RunImpl void Run(const Scope& scope, const platform::Place& place); + void Run(const RuntimeContext& ctx, const platform::Place& place); // FIXME(typhoonzero): this is only used for recv_op to stop event_loop. virtual void Stop() {} @@ -167,6 +172,9 @@ class OperatorBase { void CheckAllInputOutputSet() const; virtual void RunImpl(const Scope& scope, const platform::Place& place) const = 0; + + virtual void RunImpl(const RuntimeContext& ctx, + const platform::Place& place) const {} }; class ExecutionContext { @@ -458,6 +466,8 @@ class OperatorWithKernel : public OperatorBase { // same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; void RunImpl(const Scope& scope, const platform::Place& place) const final; + void RunImpl(const RuntimeContext& ctx, + const platform::Place& place) const final; /** * Transfer data from scope to a transfered scope. If there is no data need to diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 342cb68ab2b..239ff029dba 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -31,6 +31,11 @@ using framework::Variable; void AddTo(Variable* src, Variable* dst) { framework::LoDTensor* dst_tensor = dst->GetMutable(); framework::LoDTensor* src_tensor = src->GetMutable(); + + VLOG(3) << "apply var grad " << src_tensor->data()[0] << " " + << src_tensor->data()[1] << " " + << src_tensor->data()[2]; + PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "%lld vs %lld", dst_tensor->numel(), src_tensor->numel()); float* dst_data = dst_tensor->mutable_data(platform::CPUPlace()); @@ -38,16 +43,28 @@ void AddTo(Variable* src, Variable* dst) { for (size_t i = 0; i < src_tensor->numel(); ++i) { dst_data[i] += src_data[i]; } + + VLOG(3) << "apply var dst grad " << dst_tensor->data()[0] << " " + << dst_tensor->data()[1] << " " + << dst_tensor->data()[2]; } class Autograd { public: - explicit Autograd(framework::Scope* scope) : scope_(scope) {} + Autograd() {} void RunBackward(VarBase* var) { PADDLE_ENFORCE(var->pre_op_->op_desc_); // TODO(panyx0718): Only create for vars that "require_grad" - (*var->pre_op_->output_vars_)[var->pre_op_out_idx_]->grads_ = var->grads_; + LOG(ERROR) << reinterpret_cast(var->grads_) << " vs " + << reinterpret_cast( + var->pre_op_ + ->output_vars_[var->pre_op_out_name_] + [var->pre_op_out_idx_] + ->grads_); + var->pre_op_->output_vars_[var->pre_op_out_name_][var->pre_op_out_idx_] + ->grads_->GetMutable() + ->ShareDataWith(var->grads_->Get()); std::deque ready; ready.push_back(var->pre_op_); @@ -57,18 +74,23 @@ class Autograd { while (!ready.empty()) { OpBase* ready_op = ready.front(); ready.pop_front(); - std::vector input_grads = ready_op->ApplyGrad(scope_); - - for (size_t i = 0; i < input_grads.size(); ++i) { - if (!input_grads[i]) continue; - OpBase* pre_op = ready_op->pre_ops_->at(i); - if (!pre_op) continue; - - dep_counts[pre_op] -= 1; - PADDLE_ENFORCE(dep_counts[pre_op] >= 0); - bool pre_op_ready = dep_counts[pre_op] == 0; - if (pre_op_ready) { - ready.push_back(pre_op); + std::map> input_grads = + ready_op->ApplyGrad(); + VLOG(3) << "after apply grad"; + + for (auto it : input_grads) { + const std::vector& ingrads = it.second; + for (size_t i = 0; i < ingrads.size(); ++i) { + if (!ingrads[i]) continue; + OpBase* pre_op = (*ready_op->pre_ops_)[it.first][i]; + if (!pre_op) continue; + + dep_counts[pre_op] -= 1; + PADDLE_ENFORCE(dep_counts[pre_op] >= 0); + bool pre_op_ready = dep_counts[pre_op] == 0; + if (pre_op_ready) { + ready.push_back(pre_op); + } } } } @@ -85,26 +107,25 @@ class Autograd { while (!queue.empty()) { OpBase* candidate = queue.front(); queue.pop_front(); - for (OpBase* pre_op : *(candidate->pre_ops_)) { - if (!pre_op) continue; - if (visited.find(pre_op) == visited.end()) { - visited.insert(pre_op); - queue.push_back(pre_op); + for (auto it : *(candidate->pre_ops_)) { + for (OpBase* pre_op : it.second) { + if (!pre_op) continue; + if (visited.find(pre_op) == visited.end()) { + visited.insert(pre_op); + queue.push_back(pre_op); + } + ret[pre_op] += 1; } - ret[pre_op] += 1; } } - return ret; } - - framework::Scope* scope_; }; -framework::Variable* CreateVariable(const std::string& name, - const framework::DDim& dim, float val, - framework::Scope* scope, - bool random_name = true) { +void CreateVariable(const std::string& name, const framework::DDim& dim, + float val, bool random_name, framework::Variable* var) { + if (var->IsInitialized()) return; + std::string varname = name; if (random_name) { std::mt19937 rng; @@ -116,12 +137,9 @@ framework::Variable* CreateVariable(const std::string& name, } VLOG(3) << "creating var " << varname; - framework::Variable* var = scope->Var(varname); framework::LoDTensor* tensor = var->GetMutable(); - float* data = tensor->mutable_data(dim, platform::CPUPlace()); std::fill(data, data + tensor->numel(), val); - return var; } framework::LoDTensor& VarBase::Grad() { @@ -129,94 +147,56 @@ framework::LoDTensor& VarBase::Grad() { return *grads_->GetMutable(); } -void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) { - VLOG(3) << "apply var grad " << var_desc_->Name() << " " - << grad->Get().data()[0]; - if (!grads_) { - grads_ = - CreateVariable(string::Sprintf("%s@IGrad", var_desc_->Name()), - var_->Get().dims(), 0.0, scope); +std::map> OpBase::ApplyGrad() { + if (!grad_op_desc_) { + VLOG(3) << "op with no grad: " << op_desc_->Type(); + return {}; } - AddTo(grad, grads_); - VLOG(3) << "grad_ after apply var grad " << var_desc_->Name() << " " - << grads_->Get().data()[0]; -} - -std::vector OpBase::ApplyGrad(framework::Scope* scope) { VLOG(3) << "op grad " << grad_op_desc_->Type(); - for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) { - if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) { - // grad op inputs can be forward inputs, so not in grad_to_var. - continue; - } - VLOG(3) << "op grad in var " << grad_invar; - block_->FindRecursiveOrCreateVar(grad_invar); - framework::Variable* var = scope->Var(grad_invar); - const std::string& invar = grad_to_var_->at(grad_invar); - for (VarBase* varbase : *output_vars_) { - // Use the accumulated grads_ by sharing the input with grads_. - if (varbase->var_desc_->Name() == invar) { - var->GetMutable()->ShareDataWith( - varbase->grads_->Get()); - break; - } + std::map> grad_outputs; + for (auto it : grad_output_vars_) { + auto& outputs = grad_outputs[it.first]; + for (size_t i = 0; i < it.second.size(); ++i) { + outputs.push_back(new framework::Variable()); + outputs.back()->GetMutable(); + /* + auto& accum_grad_t = it.second[i]->Get(); + Variable* grad_var = outputs.back(); + float* data = grad_var->GetMutable() + ->mutable_data(accum_grad_t.dims(), platform::CPUPlace()); + std::fill(data, data + accum_grad_t.numel(), 0.0);*/ } } - for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { - VLOG(3) << "grad outvar " << outvar; - block_->FindRecursiveOrCreateVar(outvar); - framework::Variable* var = scope->Var(outvar); - if (!var->IsInitialized()) { - framework::VarDesc* var_desc = block_->FindVar(outvar); - if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { - var->GetMutable(); - } else { - LOG(ERROR) << "tracer doesn't support yet"; - } - } - } - grad_op_desc_->InferShape(*block_); + framework::RuntimeContext ctx(grad_input_vars_, grad_outputs); + + // grad_op_desc_->InferShape(*block_); grad_op_desc_->InferVarType(block_); + std::unique_ptr opbase = framework::OpRegistry::CreateOp(*grad_op_desc_); - - opbase->Run(*scope, platform::CPUPlace()); - - // `ret` matches exactly with `input_vars_` of forward op. - std::vector ret; - for (size_t i = 0; i < input_vars_->size(); ++i) { - bool found = false; - VarBase* origin_var = (*input_vars_)[i]; - for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { - Variable* var = scope->FindVar(outvar); - std::string orig_var = grad_to_var_->at(outvar); - if (origin_var->var_desc_->Name() != orig_var) { - continue; - } - VLOG(3) << "apply grad " << outvar << " with origin " << orig_var; - origin_var->ApplyGrad(scope, var); - found = true; - ret.push_back(var); - // TODO(panyx0718): There might be another outvar with the same name. - // In that case, it doesn't matter the first one or the second one is - // used. - break; - } - if (!found) { - ret.push_back(nullptr); + opbase->Run(ctx, platform::CPUPlace()); + + for (auto it : grad_output_vars_) { + auto& outputs = grad_outputs[it.first]; + auto& origin_outputs = it.second; + for (size_t i = 0; i < outputs.size(); ++i) { + framework::Variable* orig_grad = origin_outputs[i]; + AddTo(outputs[i], orig_grad); + VLOG(3) << "done add to " << grad_op_desc_->Outputs().at(it.first)[i]; } } - return ret; + return input_vars_; } -void VarBase::RunBackward(framework::Scope* scope) { - grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()), - var_->Get().dims(), 1.0, scope, - false); +void VarBase::RunBackward() { + auto grads_t = grads_->GetMutable(); + float* data = grads_t->mutable_data(platform::CPUPlace()); + std::fill(data, data + grads_t->numel(), 1.0); + if (!pre_op_) return; - Autograd(scope).RunBackward(this); + Autograd().RunBackward(this); } } // namespace imperative diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 85a71ca83d2..eb5fd553bdc 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -14,11 +14,11 @@ #pragma once +#include #include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" @@ -33,18 +33,26 @@ class VarBase { : pre_op_(nullptr), pre_op_out_idx_(-1), var_desc_(nullptr), - var_(nullptr), - grads_(nullptr) {} - - virtual ~VarBase() {} - - void ApplyGrad(framework::Scope* scope, framework::Variable* grad); + var_(new framework::Variable()), + grads_(new framework::Variable()) {} + + virtual ~VarBase() { + if (var_) { + delete var_; + var_ = nullptr; + } + if (grads_) { + delete grads_; + grads_ = nullptr; + } + } - void RunBackward(framework::Scope* scope); + void RunBackward(); framework::LoDTensor& Grad(); OpBase* pre_op_; + std::string pre_op_out_name_; int pre_op_out_idx_; framework::VarDesc* var_desc_; @@ -55,17 +63,12 @@ class VarBase { class OpBase { public: OpBase() - : input_vars_(new std::vector()), - output_vars_(new std::vector()), - pre_ops_(new std::vector()), - pre_ops_out_idx_(new std::vector()), + : pre_ops_(new std::map>()), + pre_ops_out_idx_(new std::map>()), op_desc_(nullptr), grad_op_desc_(nullptr) {} virtual ~OpBase() { - delete input_vars_; - delete output_vars_; - delete pre_ops_; delete pre_ops_out_idx_; @@ -73,16 +76,18 @@ class OpBase { if (grad_to_var_) delete grad_to_var_; } - std::vector ApplyGrad(framework::Scope* scope); + std::map> ApplyGrad(); - std::vector* input_vars_; - std::vector* output_vars_; - std::vector* pre_ops_; - std::vector* pre_ops_out_idx_; + std::map> input_vars_; + std::map> output_vars_; + std::map>* pre_ops_; + std::map>* pre_ops_out_idx_; framework::OpDesc* op_desc_; framework::OpDesc* grad_op_desc_; std::unordered_map* grad_to_var_; + std::map> grad_input_vars_; + std::map> grad_output_vars_; framework::BlockDesc* block_; }; diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 97772dc1101..e7a60621cd5 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -41,6 +41,14 @@ void CreateGradOp(const framework::OpDesc& op_desc, *grad_op_desc = grad_op_descs[0].release(); } +void InitVar(framework::Variable* var, framework::Variable* grad_var) { + auto& var_t = var->Get(); + float* data = + grad_var->GetMutable()->mutable_data( + var_t.dims(), platform::CPUPlace()); + std::fill(data, data + var_t.numel(), 0.0); +} + class Tracer { public: explicit Tracer(framework::BlockDesc* root_block, @@ -53,10 +61,13 @@ class Tracer { virtual ~Tracer() { delete root_scope_; } - void Trace(OpBase* op, const std::vector& inputs, - const std::vector& outputs, + void Trace(OpBase* op, + const std::map>& inputs, + const std::map>& outputs, framework::BlockDesc* block) { - framework::Scope* scope = GetScope(block); + // framework::Scope* scope = GetScope(block); + std::map vars; + framework::OpDesc* op_desc = op->op_desc_; VLOG(3) << "tracer tracing " << op_desc->Type(); op_desc->InferShape(*block); @@ -64,48 +75,60 @@ class Tracer { std::unique_ptr op_base = framework::OpRegistry::CreateOp(*op_desc); - *op->input_vars_ = inputs; - for (VarBase* input : inputs) { - const std::string vname = input->var_desc_->Name(); - framework::Variable* var = scope->Var(vname); - input->var_ = var; - if (!var->IsInitialized()) { - framework::VarDesc* var_desc = block->FindVar(vname); - if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { - var->GetMutable(); + framework::VariableValueMap invars_map; + framework::VariableValueMap outvars_map; + + op->input_vars_ = inputs; + for (auto it : op->input_vars_) { + auto& invars = invars_map[it.first]; + for (VarBase* inp : it.second) { + PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", + op->op_desc_->Type(), inp->var_desc_->Name()); + + invars.push_back(inp->var_); + vars[inp->var_desc_->Name()] = inp; + if (inp->pre_op_) { + (*op->pre_ops_)[it.first].push_back(inp->pre_op_); + (*op->pre_ops_out_idx_)[it.first].push_back(inp->pre_op_out_idx_); } else { - LOG(ERROR) << "tracer doesn't support yet"; + (*op->pre_ops_)[it.first].push_back(nullptr); } + VLOG(3) << "input vname " << inp->var_desc_->Name() << " " + << inp->var_->Get().dims().size() + << reinterpret_cast(inp->var_); } - if (input->pre_op_) { - op->pre_ops_->push_back(input->pre_op_); - op->pre_ops_out_idx_->push_back(input->pre_op_out_idx_); - } else { - op->pre_ops_->push_back(nullptr); - } - VLOG(3) << "input vname " << vname << " " - << var->Get().dims().size(); } - *op->output_vars_ = outputs; - for (size_t i = 0; i < outputs.size(); ++i) { - const std::string vname = outputs[i]->var_desc_->Name(); - framework::Variable* var = scope->Var(vname); - if (!var->IsInitialized()) { - framework::VarDesc* var_desc = block->FindVar(vname); + op->output_vars_ = outputs; + for (auto it : op->output_vars_) { + auto& outvars = outvars_map[it.first]; + const std::vector& outputs = it.second; + for (size_t i = 0; i < outputs.size(); ++i) { + VarBase* out = outputs[i]; + outvars.push_back(out->var_); + vars[out->var_desc_->Name()] = out; + + framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { - var->GetMutable(); + out->var_->GetMutable(); } else { LOG(ERROR) << "tracer doesn't support yet"; } + out->pre_op_ = op; + out->pre_op_out_name_ = it.first; + out->pre_op_out_idx_ = i; + + VLOG(3) << "output vname " << out->var_desc_->Name() << " " + << out->var_->Get().dims().size() << " " + << reinterpret_cast(out->var_) << " " + << out->var_->IsInitialized(); } - outputs[i]->var_ = var; - outputs[i]->pre_op_ = op; - outputs[i]->pre_op_out_idx_ = i; } VLOG(3) << "tracer running " << op_desc->Type(); - op_base->Run(*scope, platform::CPUPlace()); + framework::RuntimeContext ctx(invars_map, outvars_map); + op_base->Run(ctx, platform::CPUPlace()); + if (block == startup_block_) { op->grad_op_desc_ = nullptr; op->grad_to_var_ = nullptr; @@ -115,6 +138,39 @@ class Tracer { CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var); op->grad_op_desc_ = grad_op_desc; op->grad_to_var_ = grad_to_var; + + for (auto it : grad_op_desc->Inputs()) { + auto& grad_in_vars = op->grad_input_vars_[it.first]; + for (const std::string& grad_invar : it.second) { + block->FindRecursiveOrCreateVar(grad_invar); + auto var_it = op->grad_to_var_->find(grad_invar); + if (var_it == op->grad_to_var_->end()) { + auto fwd_var_it = vars.find(grad_invar); + PADDLE_ENFORCE(fwd_var_it != vars.end()); + grad_in_vars.push_back(fwd_var_it->second->var_); + } else { + VarBase* var = vars[var_it->second]; + if (!var->grads_->IsInitialized()) { + InitVar(var->var_, var->grads_); + } + grad_in_vars.push_back(var->grads_); + } + } + } + for (auto it : grad_op_desc->Outputs()) { + auto& grad_out_vars = op->grad_output_vars_[it.first]; + for (const std::string& grad_outvar : it.second) { + block->FindRecursiveOrCreateVar(grad_outvar); + auto var_it = op->grad_to_var_->find(grad_outvar); + PADDLE_ENFORCE(var_it != op->grad_to_var_->end()); + VarBase* var = vars[var_it->second]; + if (!var->grads_->IsInitialized()) { + InitVar(var->var_, var->grads_); + } + LOG(ERROR) << grad_outvar << " map to " << var->var_desc_->Name(); + grad_out_vars.push_back(var->grads_); + } + } } op->block_ = block; } diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 38cb33e7904..7b04c5d21f4 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -68,6 +68,41 @@ class FillConstantOp : public framework::OperatorBase { auto &dev_ctx = *pool.Get(dev_place); math::set_constant(dev_ctx, tensor, value); } + + void RunImpl(const framework::RuntimeContext &ctx, + const platform::Place &dev_place) const override { + auto data_type = + static_cast(Attr("dtype")); + auto value = Attr("value"); + auto force_cpu = Attr("force_cpu"); + + framework::Tensor *tensor = nullptr; + + auto &out_var = *ctx.outputs.at("Out")[0]; + + if (out_var.IsType()) { + tensor = out_var.GetMutable(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else if (out_var.IsType()) { + tensor = out_var.GetMutable()->mutable_value(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else { + PADDLE_THROW( + "fill constant op's output only" + "supports SelectedRows and LoDTensor"); + } + + if (force_cpu) { + auto cpu = platform::CPUPlace(); + tensor->mutable_data(cpu, data_type); + } else { + tensor->mutable_data(dev_place, data_type); + } + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + math::set_constant(dev_ctx, tensor, value); + } }; class FillConstantOpVarTypeInference : public framework::VarTypeInference { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 81d63aace04..2ffdc90d847 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -124,9 +124,7 @@ PYBIND11_MODULE(core, m) { py::class_(m, "VarBase", R"DOC()DOC") .def(py::init<>()) .def("_run_backward", - [](imperative::VarBase &self, framework::Scope *scope) { - self.RunBackward(scope); - }) + [](imperative::VarBase &self) { self.RunBackward(); }) .def("_grad", &imperative::VarBase::Grad) .def_property( "desc", @@ -134,7 +132,13 @@ PYBIND11_MODULE(core, m) { [](imperative::VarBase &self, framework::VarDesc *var_desc) { self.var_desc_ = var_desc; }, - py::return_value_policy::reference); + py::return_value_policy::reference) + .def_property("var", + [](const imperative::VarBase &self) { return self.var_; }, + [](imperative::VarBase &self, framework::Variable *var) { + self.var_ = var; + }, + py::return_value_policy::reference); py::class_(m, "OpBase", R"DOC()DOC") .def(py::init<>()) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index de30ed2fc58..823b6d80be1 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -15,6 +15,7 @@ from __future__ import print_function import collections +from collections import defaultdict import contextlib import os import re @@ -369,13 +370,11 @@ class Variable(object): self._ivar.desc = self.desc def _numpy(self): - scope = _imperative_tracer().get_scope(self.block.desc) - tensor = core.get_variable_tensor(scope, self.desc.name()) + tensor = self._ivar.var.get_tensor() return np.array(tensor) def _backward(self): - scope = _imperative_tracer().get_scope(self.block.desc) - self._ivar._run_backward(scope) + self._ivar._run_backward() def _gradient(self): return np.array(self._ivar._grad()) @@ -692,20 +691,20 @@ class Operator(object): if _in_imperative_mode(): self.iop = core.OpBase() self.iop.desc = self.desc - self.inputs = [] + self.inputs = defaultdict(list) if inputs is not None: - for inp in inputs.values(): - if isinstance(inp, Variable): - self.inputs.append(inp) - elif isinstance(inp, list) or isinstance(inp, tuple): - self.inputs.extend(inp[:]) - self.outputs = [] + for k, v in six.iteritems(inputs): + if isinstance(v, Variable): + self.inputs[k].append(v._ivar) + elif isinstance(v, list) or isinstance(v, tuple): + self.inputs[k].extend([var._ivar for var in v]) + self.outputs = defaultdict(list) if outputs is not None: - for out in outputs.values(): - if isinstance(out, Variable): - self.outputs.append(out) - elif isinstance(out, list) or isinstance(out, tuple): - self.outputs.extend(out[:]) + for k, v in six.iteritems(outputs): + if isinstance(v, Variable): + self.outputs[k].append(v._ivar) + elif isinstance(v, list) or isinstance(v, tuple): + self.outputs[k].extend([var._ivar for var in v]) def _has_kernel(self, op_type): return op_type not in self.OP_WITHOUT_KERNEL_SET @@ -1273,8 +1272,7 @@ class Block(object): op_desc = self.desc.append_op() op = Operator(block=self, desc=op_desc, *args, **kwargs) if _in_imperative_mode(): - _imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], - [v._ivar for v in op.outputs], self.desc) + _imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc) self.ops.append(op) return op @@ -1325,8 +1323,7 @@ class Block(object): op_desc = self.desc._prepend_op() op = Operator(self, op_desc, *args, **kwargs) if _in_imperative_mode(): - _imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], - [v._ivar for v in op.outputs], self.desc) + _imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc) self.ops.insert(0, op) return op diff --git a/python/paddle/fluid/imperative/base.py b/python/paddle/fluid/imperative/base.py index aa48ef71aa6..61e243e288f 100644 --- a/python/paddle/fluid/imperative/base.py +++ b/python/paddle/fluid/imperative/base.py @@ -46,8 +46,7 @@ def to_variable(value, block=None): name=None, shape=value.shape, dtype=value.dtype) - scope = framework._imperative_tracer().get_scope(block.desc) - var = scope.var(py_var.name) + var = py_var._ivar.var tensor = var.get_tensor() tensor.set(value, core.CPUPlace()) return py_var diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 74b4a977db6..0a299bc2fbb 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -20,7 +20,7 @@ import six import sys import numpy as np -from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating +from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating, _in_imperative_mode from . import unique_name from paddle.fluid.initializer import Constant, Xavier from paddle.fluid.imperative import base @@ -313,11 +313,20 @@ class LayerHelper(object): param = self._create_weight_normalize(attr, shape, dtype) WeightNormParamAttr.params_with_weight_norm.append(param) return param - - self.startup_program.global_block().create_parameter( - dtype=dtype, shape=shape, **attr._to_kwargs(with_initializer=True)) - return self.main_program.global_block().create_parameter( - dtype=dtype, shape=shape, **attr._to_kwargs()) + if _in_imperative_mode(): + self.main_program.global_block().create_parameter( + dtype=dtype, shape=shape, **attr._to_kwargs()) + return self.startup_program.global_block().create_parameter( + dtype=dtype, + shape=shape, + **attr._to_kwargs(with_initializer=True)) + else: + self.startup_program.global_block().create_parameter( + dtype=dtype, + shape=shape, + **attr._to_kwargs(with_initializer=True)) + return self.main_program.global_block().create_parameter( + dtype=dtype, shape=shape, **attr._to_kwargs()) def get_parameter(self, name): param = self.main_program.global_block().var(name) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cc1fdbd2856..d83e2735ffe 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np import six import os +import sys import inspect from ..layer_helper import LayerHelper from ..initializer import Normal, Constant @@ -9682,6 +9683,7 @@ class FC(layers.PyLayer): shape=param_shape, dtype=self._dtype, is_bias=False) + sys.stderr.write('created w: %s\n' % self._w.name) def forward(self, inputs): tmp = self._helper.create_variable_for_type_inference(self._dtype) diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index 0fe69d1bd4b..6368f9b44a6 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import contextlib import unittest import numpy as np @@ -38,7 +39,9 @@ class MyLayer(fluid.imperative.PyLayer): def forward(self, inputs): x = fluid.layers.relu(inputs[0]) self._x_for_debug = x - return [fluid.layers.elementwise_mul(x, x)] + x = fluid.layers.elementwise_mul(x, x) + x = fluid.layers.reduce_sum(x) + return [x] class MLP(fluid.imperative.PyLayer): @@ -79,10 +82,12 @@ class TestImperative(unittest.TestCase): with new_program_scope(): inp = fluid.layers.data( name="inp", shape=[3], append_batch_size=False) - l = MyLayer() - x = l(inp)[0] + x = fluid.layers.relu(inp) + x_for_debug = x + x = fluid.layers.elementwise_mul(x, x) + x = fluid.layers.reduce_sum(x) param_grads = fluid.backward.append_backward( - x, parameter_list=[l._x_for_debug.name])[0] + x, parameter_list=[x_for_debug.name])[0] exe = fluid.Executor(fluid.CPUPlace()) static_out, static_grad = exe.run( -- GitLab