提交 f52b514d 编写于 作者: X Xin Pan

call kernel

上级 4e80e04f
......@@ -179,8 +179,8 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(3) << place << " " << DebugStringEx(&scope);
}
void OperatorBase::Run(const RuntimeContext& ctx,
const platform::Place& place) {
void OperatorBase::RunPrepared(const RuntimeContext& ctx,
const platform::Place& place) {
RunImplPrepared(ctx, place);
}
......@@ -1092,7 +1092,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
int data_type = -1;
for (auto& input : this->inputs_) {
for (const Variable* var : ctx.MultiInputVar(input.first)) {
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
......@@ -1103,7 +1105,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input is not initialized");
PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized",
input.first, i);
int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE(
tmp == data_type || data_type == -1,
......
......@@ -105,7 +105,7 @@ class OperatorBase {
/// Executor will call this interface function to Run an op.
// The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place);
void Run(const RuntimeContext& ctx, const platform::Place& place);
void RunPrepared(const RuntimeContext& ctx, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {}
......@@ -457,8 +457,9 @@ class OperatorWithKernel : public OperatorBase {
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
protected:
virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const;
......
......@@ -45,12 +45,6 @@ class Autograd {
Autograd() {}
void RunBackward(VarBase* var) {
PADDLE_ENFORCE(var->pre_op_->op_desc_);
PADDLE_ENFORCE(
var->grads_ ==
var->pre_op_->output_vars_[var->pre_op_out_name_][var->pre_op_out_idx_]
->grads_);
std::deque<OpBase*> ready;
ready.push_back(var->pre_op_);
......@@ -66,7 +60,7 @@ class Autograd {
const std::vector<VarBase*>& ingrads = it.second;
for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue;
OpBase* pre_op = (*ready_op->pre_ops_)[it.first][i];
OpBase* pre_op = ready_op->pre_ops_[it.first][i];
if (!pre_op) continue;
dep_counts[pre_op] -= 1;
......@@ -91,7 +85,7 @@ class Autograd {
while (!queue.empty()) {
OpBase* candidate = queue.front();
queue.pop_front();
for (auto it : *(candidate->pre_ops_)) {
for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) {
if (!pre_op) continue;
if (visited.find(pre_op) == visited.end()) {
......@@ -138,11 +132,13 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
VLOG(3) << "op grad " << grad_op_desc_->Type();
std::vector<std::unique_ptr<framework::Variable>> tmp_vars;
std::map<std::string, std::vector<framework::Variable*>> grad_outputs;
for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first];
for (size_t i = 0; i < it.second.size(); ++i) {
outputs.push_back(new framework::Variable());
tmp_vars.emplace_back(new framework::Variable());
outputs.push_back(tmp_vars.back().get());
outputs.back()->GetMutable<framework::LoDTensor>();
}
}
......@@ -155,7 +151,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc_);
opbase->Run(ctx, platform::CPUPlace());
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
platform::CPUPlace place;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
p.op.RuntimeInferShape(scope, place, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first];
......@@ -169,11 +173,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
void VarBase::RunBackward() {
if (!pre_op_) return;
auto grads_t = grads_->GetMutable<framework::LoDTensor>();
float* data = grads_t->mutable_data<float>(platform::CPUPlace());
std::fill(data, data + grads_t->numel(), 1.0);
if (!pre_op_) return;
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
......
......@@ -25,6 +25,59 @@
namespace paddle {
namespace imperative {
class PreparedOp {
public:
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx)
: op(op), ctx(ctx), func(func), dev_ctx(dev_ctx) {}
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel& op,
const platform::Place& place) {
framework::Scope dummy_scope;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type());
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.",
op.Type());
}
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = op.GetExpectedKernelType(
framework::ExecutionContext(op, dummy_scope, *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == framework::LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = framework::LibraryType::kPlain;
expected_kernel_key.data_layout_ = framework::DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key));
}
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
}
const framework::OperatorBase& op;
const framework::RuntimeContext& ctx;
framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx;
};
class OpBase;
class VarBase {
......@@ -62,30 +115,22 @@ class VarBase {
class OpBase {
public:
OpBase()
: pre_ops_(new std::map<std::string, std::vector<OpBase*>>()),
pre_ops_out_idx_(new std::map<std::string, std::vector<int>>()),
op_desc_(nullptr),
grad_op_desc_(nullptr) {}
OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {}
virtual ~OpBase() {
delete pre_ops_;
delete pre_ops_out_idx_;
if (grad_op_desc_) delete grad_op_desc_;
if (grad_to_var_) delete grad_to_var_;
}
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
framework::OpDesc* op_desc_;
framework::OpDesc* grad_op_desc_;
std::map<std::string, std::vector<VarBase*>> input_vars_;
std::map<std::string, std::vector<VarBase*>> output_vars_;
std::map<std::string, std::vector<OpBase*>>* pre_ops_;
std::map<std::string, std::vector<int>>* pre_ops_out_idx_;
framework::OpDesc* op_desc_;
std::map<std::string, std::vector<OpBase*>> pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
framework::OpDesc* grad_op_desc_;
std::unordered_map<std::string, std::string>* grad_to_var_;
std::map<std::string, std::vector<framework::Variable*>> grad_input_vars_;
std::map<std::string, std::vector<framework::Variable*>> grad_output_vars_;
framework::BlockDesc* block_;
......
......@@ -82,10 +82,10 @@ class Tracer {
invars.push_back(inp->var_);
vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) {
(*op->pre_ops_)[it.first].push_back(inp->pre_op_);
(*op->pre_ops_out_idx_)[it.first].push_back(inp->pre_op_out_idx_);
op->pre_ops_[it.first].push_back(inp->pre_op_);
op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_);
} else {
(*op->pre_ops_)[it.first].push_back(nullptr);
op->pre_ops_[it.first].push_back(nullptr);
}
VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
<< inp->var_->IsInitialized();
......@@ -118,24 +118,33 @@ class Tracer {
VLOG(3) << "tracer running " << op_desc->Type();
framework::RuntimeContext ctx(invars_map, outvars_map);
op_base->Run(ctx, platform::CPUPlace());
// op_base->RunPrepared(ctx, platform::CPUPlace());
// TODO(panyx0718): Cache p.
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(op_base.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
platform::CPUPlace place;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
p.op.RuntimeInferShape(scope, place, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
if (block == startup_block_) {
op->grad_op_desc_ = nullptr;
op->grad_to_var_ = nullptr;
} else {
framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
op->grad_op_desc_ = grad_op_desc;
op->grad_to_var_ = grad_to_var;
for (auto it : grad_op_desc->Inputs()) {
auto& grad_in_vars = op->grad_input_vars_[it.first];
for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = op->grad_to_var_->find(grad_invar);
if (var_it == op->grad_to_var_->end()) {
auto var_it = grad_to_var->find(grad_invar);
if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end());
grad_in_vars.push_back(fwd_var_it->second->var_);
......@@ -152,8 +161,8 @@ class Tracer {
auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = op->grad_to_var_->find(grad_outvar);
PADDLE_ENFORCE(var_it != op->grad_to_var_->end());
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) {
InitVar(var->var_, var->grads_);
......
......@@ -86,4 +86,5 @@ REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker,
REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<int64_t>);
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>);
......@@ -316,6 +316,8 @@ class LayerHelper(object):
if _in_imperative_mode():
self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs())
# In imperative mode, we want the returned parameter to be
# initialized so that it can be used imperatively.
return self.startup_program.global_block().create_parameter(
dtype=dtype,
shape=shape,
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import contextlib
import unittest
import numpy as np
......@@ -82,12 +81,10 @@ class TestImperative(unittest.TestCase):
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[3], append_batch_size=False)
x = fluid.layers.relu(inp)
x_for_debug = x
x = fluid.layers.elementwise_mul(x, x)
x = fluid.layers.reduce_sum(x)
l = MyLayer()
x = l(inp)[0]
param_grads = fluid.backward.append_backward(
x, parameter_list=[x_for_debug.name])[0]
x, parameter_list=[l._x_for_debug.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
static_out, static_grad = exe.run(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册