未验证 提交 3e840842 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15053 from panyx0718/imperative_hold

refactor to avoid scope.
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -1041,12 +1040,11 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1041,12 +1040,11 @@ Scope* OperatorWithKernel::PrepareData(
proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1; int data_type = -1;
std::string last_input_name;
for (auto& input : this->inputs_) { for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) { const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
auto* var = scope.FindVar(ipt_name); for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) { if (var != nullptr) {
const Tensor* t = nullptr; const Tensor* t = nullptr;
if (var->IsType<Tensor>()) { if (var->IsType<Tensor>()) {
...@@ -1057,15 +1055,14 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1057,15 +1055,14 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input %s is not initialized", PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized",
ipt_name); input.first, i);
int tmp = static_cast<int>(t->type()); int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", "DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
Type(), last_input_name, data_type, ipt_name, tmp); Type(), data_type, tmp);
data_type = tmp; data_type = tmp;
last_input_name = ipt_name;
} }
} }
} }
......
...@@ -81,6 +81,10 @@ class RuntimeContext { ...@@ -81,6 +81,10 @@ class RuntimeContext {
RuntimeContext(const VariableNameMap& innames, RuntimeContext(const VariableNameMap& innames,
const VariableNameMap& outnames, const Scope& scope); const VariableNameMap& outnames, const Scope& scope);
RuntimeContext(const VariableValueMap& invars,
const VariableValueMap& outvars)
: inputs(invars), outputs(outvars) {}
VariableValueMap inputs; VariableValueMap inputs;
VariableValueMap outputs; VariableValueMap outputs;
}; };
...@@ -447,8 +451,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -447,8 +451,9 @@ class OperatorWithKernel : public OperatorBase {
void RuntimeInferShape(const Scope& scope, const platform::Place& place, void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override; const RuntimeContext& ctx) const override;
protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
protected:
virtual OpKernelType GetKernelTypeForVar( virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const; const OpKernelType& expected_kernel_type) const;
......
...@@ -42,13 +42,9 @@ void AddTo(Variable* src, Variable* dst) { ...@@ -42,13 +42,9 @@ void AddTo(Variable* src, Variable* dst) {
class Autograd { class Autograd {
public: public:
explicit Autograd(framework::Scope* scope) : scope_(scope) {} Autograd() {}
void RunBackward(VarBase* var) { void RunBackward(VarBase* var) {
PADDLE_ENFORCE(var->pre_op_->op_desc_);
// TODO(panyx0718): Only create for vars that "require_grad"
(*var->pre_op_->output_vars_)[var->pre_op_out_idx_]->grads_ = var->grads_;
std::deque<OpBase*> ready; std::deque<OpBase*> ready;
ready.push_back(var->pre_op_); ready.push_back(var->pre_op_);
...@@ -57,18 +53,22 @@ class Autograd { ...@@ -57,18 +53,22 @@ class Autograd {
while (!ready.empty()) { while (!ready.empty()) {
OpBase* ready_op = ready.front(); OpBase* ready_op = ready.front();
ready.pop_front(); ready.pop_front();
std::vector<Variable*> input_grads = ready_op->ApplyGrad(scope_); std::map<std::string, std::vector<VarBase*>> input_grads =
ready_op->ApplyGrad();
for (size_t i = 0; i < input_grads.size(); ++i) {
if (!input_grads[i]) continue; for (auto it : input_grads) {
OpBase* pre_op = ready_op->pre_ops_->at(i); const std::vector<VarBase*>& ingrads = it.second;
if (!pre_op) continue; for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue;
dep_counts[pre_op] -= 1; OpBase* pre_op = ready_op->pre_ops_[it.first][i];
PADDLE_ENFORCE(dep_counts[pre_op] >= 0); if (!pre_op) continue;
bool pre_op_ready = dep_counts[pre_op] == 0;
if (pre_op_ready) { dep_counts[pre_op] -= 1;
ready.push_back(pre_op); PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
bool pre_op_ready = dep_counts[pre_op] == 0;
if (pre_op_ready) {
ready.push_back(pre_op);
}
} }
} }
} }
...@@ -85,138 +85,84 @@ class Autograd { ...@@ -85,138 +85,84 @@ class Autograd {
while (!queue.empty()) { while (!queue.empty()) {
OpBase* candidate = queue.front(); OpBase* candidate = queue.front();
queue.pop_front(); queue.pop_front();
for (OpBase* pre_op : *(candidate->pre_ops_)) { for (auto it : candidate->pre_ops_) {
if (!pre_op) continue; for (OpBase* pre_op : it.second) {
if (visited.find(pre_op) == visited.end()) { if (!pre_op) continue;
visited.insert(pre_op); if (visited.find(pre_op) == visited.end()) {
queue.push_back(pre_op); visited.insert(pre_op);
queue.push_back(pre_op);
}
ret[pre_op] += 1;
} }
ret[pre_op] += 1;
} }
} }
return ret; return ret;
} }
framework::Scope* scope_;
}; };
framework::Variable* CreateVariable(const std::string& name,
const framework::DDim& dim, float val,
framework::Scope* scope,
bool random_name = true) {
std::string varname = name;
if (random_name) {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int>::max());
int id = dist6(rng);
varname = string::Sprintf("%s@%d", varname, id);
}
VLOG(3) << "creating var " << varname;
framework::Variable* var = scope->Var(varname);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
float* data = tensor->mutable_data<float>(dim, platform::CPUPlace());
std::fill(data, data + tensor->numel(), val);
return var;
}
framework::LoDTensor& VarBase::Grad() { framework::LoDTensor& VarBase::Grad() {
VLOG(3) << "get var grad " << var_desc_->Name(); VLOG(3) << "get var grad " << var_desc_->Name();
return *grads_->GetMutable<framework::LoDTensor>(); return *grads_->GetMutable<framework::LoDTensor>();
} }
void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
VLOG(3) << "apply var grad " << var_desc_->Name() << " " if (!grad_op_desc_) {
<< grad->Get<framework::LoDTensor>().data<float>()[0]; VLOG(3) << "op with no grad: " << op_desc_->Type();
if (!grads_) { return {};
grads_ =
CreateVariable(string::Sprintf("%s@IGrad", var_desc_->Name()),
var_->Get<framework::LoDTensor>().dims(), 0.0, scope);
} }
AddTo(grad, grads_);
VLOG(3) << "grad_ after apply var grad " << var_desc_->Name() << " "
<< grads_->Get<framework::LoDTensor>().data<float>()[0];
}
std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
VLOG(3) << "op grad " << grad_op_desc_->Type(); VLOG(3) << "op grad " << grad_op_desc_->Type();
for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) { std::vector<std::unique_ptr<framework::Variable>> tmp_vars;
if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) { std::map<std::string, std::vector<framework::Variable*>> grad_outputs;
// grad op inputs can be forward inputs, so not in grad_to_var. for (auto it : grad_output_vars_) {
continue; auto& outputs = grad_outputs[it.first];
} for (size_t i = 0; i < it.second.size(); ++i) {
VLOG(3) << "op grad in var " << grad_invar; tmp_vars.emplace_back(new framework::Variable());
block_->FindRecursiveOrCreateVar(grad_invar); outputs.push_back(tmp_vars.back().get());
framework::Variable* var = scope->Var(grad_invar); outputs.back()->GetMutable<framework::LoDTensor>();
const std::string& invar = grad_to_var_->at(grad_invar);
for (VarBase* varbase : *output_vars_) {
// Use the accumulated grads_ by sharing the input with grads_.
if (varbase->var_desc_->Name() == invar) {
var->GetMutable<framework::LoDTensor>()->ShareDataWith(
varbase->grads_->Get<framework::LoDTensor>());
break;
}
} }
} }
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { framework::RuntimeContext ctx(grad_input_vars_, grad_outputs);
VLOG(3) << "grad outvar " << outvar;
block_->FindRecursiveOrCreateVar(outvar); // No need to do static infer shape here.
framework::Variable* var = scope->Var(outvar); // grad_op_desc_->InferShape(*block_);
if (!var->IsInitialized()) {
framework::VarDesc* var_desc = block_->FindVar(outvar);
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else {
LOG(ERROR) << "tracer doesn't support yet";
}
}
}
grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_); grad_op_desc_->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc_); framework::OpRegistry::CreateOp(*grad_op_desc_);
framework::OperatorWithKernel* op_kernel =
opbase->Run(*scope, platform::CPUPlace()); dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
// `ret` matches exactly with `input_vars_` of forward op.
std::vector<Variable*> ret; framework::Scope scope;
for (size_t i = 0; i < input_vars_->size(); ++i) { platform::CPUPlace place;
bool found = false; PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
VarBase* origin_var = (*input_vars_)[i]; p.op.RuntimeInferShape(scope, place, ctx);
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
Variable* var = scope->FindVar(outvar);
std::string orig_var = grad_to_var_->at(outvar); for (auto it : grad_output_vars_) {
if (origin_var->var_desc_->Name() != orig_var) { auto& outputs = grad_outputs[it.first];
continue; auto& origin_outputs = it.second;
} for (size_t i = 0; i < outputs.size(); ++i) {
VLOG(3) << "apply grad " << outvar << " with origin " << orig_var; framework::Variable* orig_grad = origin_outputs[i];
origin_var->ApplyGrad(scope, var); AddTo(outputs[i], orig_grad);
found = true;
ret.push_back(var);
// TODO(panyx0718): There might be another outvar with the same name.
// In that case, it doesn't matter the first one or the second one is
// used.
break;
}
if (!found) {
ret.push_back(nullptr);
} }
} }
return ret; return input_vars_;
} }
void VarBase::RunBackward(framework::Scope* scope) { void VarBase::RunBackward() {
grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()),
var_->Get<framework::LoDTensor>().dims(), 1.0, scope,
false);
if (!pre_op_) return; if (!pre_op_) return;
Autograd(scope).RunBackward(this);
auto grads_t = grads_->GetMutable<framework::LoDTensor>();
float* data = grads_t->mutable_data<float>(platform::CPUPlace());
std::fill(data, data + grads_t->numel(), 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
} }
} // namespace imperative } // namespace imperative
......
...@@ -14,17 +14,69 @@ ...@@ -14,17 +14,69 @@
#pragma once #pragma once
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
class PreparedOp {
public:
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx)
: op(op), ctx(ctx), func(func), dev_ctx(dev_ctx) {}
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel& op,
const platform::Place& place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type());
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.",
op.Type());
}
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = op.GetExpectedKernelType(
framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == framework::LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = framework::LibraryType::kPlain;
expected_kernel_key.data_layout_ = framework::DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key));
}
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
}
const framework::OperatorBase& op;
const framework::RuntimeContext& ctx;
framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx;
};
class OpBase; class OpBase;
class VarBase { class VarBase {
...@@ -33,18 +85,26 @@ class VarBase { ...@@ -33,18 +85,26 @@ class VarBase {
: pre_op_(nullptr), : pre_op_(nullptr),
pre_op_out_idx_(-1), pre_op_out_idx_(-1),
var_desc_(nullptr), var_desc_(nullptr),
var_(nullptr), var_(new framework::Variable()),
grads_(nullptr) {} grads_(new framework::Variable()) {}
virtual ~VarBase() {} virtual ~VarBase() {
if (var_) {
void ApplyGrad(framework::Scope* scope, framework::Variable* grad); delete var_;
var_ = nullptr;
}
if (grads_) {
delete grads_;
grads_ = nullptr;
}
}
void RunBackward(framework::Scope* scope); void RunBackward();
framework::LoDTensor& Grad(); framework::LoDTensor& Grad();
OpBase* pre_op_; OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_; int pre_op_out_idx_;
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
...@@ -54,35 +114,24 @@ class VarBase { ...@@ -54,35 +114,24 @@ class VarBase {
class OpBase { class OpBase {
public: public:
OpBase() OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {}
: input_vars_(new std::vector<VarBase*>()),
output_vars_(new std::vector<VarBase*>()),
pre_ops_(new std::vector<OpBase*>()),
pre_ops_out_idx_(new std::vector<int>()),
op_desc_(nullptr),
grad_op_desc_(nullptr) {}
virtual ~OpBase() { virtual ~OpBase() {
delete input_vars_;
delete output_vars_;
delete pre_ops_;
delete pre_ops_out_idx_;
if (grad_op_desc_) delete grad_op_desc_; if (grad_op_desc_) delete grad_op_desc_;
if (grad_to_var_) delete grad_to_var_;
} }
std::vector<framework::Variable*> ApplyGrad(framework::Scope* scope); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
std::vector<VarBase*>* input_vars_;
std::vector<VarBase*>* output_vars_;
std::vector<OpBase*>* pre_ops_;
std::vector<int>* pre_ops_out_idx_;
framework::OpDesc* op_desc_; framework::OpDesc* op_desc_;
framework::OpDesc* grad_op_desc_; framework::OpDesc* grad_op_desc_;
std::unordered_map<std::string, std::string>* grad_to_var_;
std::map<std::string, std::vector<VarBase*>> input_vars_;
std::map<std::string, std::vector<VarBase*>> output_vars_;
std::map<std::string, std::vector<OpBase*>> pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
std::map<std::string, std::vector<framework::Variable*>> grad_input_vars_;
std::map<std::string, std::vector<framework::Variable*>> grad_output_vars_;
framework::BlockDesc* block_; framework::BlockDesc* block_;
}; };
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/engine.h" #include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
...@@ -41,22 +40,28 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -41,22 +40,28 @@ void CreateGradOp(const framework::OpDesc& op_desc,
*grad_op_desc = grad_op_descs[0].release(); *grad_op_desc = grad_op_descs[0].release();
} }
void InitVar(framework::Variable* var, framework::Variable* grad_var) {
auto& var_t = var->Get<framework::LoDTensor>();
float* data =
grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>(
var_t.dims(), platform::CPUPlace());
std::fill(data, data + var_t.numel(), 0.0);
}
class Tracer { class Tracer {
public: public:
explicit Tracer(framework::BlockDesc* root_block, explicit Tracer(framework::BlockDesc* root_block,
framework::BlockDesc* startup_block) framework::BlockDesc* startup_block)
: root_block_(root_block), startup_block_(startup_block) { : root_block_(root_block), startup_block_(startup_block) {}
root_scope_ = new framework::Scope();
scopes_[root_block_] = root_scope_;
scopes_[startup_block_] = root_scope_;
}
virtual ~Tracer() { delete root_scope_; } virtual ~Tracer() {}
void Trace(OpBase* op, const std::vector<VarBase*>& inputs, void Trace(OpBase* op,
const std::vector<VarBase*>& outputs, const std::map<std::string, std::vector<VarBase*>>& inputs,
const std::map<std::string, std::vector<VarBase*>>& outputs,
framework::BlockDesc* block) { framework::BlockDesc* block) {
framework::Scope* scope = GetScope(block); std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type(); VLOG(3) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block); op_desc->InferShape(*block);
...@@ -64,77 +69,113 @@ class Tracer { ...@@ -64,77 +69,113 @@ class Tracer {
std::unique_ptr<framework::OperatorBase> op_base = std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc); framework::OpRegistry::CreateOp(*op_desc);
*op->input_vars_ = inputs; framework::VariableValueMap invars_map;
for (VarBase* input : inputs) { framework::VariableValueMap outvars_map;
const std::string vname = input->var_desc_->Name();
framework::Variable* var = scope->Var(vname); op->input_vars_ = inputs;
input->var_ = var; for (auto it : op->input_vars_) {
if (!var->IsInitialized()) { auto& invars = invars_map[it.first];
framework::VarDesc* var_desc = block->FindVar(vname); for (VarBase* inp : it.second) {
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
var->GetMutable<framework::LoDTensor>(); op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_);
vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) {
op->pre_ops_[it.first].push_back(inp->pre_op_);
op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_);
} else { } else {
LOG(ERROR) << "tracer doesn't support yet"; op->pre_ops_[it.first].push_back(nullptr);
} }
VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
<< inp->var_->IsInitialized();
} }
if (input->pre_op_) {
op->pre_ops_->push_back(input->pre_op_);
op->pre_ops_out_idx_->push_back(input->pre_op_out_idx_);
} else {
op->pre_ops_->push_back(nullptr);
}
VLOG(3) << "input vname " << vname << " "
<< var->Get<framework::LoDTensor>().dims().size();
} }
*op->output_vars_ = outputs; op->output_vars_ = outputs;
for (size_t i = 0; i < outputs.size(); ++i) { for (auto it : op->output_vars_) {
const std::string vname = outputs[i]->var_desc_->Name(); auto& outvars = outvars_map[it.first];
framework::Variable* var = scope->Var(vname); const std::vector<VarBase*>& outputs = it.second;
if (!var->IsInitialized()) { for (size_t i = 0; i < outputs.size(); ++i) {
framework::VarDesc* var_desc = block->FindVar(vname); VarBase* out = outputs[i];
outvars.push_back(out->var_);
vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>(); out->var_->GetMutable<framework::LoDTensor>();
} else { } else {
LOG(ERROR) << "tracer doesn't support yet"; LOG(ERROR) << "tracer doesn't support yet";
} }
out->pre_op_ = op;
out->pre_op_out_name_ = it.first;
out->pre_op_out_idx_ = i;
VLOG(3) << "output vname " << out->var_desc_->Name() << " "
<< out->var_->IsInitialized();
} }
outputs[i]->var_ = var;
outputs[i]->pre_op_ = op;
outputs[i]->pre_op_out_idx_ = i;
} }
VLOG(3) << "tracer running " << op_desc->Type(); VLOG(3) << "tracer running " << op_desc->Type();
op_base->Run(*scope, platform::CPUPlace()); framework::RuntimeContext ctx(invars_map, outvars_map);
// TODO(panyx0718): Cache p.
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(op_base.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
platform::CPUPlace place;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
p.op.RuntimeInferShape(scope, place, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
if (block == startup_block_) { if (block == startup_block_) {
op->grad_op_desc_ = nullptr; op->grad_op_desc_ = nullptr;
op->grad_to_var_ = nullptr;
} else { } else {
framework::OpDesc* grad_op_desc; framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>(); auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var); CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
op->grad_op_desc_ = grad_op_desc; op->grad_op_desc_ = grad_op_desc;
op->grad_to_var_ = grad_to_var;
}
op->block_ = block;
}
framework::Scope* GetScope(framework::BlockDesc* block) { for (auto it : grad_op_desc->Inputs()) {
if (scopes_.find(block) != scopes_.end()) { auto& grad_in_vars = op->grad_input_vars_[it.first];
return scopes_.at(block); for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = grad_to_var->find(grad_invar);
if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end());
grad_in_vars.push_back(fwd_var_it->second->var_);
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) {
InitVar(var->var_, var->grads_);
}
grad_in_vars.push_back(var->grads_);
}
}
}
for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) {
InitVar(var->var_, var->grads_);
}
grad_out_vars.push_back(var->grads_);
}
}
} }
framework::BlockDesc* parent_block = block->ParentBlock(); op->block_ = block;
PADDLE_ENFORCE(scopes_.find(parent_block) != scopes_.end());
framework::Scope* scope = &scopes_[parent_block]->NewScope();
scopes_[block] = scope;
return scope;
} }
private: private:
std::map<framework::BlockDesc*, framework::Scope*> scopes_;
framework::BlockDesc* root_block_; framework::BlockDesc* root_block_;
framework::BlockDesc* startup_block_; framework::BlockDesc* startup_block_;
framework::Scope* root_scope_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -12,68 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,68 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FillConstantInferShape : public framework::InferShapeBase { class FillConstantOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillConstantOp should not be null."); "Output(Out) of FillConstantOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape"); auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape)); ctx->SetOutputDim("Out", framework::make_ddim(shape));
} }
};
class FillConstantOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
framework::Tensor *tensor = nullptr;
auto &out_var = *scope.FindVar(Output("Out")); protected:
framework::OpKernelType GetExpectedKernelType(
if (out_var.IsType<framework::LoDTensor>()) { const framework::ExecutionContext& ctx) const override {
tensor = out_var.GetMutable<framework::LoDTensor>(); return framework::OpKernelType(
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape"))); framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
} else if (out_var.IsType<framework::SelectedRows>()) { ctx.GetPlace());
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else {
PADDLE_THROW(
"fill constant op's output only"
"supports SelectedRows and LoDTensor");
}
if (force_cpu) {
auto cpu = platform::CPUPlace();
tensor->mutable_data(cpu, data_type);
} else {
tensor->mutable_data(dev_place, data_type);
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
math::set_constant(dev_ctx, tensor, value);
} }
}; };
class FillConstantOpVarTypeInference : public framework::VarTypeInference { class FillConstantOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc *block) const override {} framework::BlockDesc* block) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype")));
auto& out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetDataType(data_type);
}
}; };
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -107,7 +79,13 @@ Fill up a variable with specified constant value. ...@@ -107,7 +79,13 @@ Fill up a variable with specified constant value.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fill_constant, ops::FillConstantOp,
ops::FillConstantInferShape, ops::FillConstantOpMaker, REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker,
paddle::framework::EmptyGradOpMaker, ops::FillConstantOpVarTypeInference,
ops::FillConstantOpVarTypeInference); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<paddle::platform::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<paddle::platform::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename T>
class FillConstantKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto value = ctx.Attr<float>("value");
auto force_cpu = ctx.Attr<bool>("force_cpu");
framework::Tensor *tensor = nullptr;
framework::Variable *out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->Resize(
framework::make_ddim(ctx.Attr<std::vector<int64_t>>("shape")));
} else if (out_var->IsType<framework::SelectedRows>()) {
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(
framework::make_ddim(ctx.Attr<std::vector<int64_t>>("shape")));
} else {
PADDLE_THROW(
"fill constant op's output only"
"supports SelectedRows and LoDTensor");
}
if (force_cpu) {
tensor->mutable_data(platform::CPUPlace(), data_type);
} else {
tensor->mutable_data(ctx.GetPlace(), data_type);
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
math::set_constant(dev_ctx, tensor, value);
}
};
} // namespace operators
} // namespace paddle
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
namespace paddle { namespace paddle {
...@@ -28,9 +27,7 @@ void BindTracer(pybind11::module *m) { ...@@ -28,9 +27,7 @@ void BindTracer(pybind11::module *m) {
framework::BlockDesc *startup_block) { framework::BlockDesc *startup_block) {
new (&self) imperative::Tracer(root_block, startup_block); new (&self) imperative::Tracer(root_block, startup_block);
}) })
.def("trace", &imperative::Tracer::Trace) .def("trace", &imperative::Tracer::Trace);
.def("get_scope", &imperative::Tracer::GetScope,
pybind11::return_value_policy::reference);
} }
} // namespace pybind } // namespace pybind
......
...@@ -124,9 +124,7 @@ PYBIND11_MODULE(core, m) { ...@@ -124,9 +124,7 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self, framework::Scope *scope) { [](imperative::VarBase &self) { self.RunBackward(); })
self.RunBackward(scope);
})
.def("_grad", &imperative::VarBase::Grad) .def("_grad", &imperative::VarBase::Grad)
.def_property( .def_property(
"desc", "desc",
...@@ -134,7 +132,13 @@ PYBIND11_MODULE(core, m) { ...@@ -134,7 +132,13 @@ PYBIND11_MODULE(core, m) {
[](imperative::VarBase &self, framework::VarDesc *var_desc) { [](imperative::VarBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc; self.var_desc_ = var_desc;
}, },
py::return_value_policy::reference); py::return_value_policy::reference)
.def_property("var",
[](const imperative::VarBase &self) { return self.var_; },
[](imperative::VarBase &self, framework::Variable *var) {
self.var_ = var;
},
py::return_value_policy::reference);
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import collections import collections
from collections import defaultdict
import contextlib import contextlib
import os import os
import re import re
...@@ -369,13 +370,11 @@ class Variable(object): ...@@ -369,13 +370,11 @@ class Variable(object):
self._ivar.desc = self.desc self._ivar.desc = self.desc
def _numpy(self): def _numpy(self):
scope = _imperative_tracer().get_scope(self.block.desc) tensor = self._ivar.var.get_tensor()
tensor = core.get_variable_tensor(scope, self.desc.name())
return np.array(tensor) return np.array(tensor)
def _backward(self): def _backward(self):
scope = _imperative_tracer().get_scope(self.block.desc) self._ivar._run_backward()
self._ivar._run_backward(scope)
def _gradient(self): def _gradient(self):
return np.array(self._ivar._grad()) return np.array(self._ivar._grad())
...@@ -692,20 +691,20 @@ class Operator(object): ...@@ -692,20 +691,20 @@ class Operator(object):
if _in_imperative_mode(): if _in_imperative_mode():
self.iop = core.OpBase() self.iop = core.OpBase()
self.iop.desc = self.desc self.iop.desc = self.desc
self.inputs = [] self.inputs = defaultdict(list)
if inputs is not None: if inputs is not None:
for inp in inputs.values(): for k, v in six.iteritems(inputs):
if isinstance(inp, Variable): if isinstance(v, Variable):
self.inputs.append(inp) self.inputs[k].append(v._ivar)
elif isinstance(inp, list) or isinstance(inp, tuple): elif isinstance(v, list) or isinstance(v, tuple):
self.inputs.extend(inp[:]) self.inputs[k].extend([var._ivar for var in v])
self.outputs = [] self.outputs = defaultdict(list)
if outputs is not None: if outputs is not None:
for out in outputs.values(): for k, v in six.iteritems(outputs):
if isinstance(out, Variable): if isinstance(v, Variable):
self.outputs.append(out) self.outputs[k].append(v._ivar)
elif isinstance(out, list) or isinstance(out, tuple): elif isinstance(v, list) or isinstance(v, tuple):
self.outputs.extend(out[:]) self.outputs[k].extend([var._ivar for var in v])
def _has_kernel(self, op_type): def _has_kernel(self, op_type):
return op_type not in self.OP_WITHOUT_KERNEL_SET return op_type not in self.OP_WITHOUT_KERNEL_SET
...@@ -1273,8 +1272,7 @@ class Block(object): ...@@ -1273,8 +1272,7 @@ class Block(object):
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
op = Operator(block=self, desc=op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
if _in_imperative_mode(): if _in_imperative_mode():
_imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], _imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc)
[v._ivar for v in op.outputs], self.desc)
self.ops.append(op) self.ops.append(op)
return op return op
...@@ -1325,8 +1323,7 @@ class Block(object): ...@@ -1325,8 +1323,7 @@ class Block(object):
op_desc = self.desc._prepend_op() op_desc = self.desc._prepend_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
if _in_imperative_mode(): if _in_imperative_mode():
_imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], _imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc)
[v._ivar for v in op.outputs], self.desc)
self.ops.insert(0, op) self.ops.insert(0, op)
return op return op
......
...@@ -46,8 +46,7 @@ def to_variable(value, block=None): ...@@ -46,8 +46,7 @@ def to_variable(value, block=None):
name=None, name=None,
shape=value.shape, shape=value.shape,
dtype=value.dtype) dtype=value.dtype)
scope = framework._imperative_tracer().get_scope(block.desc) var = py_var._ivar.var
var = scope.var(py_var.name)
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set(value, core.CPUPlace()) tensor.set(value, core.CPUPlace())
return py_var return py_var
......
...@@ -20,7 +20,7 @@ import six ...@@ -20,7 +20,7 @@ import six
import sys import sys
import numpy as np import numpy as np
from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating, _in_imperative_mode
from . import unique_name from . import unique_name
from paddle.fluid.initializer import Constant, Xavier from paddle.fluid.initializer import Constant, Xavier
from paddle.fluid.imperative import base from paddle.fluid.imperative import base
...@@ -313,11 +313,22 @@ class LayerHelper(object): ...@@ -313,11 +313,22 @@ class LayerHelper(object):
param = self._create_weight_normalize(attr, shape, dtype) param = self._create_weight_normalize(attr, shape, dtype)
WeightNormParamAttr.params_with_weight_norm.append(param) WeightNormParamAttr.params_with_weight_norm.append(param)
return param return param
if _in_imperative_mode():
self.startup_program.global_block().create_parameter( self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs(with_initializer=True)) dtype=dtype, shape=shape, **attr._to_kwargs())
return self.main_program.global_block().create_parameter( # In imperative mode, we want the returned parameter to be
dtype=dtype, shape=shape, **attr._to_kwargs()) # initialized so that it can be used imperatively.
return self.startup_program.global_block().create_parameter(
dtype=dtype,
shape=shape,
**attr._to_kwargs(with_initializer=True))
else:
self.startup_program.global_block().create_parameter(
dtype=dtype,
shape=shape,
**attr._to_kwargs(with_initializer=True))
return self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs())
def get_parameter(self, name): def get_parameter(self, name):
param = self.main_program.global_block().var(name) param = self.main_program.global_block().var(name)
......
...@@ -38,7 +38,9 @@ class MyLayer(fluid.imperative.PyLayer): ...@@ -38,7 +38,9 @@ class MyLayer(fluid.imperative.PyLayer):
def forward(self, inputs): def forward(self, inputs):
x = fluid.layers.relu(inputs[0]) x = fluid.layers.relu(inputs[0])
self._x_for_debug = x self._x_for_debug = x
return [fluid.layers.elementwise_mul(x, x)] x = fluid.layers.elementwise_mul(x, x)
x = fluid.layers.reduce_sum(x)
return [x]
class MLP(fluid.imperative.PyLayer): class MLP(fluid.imperative.PyLayer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册