提交 87efa600 编写于 作者: Q Qiao Longfei 提交者: GitHub

add some check to operator.run (#4544)

* fix cond_op_test and add some check to operator.run

* tmp

* optimize kernel check
上级 240ed5e6
......@@ -245,5 +245,12 @@ std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
return res;
}
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
return os;
}
} // namespace framework
} // namespace paddle
......@@ -478,9 +478,25 @@ class OperatorWithKernel : public OperatorBase {
this->InferShape(&infer_shape_ctx);
ExecutionContext ctx(*this, scope, dev_ctx);
auto& opKernel = AllOpKernels().at(type_).at(
OpKernelKey(IndicateDataType(ctx), dev_ctx));
opKernel->Compute(ctx);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW("op[%s] has no kernel", type_);
}
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
auto kernel_iter = kernels.find(kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
kernel_key);
}
kernel_iter->second->Compute(ctx);
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......@@ -529,5 +545,8 @@ class OperatorWithKernel : public OperatorBase {
}
};
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key);
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册