未验证 提交 5c1eda19 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove InterpretercoreInferShapeContext (#51209)

* Remove InterpretercoreInferShapeContext

* Fix lod errors
上级 f1f2a253
...@@ -129,7 +129,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -129,7 +129,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs["X"] = {scope_->FindVar(var_name)}; runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {scope_->Var(new_var_name)}; runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
op.get()->Info().infer_shape_(&infer_shape_ctx); op.get()->Info().infer_shape_(&infer_shape_ctx);
// 2. choose kernel // 2. choose kernel
......
...@@ -824,8 +824,7 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -824,8 +824,7 @@ void BuildOpFuncList(const platform::Place& place,
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) { op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
VLOG(4) << "infer shape"; VLOG(4) << "infer shape";
InterpretercoreInferShapeContext infer_shape_ctx(*op, RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT // TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel. // inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx); op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
......
...@@ -24,525 +24,6 @@ ...@@ -24,525 +24,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
InterpretercoreInferShapeContext::InterpretercoreInferShapeContext(
const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx), can_skip_lod_(false) {}
bool InterpretercoreInferShapeContext::HasInput(const std::string& name) const {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(),
1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool InterpretercoreInferShapeContext::HasOutput(
const std::string& name) const {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(),
1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
bool InterpretercoreInferShapeContext::HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
bool InterpretercoreInferShapeContext::HasInputs(
const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool InterpretercoreInferShapeContext::HasOutputs(const std::string& name,
bool allow_null) const {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
}
return true;
}
AttrReader InterpretercoreInferShapeContext::Attrs() const {
return AttrReader(op_.Attrs(), op_.RuntimeAttrs());
}
std::vector<std::string> InterpretercoreInferShapeContext::Inputs(
const std::string& name) const {
return op_.Inputs(name);
}
std::vector<std::string> InterpretercoreInferShapeContext::Outputs(
const std::string& name) const {
return op_.Outputs(name);
}
std::string InterpretercoreInferShapeContext::GetInputNameByIdx(
size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string InterpretercoreInferShapeContext::GetOutputNameByIdx(
size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void InterpretercoreInferShapeContext::ShareDim(const std::string& in,
const std::string& out,
size_t i,
size_t j) {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
in_var->Type(),
out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in, out));
if (in_var->IsType<phi::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<phi::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<phi::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<phi::DenseTensor>()) {
auto& in_lod_tensor = in_var->Get<phi::DenseTensor>();
auto* out_lod_tensor = out_var->GetMutable<phi::DenseTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be phi::DenseTensor "
"or SelectedRows."));
}
}
void InterpretercoreInferShapeContext::ShareAllLoD(
const std::string& in, const std::string& out) const {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(out_it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Output [%s] found error in Op [%s]", out, op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(),
out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be phi::DenseTensor.",
i,
out_var_names[i]));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void InterpretercoreInferShapeContext::ShareLoD(const std::string& in,
const std::string& out,
size_t i,
size_t j) const {
if (can_skip_lod_) {
return;
}
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be phi::DenseTensor.", j, out));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out phi::DenseTensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
int32_t InterpretercoreInferShapeContext::GetLoDLevel(const std::string& in,
size_t i) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
void InterpretercoreInferShapeContext::SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
bool InterpretercoreInferShapeContext::IsRuntime() const { return true; }
bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
phi::DataLayout::ONEDNN));
} catch (std::bad_cast& exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
InterpretercoreInferShapeContext::GetOutputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim InterpretercoreInferShapeContext::GetInputDim(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument(
"Input(%s) should hold one element, but now it holds %zu elements.",
name,
vars.size()));
return this->GetDim(vars[0]);
}
std::vector<DDim> InterpretercoreInferShapeContext::GetInputsDim(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
proto::VarType::Type InterpretercoreInferShapeContext::GetInputVarType(
const std::string& name) const {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type>
InterpretercoreInferShapeContext::GetInputsVarType(
const std::string& name) const {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type>
InterpretercoreInferShapeContext::GetOutputsVarType(
const std::string& name) const {
return GetVarTypes(OutputVars(name));
}
void InterpretercoreInferShapeContext::SetOutputDim(const std::string& name,
const DDim& dim) {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name,
vars.size()));
SetDim(vars[0], dim);
}
void InterpretercoreInferShapeContext::SetOutputsDim(
const std::string& name, const std::vector<DDim>& dims) {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
const phi::ArgumentMappingFn*
InterpretercoreInferShapeContext::GetPhiArgumentMappingFn() const {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature*
InterpretercoreInferShapeContext::GetPhiDefaultKernelSignature() const {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
void InterpretercoreInferShapeContext::SetSkipLoD(bool skip) {
can_skip_lod_ = skip;
}
DDim InterpretercoreInferShapeContext::GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only phi::DenseTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
ToTypeName(var->Type())));
}
}
std::vector<DDim> InterpretercoreInferShapeContext::GetDims(
const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(
vars.begin(), vars.end(), std::back_inserter(ret), [this](Variable* var) {
return this->GetDim(var);
});
return ret;
}
std::vector<DDim> InterpretercoreInferShapeContext::GetRepeatedDims(
const std::string& name) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
void InterpretercoreInferShapeContext::SetDim(Variable* var, const DDim& dim) {
if (var->IsType<phi::DenseTensor>()) {
var->GetMutable<phi::DenseTensor>()->Resize(dim);
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect phi::DenseTensor or SelectedRows, but "
"received "
"(%s).",
ToTypeName(var->Type())));
}
}
void InterpretercoreInferShapeContext::SetDims(
const std::vector<Variable*>& vars, const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length,
dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void InterpretercoreInferShapeContext::SetRepeatedDims(
const std::string& name, const std::vector<DDim>& dims) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
std::vector<proto::VarType::Type> InterpretercoreInferShapeContext::GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(
vars.begin(),
vars.end(),
retv.begin(),
std::bind(std::mem_fn(&InterpretercoreInferShapeContext::GetVarType),
this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type InterpretercoreInferShapeContext::GetVarType(
Variable* var) const {
return ToVarType(var->Type());
}
const std::vector<Variable*>& InterpretercoreInferShapeContext::InputVars(
const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
}
const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars(
const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
}
VariableScope::VariableScope(Scope* scope) { VariableScope::VariableScope(Scope* scope) {
// for @EMPTY@ variable // for @EMPTY@ variable
name2id_[kEmptyVarName] = kEmptyVarIndex; name2id_[kEmptyVarName] = kEmptyVarIndex;
...@@ -747,7 +228,7 @@ void Instruction::ResetContext(const VariableValueMap& in_vars, ...@@ -747,7 +228,7 @@ void Instruction::ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) { const VariableValueMap& out_vars) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset( infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get())); new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an // NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference. // empty here to avoid illegal local reference.
static framework::Scope scope_; static framework::Scope scope_;
...@@ -760,7 +241,7 @@ void Instruction::ResetContextWithScope(const VariableValueMap& in_vars, ...@@ -760,7 +241,7 @@ void Instruction::ResetContextWithScope(const VariableValueMap& in_vars,
const framework::Scope& scope) { const framework::Scope& scope) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset( infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get())); new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get()));
execution_ctx_.reset( execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get())); new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get()));
} }
...@@ -769,8 +250,8 @@ std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const { ...@@ -769,8 +250,8 @@ std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
return runtime_ctx_; return runtime_ctx_;
} }
std::shared_ptr<InterpretercoreInferShapeContext> std::shared_ptr<RuntimeInferShapeContext> Instruction::InnerInferShapeContext()
Instruction::InnerInferShapeContext() const { const {
return infershape_ctx_; return infershape_ctx_;
} }
......
...@@ -44,115 +44,6 @@ constexpr const char* kH2DStream = "H2DStream"; ...@@ -44,115 +44,6 @@ constexpr const char* kH2DStream = "H2DStream";
constexpr int kEmptyVarIndex = 0; constexpr int kEmptyVarIndex = 0;
class InterpretercoreInferShapeContext : public InferShapeContext {
public:
InterpretercoreInferShapeContext(const OperatorBase& op,
const RuntimeContext& ctx);
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
std::vector<std::string> Inputs(const std::string& name) const override;
std::vector<std::string> Outputs(const std::string& name) const override;
std::string GetInputNameByIdx(size_t idx) const override;
std::string GetOutputNameByIdx(size_t idx) const override;
void ShareDim(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override;
void ShareAllLoD(const std::string& in,
const std::string& out) const override;
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override;
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
void SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override;
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
std::vector<DDim> GetInputsDim(const std::string& name) const override;
proto::VarType::Type GetInputVarType(const std::string& name) const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override;
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override;
void SetOutputDim(const std::string& name, const DDim& dim) override;
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip);
protected:
DDim GetDim(Variable* var) const;
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
void SetDim(Variable* var, const DDim& dim);
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims);
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const;
proto::VarType::Type GetVarType(Variable* var) const;
private:
const std::vector<Variable*>& InputVars(const std::string& name) const;
const std::vector<Variable*>& OutputVars(const std::string& name) const;
const OperatorBase& op_;
const RuntimeContext& ctx_;
bool can_skip_lod_;
};
struct OpKernelFunc { struct OpKernelFunc {
OpKernelComputeFunc compute_func_; OpKernelComputeFunc compute_func_;
}; };
...@@ -260,7 +151,6 @@ enum class OpFuncType { ...@@ -260,7 +151,6 @@ enum class OpFuncType {
kGpuSync, // GPU or other device kernel without asynchronous operation kGpuSync, // GPU or other device kernel without asynchronous operation
kGpuAsync // GPU or other device kernel with asynchronous operation kGpuAsync // GPU or other device kernel with asynchronous operation
}; };
class RuntimeInferShapeContext;
struct OpFuncNode { struct OpFuncNode {
int stream_priority_{0}; // lower value, higher priority int stream_priority_{0}; // lower value, higher priority
...@@ -357,8 +247,7 @@ class Instruction { ...@@ -357,8 +247,7 @@ class Instruction {
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const; std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext() std::shared_ptr<RuntimeInferShapeContext> InnerInferShapeContext() const;
const;
std::shared_ptr<ExecutionContext> InnerExecutionContext() const; std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
...@@ -390,7 +279,7 @@ class Instruction { ...@@ -390,7 +279,7 @@ class Instruction {
const platform::DeviceContext& dev_ctx_; // not owned const platform::DeviceContext& dev_ctx_; // not owned
std::shared_ptr<RuntimeContext> runtime_ctx_; std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_; std::shared_ptr<RuntimeInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_; std::shared_ptr<ExecutionContext> execution_ctx_;
std::vector<size_t> gc_check_vars_; std::vector<size_t> gc_check_vars_;
......
...@@ -25,7 +25,6 @@ limitations under the License. */ ...@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
...@@ -214,6 +213,512 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, ...@@ -214,6 +213,512 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
} }
} }
RuntimeInferShapeContext::RuntimeInferShapeContext(const OperatorBase& op,
const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {}
bool RuntimeInferShapeContext::HasInput(const std::string& name) const {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(),
1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool RuntimeInferShapeContext::HasOutput(const std::string& name) const {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(),
1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
bool RuntimeInferShapeContext::HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
bool RuntimeInferShapeContext::HasInputs(const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool RuntimeInferShapeContext::HasOutputs(const std::string& name,
bool allow_null) const {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
}
return true;
}
AttrReader RuntimeInferShapeContext::Attrs() const {
return AttrReader(op_.Attrs(), op_.RuntimeAttrs());
}
std::vector<std::string> RuntimeInferShapeContext::Inputs(
const std::string& name) const {
return op_.Inputs(name);
}
std::vector<std::string> RuntimeInferShapeContext::Outputs(
const std::string& name) const {
return op_.Outputs(name);
}
std::string RuntimeInferShapeContext::GetInputNameByIdx(size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string RuntimeInferShapeContext::GetOutputNameByIdx(size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void RuntimeInferShapeContext::ShareDim(const std::string& in,
const std::string& out,
size_t i,
size_t j) {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
in_var->Type(),
out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in, out));
if (in_var->IsType<phi::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<phi::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<phi::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<phi::DenseTensor>()) {
auto& in_lod_tensor = in_var->Get<phi::DenseTensor>();
auto* out_lod_tensor = out_var->GetMutable<phi::DenseTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be phi::DenseTensor "
"or SelectedRows."));
}
}
void RuntimeInferShapeContext::ShareAllLoD(const std::string& in,
const std::string& out) const {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(out_it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Output [%s] found error in Op [%s]", out, op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(),
out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be phi::DenseTensor.",
i,
out_var_names[i]));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void RuntimeInferShapeContext::ShareLoD(const std::string& in,
const std::string& out,
size_t i,
size_t j) const {
if (can_skip_lod_) {
return;
}
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be phi::DenseTensor.", j, out));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out phi::DenseTensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
int32_t RuntimeInferShapeContext::GetLoDLevel(const std::string& in,
size_t i) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
void RuntimeInferShapeContext::SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
bool RuntimeInferShapeContext::IsRuntime() const { return true; }
bool RuntimeInferShapeContext::IsRunMKLDNNKernel() const {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
phi::DataLayout::ONEDNN));
} catch (std::bad_cast& exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
RuntimeInferShapeContext::GetInputVarPtrs(const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
RuntimeInferShapeContext::GetOutputVarPtrs(const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim RuntimeInferShapeContext::GetInputDim(const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument(
"Input(%s) should hold one element, but now it holds %zu elements.",
name,
vars.size()));
return this->GetDim(vars[0]);
}
std::vector<DDim> RuntimeInferShapeContext::GetInputsDim(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
proto::VarType::Type RuntimeInferShapeContext::GetInputVarType(
const std::string& name) const {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type> RuntimeInferShapeContext::GetInputsVarType(
const std::string& name) const {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> RuntimeInferShapeContext::GetOutputsVarType(
const std::string& name) const {
return GetVarTypes(OutputVars(name));
}
void RuntimeInferShapeContext::SetOutputDim(const std::string& name,
const DDim& dim) {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name,
vars.size()));
SetDim(vars[0], dim);
}
void RuntimeInferShapeContext::SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
const phi::ArgumentMappingFn*
RuntimeInferShapeContext::GetPhiArgumentMappingFn() const {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature*
RuntimeInferShapeContext::GetPhiDefaultKernelSignature() const {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
void RuntimeInferShapeContext::SetSkipLoD(bool skip) { can_skip_lod_ = skip; }
DDim RuntimeInferShapeContext::GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only phi::DenseTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
ToTypeName(var->Type())));
}
}
std::vector<DDim> RuntimeInferShapeContext::GetDims(
const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(
vars.begin(), vars.end(), std::back_inserter(ret), [this](Variable* var) {
return this->GetDim(var);
});
return ret;
}
std::vector<DDim> RuntimeInferShapeContext::GetRepeatedDims(
const std::string& name) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
void RuntimeInferShapeContext::SetDim(Variable* var, const DDim& dim) {
if (var->IsType<phi::DenseTensor>()) {
var->GetMutable<phi::DenseTensor>()->Resize(dim);
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect phi::DenseTensor or SelectedRows, but "
"received "
"(%s).",
ToTypeName(var->Type())));
}
}
void RuntimeInferShapeContext::SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length,
dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void RuntimeInferShapeContext::SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
std::vector<proto::VarType::Type> RuntimeInferShapeContext::GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(),
vars.end(),
retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type RuntimeInferShapeContext::GetVarType(Variable* var) const {
return ToVarType(var->Type());
}
const std::vector<Variable*>& RuntimeInferShapeContext::InputVars(
const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
}
const std::vector<Variable*>& RuntimeInferShapeContext::OutputVars(
const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) { void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
try { try {
VLOG(4) << place << " " << DebugStringEx(&scope); VLOG(4) << place << " " << DebugStringEx(&scope);
...@@ -710,510 +1215,6 @@ bool OpSupportGPU(const std::string& op_type) { ...@@ -710,510 +1215,6 @@ bool OpSupportGPU(const std::string& op_type) {
return false; return false;
} }
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(),
1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(),
1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
bool HasAttr(const std::string& name) const override {
return op_.HasAttr(name);
}
bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name,
bool allow_null = false) const override {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
}
return true;
}
AttrReader Attrs() const override {
return AttrReader(op_.Attrs(), op_.RuntimeAttrs());
}
std::vector<std::string> Inputs(const std::string& name) const override {
return op_.Inputs(name);
}
std::vector<std::string> Outputs(const std::string& name) const override {
return op_.Outputs(name);
}
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(
idx,
op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
in_var->Type(),
out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.",
in,
out));
if (in_var->IsType<phi::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<phi::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<phi::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<phi::DenseTensor>()) {
auto& in_lod_tensor = in_var->Get<phi::DenseTensor>();
auto* out_lod_tensor = out_var->GetMutable<phi::DenseTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be phi::DenseTensor "
"or SelectedRows."));
}
}
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Output [%s] found error in Op [%s]", out, op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(),
out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be phi::DenseTensor.",
i,
out_var_names[i]));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be phi::DenseTensor.",
j,
out));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out phi::DenseTensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
void SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
phi::DataLayout::ONEDNN));
} catch (const std::bad_cast& exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = OutputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument(
"Input(%s) should hold one element, but now it holds %zu elements.",
name,
vars.size()));
return this->GetDim(vars[0]);
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
proto::VarType::Type GetInputVarType(const std::string& name) const override {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name,
vars.size()));
SetDim(vars[0], dim);
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
protected:
DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only phi::DenseTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
ToTypeName(var->Type())));
}
}
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(vars.begin(),
vars.end(),
std::back_inserter(ret),
[this](Variable* var) { return this->GetDim(var); });
return ret;
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<phi::DenseTensor>()) {
var->GetMutable<phi::DenseTensor>()->Resize(dim);
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect phi::DenseTensor or SelectedRows, but "
"received "
"(%s).",
ToTypeName(var->Type())));
}
}
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length,
dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(),
vars.end(),
retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(Variable* var) const {
return ToVarType(var->Type());
}
private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
}
const OperatorBase& op_;
const RuntimeContext& ctx_;
};
struct OperatorWithKernel::CacheImpl { struct OperatorWithKernel::CacheImpl {
static const char kNotAllowInferShapeCahce[]; static const char kNotAllowInferShapeCahce[];
explicit CacheImpl(phi::KernelContext* kernel_ctx, explicit CacheImpl(phi::KernelContext* kernel_ctx,
......
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
...@@ -47,7 +48,6 @@ limitations under the License. */ ...@@ -47,7 +48,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeContext;
class OpInfo; class OpInfo;
class Scope; class Scope;
class Variable; class Variable;
...@@ -146,6 +146,114 @@ class RuntimeContext { ...@@ -146,6 +146,114 @@ class RuntimeContext {
VariableValueMap outputs; VariableValueMap outputs;
}; };
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx);
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
std::vector<std::string> Inputs(const std::string& name) const override;
std::vector<std::string> Outputs(const std::string& name) const override;
std::string GetInputNameByIdx(size_t idx) const override;
std::string GetOutputNameByIdx(size_t idx) const override;
void ShareDim(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override;
void ShareAllLoD(const std::string& in,
const std::string& out) const override;
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override;
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
void SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override;
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
std::vector<DDim> GetInputsDim(const std::string& name) const override;
proto::VarType::Type GetInputVarType(const std::string& name) const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override;
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override;
void SetOutputDim(const std::string& name, const DDim& dim) override;
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip);
protected:
DDim GetDim(Variable* var) const;
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
void SetDim(Variable* var, const DDim& dim);
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims);
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const;
proto::VarType::Type GetVarType(Variable* var) const;
private:
const std::vector<Variable*>& InputVars(const std::string& name) const;
const std::vector<Variable*>& OutputVars(const std::string& name) const;
const OperatorBase& op_;
const RuntimeContext& ctx_;
bool can_skip_lod_{false};
};
/** /**
* OperatorBase has the basic elements that Net will call to do computation. * OperatorBase has the basic elements that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册