提交 f52b514d 编写于 作者: X Xin Pan

call kernel

上级 4e80e04f
...@@ -179,8 +179,8 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -179,8 +179,8 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(3) << place << " " << DebugStringEx(&scope); VLOG(3) << place << " " << DebugStringEx(&scope);
} }
void OperatorBase::Run(const RuntimeContext& ctx, void OperatorBase::RunPrepared(const RuntimeContext& ctx,
const platform::Place& place) { const platform::Place& place) {
RunImplPrepared(ctx, place); RunImplPrepared(ctx, place);
} }
...@@ -1092,7 +1092,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1092,7 +1092,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
int data_type = -1; int data_type = -1;
for (auto& input : this->inputs_) { for (auto& input : this->inputs_) {
for (const Variable* var : ctx.MultiInputVar(input.first)) { const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) { if (var != nullptr) {
const Tensor* t = nullptr; const Tensor* t = nullptr;
if (var->IsType<Tensor>()) { if (var->IsType<Tensor>()) {
...@@ -1103,7 +1105,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1103,7 +1105,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input is not initialized"); PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized",
input.first, i);
int tmp = static_cast<int>(t->type()); int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
......
...@@ -105,7 +105,7 @@ class OperatorBase { ...@@ -105,7 +105,7 @@ class OperatorBase {
/// Executor will call this interface function to Run an op. /// Executor will call this interface function to Run an op.
// The implementation should be written at RunImpl // The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place); void Run(const Scope& scope, const platform::Place& place);
void Run(const RuntimeContext& ctx, const platform::Place& place); void RunPrepared(const RuntimeContext& ctx, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop. // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {} virtual void Stop() {}
...@@ -457,8 +457,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -457,8 +457,9 @@ class OperatorWithKernel : public OperatorBase {
void RuntimeInferShape(const Scope& scope, const platform::Place& place, void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override; const RuntimeContext& ctx) const override;
protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
protected:
virtual OpKernelType GetKernelTypeForVar( virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const; const OpKernelType& expected_kernel_type) const;
......
...@@ -45,12 +45,6 @@ class Autograd { ...@@ -45,12 +45,6 @@ class Autograd {
Autograd() {} Autograd() {}
void RunBackward(VarBase* var) { void RunBackward(VarBase* var) {
PADDLE_ENFORCE(var->pre_op_->op_desc_);
PADDLE_ENFORCE(
var->grads_ ==
var->pre_op_->output_vars_[var->pre_op_out_name_][var->pre_op_out_idx_]
->grads_);
std::deque<OpBase*> ready; std::deque<OpBase*> ready;
ready.push_back(var->pre_op_); ready.push_back(var->pre_op_);
...@@ -66,7 +60,7 @@ class Autograd { ...@@ -66,7 +60,7 @@ class Autograd {
const std::vector<VarBase*>& ingrads = it.second; const std::vector<VarBase*>& ingrads = it.second;
for (size_t i = 0; i < ingrads.size(); ++i) { for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue; if (!ingrads[i]) continue;
OpBase* pre_op = (*ready_op->pre_ops_)[it.first][i]; OpBase* pre_op = ready_op->pre_ops_[it.first][i];
if (!pre_op) continue; if (!pre_op) continue;
dep_counts[pre_op] -= 1; dep_counts[pre_op] -= 1;
...@@ -91,7 +85,7 @@ class Autograd { ...@@ -91,7 +85,7 @@ class Autograd {
while (!queue.empty()) { while (!queue.empty()) {
OpBase* candidate = queue.front(); OpBase* candidate = queue.front();
queue.pop_front(); queue.pop_front();
for (auto it : *(candidate->pre_ops_)) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
...@@ -138,11 +132,13 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -138,11 +132,13 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
VLOG(3) << "op grad " << grad_op_desc_->Type(); VLOG(3) << "op grad " << grad_op_desc_->Type();
std::vector<std::unique_ptr<framework::Variable>> tmp_vars;
std::map<std::string, std::vector<framework::Variable*>> grad_outputs; std::map<std::string, std::vector<framework::Variable*>> grad_outputs;
for (auto it : grad_output_vars_) { for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first]; auto& outputs = grad_outputs[it.first];
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
outputs.push_back(new framework::Variable()); tmp_vars.emplace_back(new framework::Variable());
outputs.push_back(tmp_vars.back().get());
outputs.back()->GetMutable<framework::LoDTensor>(); outputs.back()->GetMutable<framework::LoDTensor>();
} }
} }
...@@ -155,7 +151,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -155,7 +151,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc_); framework::OpRegistry::CreateOp(*grad_op_desc_);
opbase->Run(ctx, platform::CPUPlace()); framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
platform::CPUPlace place;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
p.op.RuntimeInferShape(scope, place, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
for (auto it : grad_output_vars_) { for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first]; auto& outputs = grad_outputs[it.first];
...@@ -169,11 +173,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -169,11 +173,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
void VarBase::RunBackward() { void VarBase::RunBackward() {
if (!pre_op_) return;
auto grads_t = grads_->GetMutable<framework::LoDTensor>(); auto grads_t = grads_->GetMutable<framework::LoDTensor>();
float* data = grads_t->mutable_data<float>(platform::CPUPlace()); float* data = grads_t->mutable_data<float>(platform::CPUPlace());
std::fill(data, data + grads_t->numel(), 1.0); std::fill(data, data + grads_t->numel(), 1.0);
if (!pre_op_) return; PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this); Autograd().RunBackward(this);
} }
......
...@@ -25,6 +25,59 @@ ...@@ -25,6 +25,59 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
class PreparedOp {
public:
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx)
: op(op), ctx(ctx), func(func), dev_ctx(dev_ctx) {}
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel& op,
const platform::Place& place) {
framework::Scope dummy_scope;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type());
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.",
op.Type());
}
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = op.GetExpectedKernelType(
framework::ExecutionContext(op, dummy_scope, *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == framework::LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = framework::LibraryType::kPlain;
expected_kernel_key.data_layout_ = framework::DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key));
}
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
}
const framework::OperatorBase& op;
const framework::RuntimeContext& ctx;
framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx;
};
class OpBase; class OpBase;
class VarBase { class VarBase {
...@@ -62,30 +115,22 @@ class VarBase { ...@@ -62,30 +115,22 @@ class VarBase {
class OpBase { class OpBase {
public: public:
OpBase() OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {}
: pre_ops_(new std::map<std::string, std::vector<OpBase*>>()),
pre_ops_out_idx_(new std::map<std::string, std::vector<int>>()),
op_desc_(nullptr),
grad_op_desc_(nullptr) {}
virtual ~OpBase() { virtual ~OpBase() {
delete pre_ops_;
delete pre_ops_out_idx_;
if (grad_op_desc_) delete grad_op_desc_; if (grad_op_desc_) delete grad_op_desc_;
if (grad_to_var_) delete grad_to_var_;
} }
std::map<std::string, std::vector<VarBase*>> ApplyGrad(); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
framework::OpDesc* op_desc_;
framework::OpDesc* grad_op_desc_;
std::map<std::string, std::vector<VarBase*>> input_vars_; std::map<std::string, std::vector<VarBase*>> input_vars_;
std::map<std::string, std::vector<VarBase*>> output_vars_; std::map<std::string, std::vector<VarBase*>> output_vars_;
std::map<std::string, std::vector<OpBase*>>* pre_ops_; std::map<std::string, std::vector<OpBase*>> pre_ops_;
std::map<std::string, std::vector<int>>* pre_ops_out_idx_; std::map<std::string, std::vector<int>> pre_ops_out_idx_;
framework::OpDesc* op_desc_;
framework::OpDesc* grad_op_desc_;
std::unordered_map<std::string, std::string>* grad_to_var_;
std::map<std::string, std::vector<framework::Variable*>> grad_input_vars_; std::map<std::string, std::vector<framework::Variable*>> grad_input_vars_;
std::map<std::string, std::vector<framework::Variable*>> grad_output_vars_; std::map<std::string, std::vector<framework::Variable*>> grad_output_vars_;
framework::BlockDesc* block_; framework::BlockDesc* block_;
......
...@@ -82,10 +82,10 @@ class Tracer { ...@@ -82,10 +82,10 @@ class Tracer {
invars.push_back(inp->var_); invars.push_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) { if (inp->pre_op_) {
(*op->pre_ops_)[it.first].push_back(inp->pre_op_); op->pre_ops_[it.first].push_back(inp->pre_op_);
(*op->pre_ops_out_idx_)[it.first].push_back(inp->pre_op_out_idx_); op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_);
} else { } else {
(*op->pre_ops_)[it.first].push_back(nullptr); op->pre_ops_[it.first].push_back(nullptr);
} }
VLOG(3) << "input vname " << inp->var_desc_->Name() << " " VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
<< inp->var_->IsInitialized(); << inp->var_->IsInitialized();
...@@ -118,24 +118,33 @@ class Tracer { ...@@ -118,24 +118,33 @@ class Tracer {
VLOG(3) << "tracer running " << op_desc->Type(); VLOG(3) << "tracer running " << op_desc->Type();
framework::RuntimeContext ctx(invars_map, outvars_map); framework::RuntimeContext ctx(invars_map, outvars_map);
op_base->Run(ctx, platform::CPUPlace()); // op_base->RunPrepared(ctx, platform::CPUPlace());
// TODO(panyx0718): Cache p.
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(op_base.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
platform::CPUPlace place;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
p.op.RuntimeInferShape(scope, place, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
if (block == startup_block_) { if (block == startup_block_) {
op->grad_op_desc_ = nullptr; op->grad_op_desc_ = nullptr;
op->grad_to_var_ = nullptr;
} else { } else {
framework::OpDesc* grad_op_desc; framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>(); auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var); CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
op->grad_op_desc_ = grad_op_desc; op->grad_op_desc_ = grad_op_desc;
op->grad_to_var_ = grad_to_var;
for (auto it : grad_op_desc->Inputs()) { for (auto it : grad_op_desc->Inputs()) {
auto& grad_in_vars = op->grad_input_vars_[it.first]; auto& grad_in_vars = op->grad_input_vars_[it.first];
for (const std::string& grad_invar : it.second) { for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar); block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = op->grad_to_var_->find(grad_invar); auto var_it = grad_to_var->find(grad_invar);
if (var_it == op->grad_to_var_->end()) { if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar); auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != vars.end());
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.push_back(fwd_var_it->second->var_);
...@@ -152,8 +161,8 @@ class Tracer { ...@@ -152,8 +161,8 @@ class Tracer {
auto& grad_out_vars = op->grad_output_vars_[it.first]; auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) { for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar); block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = op->grad_to_var_->find(grad_outvar); auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != op->grad_to_var_->end()); PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) { if (!var->grads_->IsInitialized()) {
InitVar(var->var_, var->grads_); InitVar(var->var_, var->grads_);
......
...@@ -86,4 +86,5 @@ REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker, ...@@ -86,4 +86,5 @@ REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker,
REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>, REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<double>,
ops::FillConstantKernel<int64_t>); ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>);
...@@ -316,6 +316,8 @@ class LayerHelper(object): ...@@ -316,6 +316,8 @@ class LayerHelper(object):
if _in_imperative_mode(): if _in_imperative_mode():
self.main_program.global_block().create_parameter( self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs()) dtype=dtype, shape=shape, **attr._to_kwargs())
# In imperative mode, we want the returned parameter to be
# initialized so that it can be used imperatively.
return self.startup_program.global_block().create_parameter( return self.startup_program.global_block().create_parameter(
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import contextlib import contextlib
import unittest import unittest
import numpy as np import numpy as np
...@@ -82,12 +81,10 @@ class TestImperative(unittest.TestCase): ...@@ -82,12 +81,10 @@ class TestImperative(unittest.TestCase):
with new_program_scope(): with new_program_scope():
inp = fluid.layers.data( inp = fluid.layers.data(
name="inp", shape=[3], append_batch_size=False) name="inp", shape=[3], append_batch_size=False)
x = fluid.layers.relu(inp) l = MyLayer()
x_for_debug = x x = l(inp)[0]
x = fluid.layers.elementwise_mul(x, x)
x = fluid.layers.reduce_sum(x)
param_grads = fluid.backward.append_backward( param_grads = fluid.backward.append_backward(
x, parameter_list=[x_for_debug.name])[0] x, parameter_list=[l._x_for_debug.name])[0]
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
static_out, static_grad = exe.run( static_out, static_grad = exe.run(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册