未验证 提交 5c1eda19 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove InterpretercoreInferShapeContext (#51209)

* Remove InterpretercoreInferShapeContext

* Fix lod errors
上级 f1f2a253
......@@ -129,7 +129,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
RuntimeContext runtime_context({}, {});
runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
op.get()->Info().infer_shape_(&infer_shape_ctx);
// 2. choose kernel
......
......@@ -824,8 +824,7 @@ void BuildOpFuncList(const platform::Place& place,
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
VLOG(4) << "infer shape";
InterpretercoreInferShapeContext infer_shape_ctx(*op,
runtime_context);
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
......
......@@ -24,525 +24,6 @@
namespace paddle {
namespace framework {
InterpretercoreInferShapeContext::InterpretercoreInferShapeContext(
const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx), can_skip_lod_(false) {}
bool InterpretercoreInferShapeContext::HasInput(const std::string& name) const {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(),
1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool InterpretercoreInferShapeContext::HasOutput(
const std::string& name) const {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(),
1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
bool InterpretercoreInferShapeContext::HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
bool InterpretercoreInferShapeContext::HasInputs(
const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool InterpretercoreInferShapeContext::HasOutputs(const std::string& name,
bool allow_null) const {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
}
return true;
}
AttrReader InterpretercoreInferShapeContext::Attrs() const {
return AttrReader(op_.Attrs(), op_.RuntimeAttrs());
}
std::vector<std::string> InterpretercoreInferShapeContext::Inputs(
const std::string& name) const {
return op_.Inputs(name);
}
std::vector<std::string> InterpretercoreInferShapeContext::Outputs(
const std::string& name) const {
return op_.Outputs(name);
}
std::string InterpretercoreInferShapeContext::GetInputNameByIdx(
size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string InterpretercoreInferShapeContext::GetOutputNameByIdx(
size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void InterpretercoreInferShapeContext::ShareDim(const std::string& in,
const std::string& out,
size_t i,
size_t j) {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
in_var->Type(),
out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in, out));
if (in_var->IsType<phi::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<phi::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<phi::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<phi::DenseTensor>()) {
auto& in_lod_tensor = in_var->Get<phi::DenseTensor>();
auto* out_lod_tensor = out_var->GetMutable<phi::DenseTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be phi::DenseTensor "
"or SelectedRows."));
}
}
void InterpretercoreInferShapeContext::ShareAllLoD(
const std::string& in, const std::string& out) const {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(out_it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Output [%s] found error in Op [%s]", out, op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(),
out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be phi::DenseTensor.",
i,
out_var_names[i]));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void InterpretercoreInferShapeContext::ShareLoD(const std::string& in,
const std::string& out,
size_t i,
size_t j) const {
if (can_skip_lod_) {
return;
}
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be phi::DenseTensor.", j, out));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out phi::DenseTensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
int32_t InterpretercoreInferShapeContext::GetLoDLevel(const std::string& in,
size_t i) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
void InterpretercoreInferShapeContext::SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
bool InterpretercoreInferShapeContext::IsRuntime() const { return true; }
bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
phi::DataLayout::ONEDNN));
} catch (std::bad_cast& exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
InterpretercoreInferShapeContext::GetOutputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim InterpretercoreInferShapeContext::GetInputDim(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument(
"Input(%s) should hold one element, but now it holds %zu elements.",
name,
vars.size()));
return this->GetDim(vars[0]);
}
std::vector<DDim> InterpretercoreInferShapeContext::GetInputsDim(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
proto::VarType::Type InterpretercoreInferShapeContext::GetInputVarType(
const std::string& name) const {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type>
InterpretercoreInferShapeContext::GetInputsVarType(
const std::string& name) const {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type>
InterpretercoreInferShapeContext::GetOutputsVarType(
const std::string& name) const {
return GetVarTypes(OutputVars(name));
}
void InterpretercoreInferShapeContext::SetOutputDim(const std::string& name,
const DDim& dim) {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name,
vars.size()));
SetDim(vars[0], dim);
}
void InterpretercoreInferShapeContext::SetOutputsDim(
const std::string& name, const std::vector<DDim>& dims) {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
const phi::ArgumentMappingFn*
InterpretercoreInferShapeContext::GetPhiArgumentMappingFn() const {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature*
InterpretercoreInferShapeContext::GetPhiDefaultKernelSignature() const {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
void InterpretercoreInferShapeContext::SetSkipLoD(bool skip) {
can_skip_lod_ = skip;
}
DDim InterpretercoreInferShapeContext::GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only phi::DenseTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
ToTypeName(var->Type())));
}
}
std::vector<DDim> InterpretercoreInferShapeContext::GetDims(
const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(
vars.begin(), vars.end(), std::back_inserter(ret), [this](Variable* var) {
return this->GetDim(var);
});
return ret;
}
std::vector<DDim> InterpretercoreInferShapeContext::GetRepeatedDims(
const std::string& name) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
void InterpretercoreInferShapeContext::SetDim(Variable* var, const DDim& dim) {
if (var->IsType<phi::DenseTensor>()) {
var->GetMutable<phi::DenseTensor>()->Resize(dim);
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect phi::DenseTensor or SelectedRows, but "
"received "
"(%s).",
ToTypeName(var->Type())));
}
}
void InterpretercoreInferShapeContext::SetDims(
const std::vector<Variable*>& vars, const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length,
dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void InterpretercoreInferShapeContext::SetRepeatedDims(
const std::string& name, const std::vector<DDim>& dims) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
std::vector<proto::VarType::Type> InterpretercoreInferShapeContext::GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(
vars.begin(),
vars.end(),
retv.begin(),
std::bind(std::mem_fn(&InterpretercoreInferShapeContext::GetVarType),
this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type InterpretercoreInferShapeContext::GetVarType(
Variable* var) const {
return ToVarType(var->Type());
}
const std::vector<Variable*>& InterpretercoreInferShapeContext::InputVars(
const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
}
const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars(
const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
}
VariableScope::VariableScope(Scope* scope) {
// for @EMPTY@ variable
name2id_[kEmptyVarName] = kEmptyVarIndex;
......@@ -747,7 +228,7 @@ void Instruction::ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference.
static framework::Scope scope_;
......@@ -760,7 +241,7 @@ void Instruction::ResetContextWithScope(const VariableValueMap& in_vars,
const framework::Scope& scope) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get()));
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get()));
}
......@@ -769,8 +250,8 @@ std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
return runtime_ctx_;
}
std::shared_ptr<InterpretercoreInferShapeContext>
Instruction::InnerInferShapeContext() const {
std::shared_ptr<RuntimeInferShapeContext> Instruction::InnerInferShapeContext()
const {
return infershape_ctx_;
}
......
......@@ -44,115 +44,6 @@ constexpr const char* kH2DStream = "H2DStream";
constexpr int kEmptyVarIndex = 0;
class InterpretercoreInferShapeContext : public InferShapeContext {
public:
InterpretercoreInferShapeContext(const OperatorBase& op,
const RuntimeContext& ctx);
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
std::vector<std::string> Inputs(const std::string& name) const override;
std::vector<std::string> Outputs(const std::string& name) const override;
std::string GetInputNameByIdx(size_t idx) const override;
std::string GetOutputNameByIdx(size_t idx) const override;
void ShareDim(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override;
void ShareAllLoD(const std::string& in,
const std::string& out) const override;
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override;
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
void SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override;
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
std::vector<DDim> GetInputsDim(const std::string& name) const override;
proto::VarType::Type GetInputVarType(const std::string& name) const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override;
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override;
void SetOutputDim(const std::string& name, const DDim& dim) override;
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip);
protected:
DDim GetDim(Variable* var) const;
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
void SetDim(Variable* var, const DDim& dim);
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims);
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const;
proto::VarType::Type GetVarType(Variable* var) const;
private:
const std::vector<Variable*>& InputVars(const std::string& name) const;
const std::vector<Variable*>& OutputVars(const std::string& name) const;
const OperatorBase& op_;
const RuntimeContext& ctx_;
bool can_skip_lod_;
};
struct OpKernelFunc {
OpKernelComputeFunc compute_func_;
};
......@@ -260,7 +151,6 @@ enum class OpFuncType {
kGpuSync, // GPU or other device kernel without asynchronous operation
kGpuAsync // GPU or other device kernel with asynchronous operation
};
class RuntimeInferShapeContext;
struct OpFuncNode {
int stream_priority_{0}; // lower value, higher priority
......@@ -357,8 +247,7 @@ class Instruction {
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
const;
std::shared_ptr<RuntimeInferShapeContext> InnerInferShapeContext() const;
std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
......@@ -390,7 +279,7 @@ class Instruction {
const platform::DeviceContext& dev_ctx_; // not owned
std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
std::shared_ptr<RuntimeInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_;
std::vector<size_t> gc_check_vars_;
......
......@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/framework/var_type.h"
......@@ -214,625 +213,235 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
}
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
try {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
#endif
} else if (platform::is_mlu_place(place)) {
#ifndef PADDLE_WITH_MLU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with MLU support.",
place));
#else
auto dev_id = place.device;
platform::SetMLUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
{
// TODO(wangchaochaohu) : refine code to use only one RecordEvent)
// in order to record different op type cost time
// and different op name cost time,we set two event.
platform::RecordEvent op_type_record_event(
Type(), platform::TracerEventType::Operator, 1);
auto op_name = platform::OpName(outputs_, Type());
platform::RecordEvent op_name_record_event(
op_name,
platform::TracerEventType::Operator,
FLAGS_enable_host_event_recorder_hook ? 20 : 1,
platform::EventRole::kUniqueOp);
RunImpl(scope, place);
}
RuntimeInferShapeContext::RuntimeInferShapeContext(const OperatorBase& op,
const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {}
VLOG(3) << GetExecutionPlace(place) << " " << DebugStringEx(&scope);
} catch (platform::EnforceNotMet& exception) {
framework::InsertCallStackInfo(Type(), Attrs(), &exception);
throw std::move(exception);
} catch (platform::EOFException&) {
std::rethrow_exception(std::current_exception());
} catch (std::exception& ex) {
LOG(WARNING) << Type() << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", " << ex.what();
std::rethrow_exception(std::current_exception());
} catch (...) {
LOG(WARNING) << Type() << " raises an unknown exception";
std::rethrow_exception(std::current_exception());
bool RuntimeInferShapeContext::HasInput(const std::string& name) const {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
}
bool OperatorBase::HasInputs(const std::string& name) const {
return inputs_.find(name) != inputs_.end();
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(
ins.size(),
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's input %s should contain only one variable.",
type_,
name));
return ins.empty() ? kEmptyVarName : ins[0];
}
const std::vector<std::string>& OperatorBase::Inputs(
const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE_NE(
it,
inputs_.end(),
platform::errors::NotFound(
"Operator %s does not have the input %s.", type_, name));
return it->second;
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.find(name) != outputs_.end()) {
return true;
} else {
bool RuntimeInferShapeContext::HasOutput(const std::string& name) const {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
}
std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name);
PADDLE_ENFORCE_LE(
outs.size(),
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's output %s should contain only one variable.",
type_,
name));
return outs.empty() ? kEmptyVarName : outs[0];
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
const std::vector<std::string>& OperatorBase::Outputs(
const std::string& name) const {
auto it = outputs_.find(name);
PADDLE_ENFORCE_NE(
it,
outputs_.end(),
platform::errors::NotFound(
"Operator %s does not have an output called %s.", type_, name));
return it->second;
bool RuntimeInferShapeContext::HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{";
const std::unordered_set<std::string>* no_need_buffer_vars = nullptr;
if (info_ && info_->NoNeedBufferVarsInferer()) {
no_need_buffer_vars =
&(Info().NoNeedBufferVarsInferer()(Inputs(), Outputs(), Attrs()));
if (no_need_buffer_vars->empty()) no_need_buffer_vars = nullptr;
}
for (auto it = inputs_.begin(); it != inputs_.end();) {
auto& input = *it;
bool is_no_need_buffer_var =
(no_need_buffer_vars && no_need_buffer_vars->count(input.first) > 0);
ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) {
auto var_name = input.second[i];
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
ss << "[uninited]";
} else {
int row_size = GetRowSize(*scope, var_name);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
std::string dtype = is_no_need_buffer_var
? "unknown_dtype"
: GetDtype(*scope, var_name);
std::string place = is_no_need_buffer_var
? "unknown_place"
: GetPlace(*scope, var_name);
ss << ":" << dtype;
ss << "[" << GetDimsDebug(*scope, var_name, true) << "]";
ss << "(" << GetLoDDebug(*scope, var_name) << ")";
ss << "(" << place << ")";
}
}
if (i != input.second.size() - 1) {
ss << ", ";
}
}
ss << "]";
++it;
if (it != inputs_.end()) {
ss << ", ";
}
}
ss << "}, outputs:{";
for (auto it = outputs_.begin(); it != outputs_.end();) {
auto& output = *it;
ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) {
auto var_name = output.second[i];
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
ss << "[uninited]";
} else {
int row_size = GetRowSize(*scope, output.second[i]);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
std::string dtype = GetDtype(*scope, output.second[i]);
ss << ":" << dtype;
ss << "[" << GetDimsDebug(*scope, var_name, true) << "]";
ss << "(" << GetLoDDebug(*scope, var_name) << ")";
ss << "(" << GetPlace(*scope, var_name) << ")";
bool RuntimeInferShapeContext::HasInputs(const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
if (i != output.second.size() - 1) {
ss << ", ";
}
return true;
}
bool RuntimeInferShapeContext::HasOutputs(const std::string& name,
bool allow_null) const {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
ss << "]";
++it;
if (it != outputs_.end()) {
ss << ", ";
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
}
ss << "}.";
return ss.str();
return true;
}
OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: type_(type),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
// NOTE(zjl): why op_info may be nullptr?
info_(OpInfoMap::Instance().GetNullable(type)) {
// In dygraph mode, all the OperatorBase will be constructed by function:
// framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
// Inputs, outputs and attrs will be set to empty map
// to improve the execution efficiency of dygraph.
if (inputs_.size() > 0 || outputs_.size() > 0) {
GenerateTemporaryNames();
CheckAllInputOutputSet();
}
// In OperatorBase level, all attributes with VarDesc type will be considered
// as Input.
for (auto& attr : FilterAttrVar(attrs)) {
VLOG(3) << "found Attribute with Variable type: " << attr.first;
inputs_[attr.first] = std::move(AttrVarNames(attr.second));
attrs_.erase(attr.first);
}
AttrReader RuntimeInferShapeContext::Attrs() const {
return AttrReader(op_.Attrs(), op_.RuntimeAttrs());
}
std::vector<std::string> OperatorBase::InputVars() const {
std::vector<std::string> ret_val;
for (auto& o : inputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
std::vector<std::string> RuntimeInferShapeContext::Inputs(
const std::string& name) const {
return op_.Inputs(name);
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto& info = Info();
std::vector<std::string> RuntimeInferShapeContext::Outputs(
const std::string& name) const {
return op_.Outputs(name);
}
// get all OpProto::Var for outputs
for (auto& o : info.Proto().outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
std::string RuntimeInferShapeContext::GetInputNameByIdx(size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
void OperatorBase::CheckAllInputOutputSet() const {
if (info_ == nullptr || info_->proto_ == nullptr) return;
std::string RuntimeInferShapeContext::GetOutputNameByIdx(size_t idx) const {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
for (auto& in : info_->Proto().inputs()) {
if (!in.dispensable() && !in.extra()) {
void RuntimeInferShapeContext::ShareDim(const std::string& in,
const std::string& out,
size_t i,
size_t j) {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
inputs_.find(in.name()),
inputs_.end(),
platform::errors::NotFound(
"Operator %s's input (%s) is not set.", Type(), in.name()));
}
}
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
for (auto& out : info_->Proto().outputs()) {
if (!out.dispensable() && !out.extra() && !out.intermediate()) {
PADDLE_ENFORCE_NE(
outputs_.find(out.name()),
outputs_.end(),
platform::errors::NotFound(
"Operator %s's output (%s) is not set.", Type(), out.name()));
}
}
}
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
void OperatorBase::GenerateTemporaryNames() {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
PADDLE_ENFORCE_EQ(
in_var->Type(),
out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in, out));
const phi::DenseTensor* GetLoDTensorOrSelectedRowsValueFromVar(
const Variable& var) {
if (var.IsType<phi::DenseTensor>()) {
return static_cast<const phi::DenseTensor*>(&(var.Get<phi::DenseTensor>()));
} else if (var.IsType<phi::SelectedRows>()) {
return &(var.Get<phi::SelectedRows>().value());
if (in_var->IsType<phi::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<phi::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<phi::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<phi::DenseTensor>()) {
auto& in_lod_tensor = in_var->Get<phi::DenseTensor>();
auto* out_lod_tensor = out_var->GetMutable<phi::DenseTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type is %s, expect phi::DenseTensor or SelectedRows.",
ToTypeName(var.Type())));
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be phi::DenseTensor "
"or SelectedRows."));
}
}
phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
if (var->IsType<phi::DenseTensor>()) {
return var->GetMutable<phi::DenseTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
return var->GetMutable<phi::SelectedRows>()->mutable_value();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type is %s, expect phi::DenseTensor or SelectedRows.",
ToTypeName(var->Type())));
}
}
void RuntimeInferShapeContext::ShareAllLoD(const std::string& in,
const std::string& out) const {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(out_it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Output [%s] found error in Op [%s]", out, op_.Type()));
OperatorWithKernel::OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
OperatorWithKernel::~OperatorWithKernel() = default;
PADDLE_ENFORCE_EQ(
in_var_list.size(),
out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
bool ExecutionContext::HasInput(const std::string& name) const {
auto* var = InputVar(name);
return var != nullptr;
}
auto& out_var_names = op_.Outputs(out);
bool ExecutionContext::HasInputs(const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (const auto* input : it->second) {
if (input == nullptr) {
return false;
}
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
return true;
}
bool ExecutionContext::HasOutput(const std::string& name) const {
auto* var = OutputVar(name);
return var != nullptr;
}
const Variable* ExecutionContext::InputVar(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) return nullptr;
PADDLE_ENFORCE_LE(
it->second.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's input %s should contain only one variable.",
op_.Type(),
name));
return it->second.empty() ? nullptr : it->second[0];
Variable* in_var = in_var_list[i];
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be phi::DenseTensor.",
i,
out_var_names[i]));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
Variable* ExecutionContext::OutputVar(const std::string& name) const {
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) return nullptr;
PADDLE_ENFORCE_LE(
it->second.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's output %s should contain only one variable.",
op_.Type(),
name));
return it->second.empty() ? nullptr : it->second[0];
}
template <>
const std::vector<const phi::DenseTensor*>
ExecutionContext::MultiInput<phi::DenseTensor>(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
auto vars = MultiInputVar(name);
if (vars.size() == 0) {
return {};
}
std::vector<const phi::DenseTensor*> res;
res.reserve(vars.size());
std::transform(vars.begin(),
vars.end(),
std::back_inserter(res),
[&](const Variable* var) -> const phi::DenseTensor* {
if (var == nullptr) return nullptr;
PADDLE_ENFORCE_EQ(
var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"Input variable should be phi::DenseTensor, "
"but the received type is %s.",
ToTypeName(var->Type())));
return &(var->Get<phi::DenseTensor>());
});
return res;
}
template <>
std::vector<phi::DenseTensor*> ExecutionContext::MultiOutput<phi::DenseTensor>(
const std::string& name) const {
auto vars = MultiOutputVar(name);
if (vars.size() == 0) {
return {};
}
std::vector<phi::DenseTensor*> res;
res.reserve(vars.size());
std::transform(vars.begin(),
vars.end(),
std::back_inserter(res),
[&](Variable* var) -> phi::DenseTensor* {
return var == nullptr ? nullptr
: var->GetMutable<phi::DenseTensor>();
});
return res;
}
bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first
bool has_phi_kernel = false;
auto& kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
for (auto& kernel : kernel_key_map) {
has_phi_kernel = true;
if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
return true;
}
}
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
return true;
}
}
} else {
if (has_phi_kernel) {
// if has phi kernel, but not find phi gpu kernel and fluid gpu kernel,
// this op doesn't support GPU
return false;
} else {
// All control operator must support GPU
return true;
}
}
return false;
}
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(),
1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(),
1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
bool HasAttr(const std::string& name) const override {
return op_.HasAttr(name);
}
bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name,
bool allow_null = false) const override {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
}
return true;
}
AttrReader Attrs() const override {
return AttrReader(op_.Attrs(), op_.RuntimeAttrs());
}
std::vector<std::string> Inputs(const std::string& name) const override {
return op_.Inputs(name);
}
std::vector<std::string> Outputs(const std::string& name) const override {
return op_.Outputs(name);
}
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx,
op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(
idx,
op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(),
idx,
op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string& in,
void RuntimeInferShapeContext::ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override {
size_t i,
size_t j) const {
if (can_skip_lod_) {
return;
}
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it,
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
......@@ -854,124 +463,14 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_it->second.size(),
j));
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
in_var->Type(),
out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.",
in,
out));
if (in_var->IsType<phi::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<phi::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<phi::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<phi::DenseTensor>()) {
auto& in_lod_tensor = in_var->Get<phi::DenseTensor>();
auto* out_lod_tensor = out_var->GetMutable<phi::DenseTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be phi::DenseTensor "
"or SelectedRows."));
}
}
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Output [%s] found error in Op [%s]", out, op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(),
out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_var_list[i];
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be phi::DenseTensor.",
i,
out_var_names[i]));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it,
ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it,
ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i,
in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(),
i));
PADDLE_ENFORCE_LT(j,
out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(),
j));
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<phi::DenseTensor>()) return;
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be phi::DenseTensor.",
j,
out));
"The %zu-th output of Output(%s) must be phi::DenseTensor.", j, out));
auto& in_tensor = in_var->Get<phi::DenseTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
......@@ -995,57 +494,58 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
int32_t RuntimeInferShapeContext::GetLoDLevel(const std::string& in,
size_t i) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
}
void SetLoDLevel(const std::string& out,
void RuntimeInferShapeContext::SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override {
size_t j) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
}
bool IsRuntime() const override { return true; }
bool RuntimeInferShapeContext::IsRuntime() const { return true; }
bool IsRunMKLDNNKernel() const override {
bool RuntimeInferShapeContext::IsRunMKLDNNKernel() const {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
phi::DataLayout::ONEDNN));
} catch (const std::bad_cast& exp) {
} catch (std::bad_cast& exp) {
return false;
}
}
}
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override {
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
RuntimeInferShapeContext::GetInputVarPtrs(const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
}
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override {
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
RuntimeInferShapeContext::GetOutputVarPtrs(const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name);
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
}
DDim GetInputDim(const std::string& name) const override {
DDim RuntimeInferShapeContext::GetInputDim(const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
......@@ -1055,164 +555,665 @@ class RuntimeInferShapeContext : public InferShapeContext {
name,
vars.size()));
return this->GetDim(vars[0]);
}
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
std::vector<DDim> RuntimeInferShapeContext::GetInputsDim(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
proto::VarType::Type RuntimeInferShapeContext::GetInputVarType(
const std::string& name) const {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type> RuntimeInferShapeContext::GetInputsVarType(
const std::string& name) const {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> RuntimeInferShapeContext::GetOutputsVarType(
const std::string& name) const {
return GetVarTypes(OutputVars(name));
}
void RuntimeInferShapeContext::SetOutputDim(const std::string& name,
const DDim& dim) {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name,
vars.size()));
SetDim(vars[0], dim);
}
void RuntimeInferShapeContext::SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
const phi::ArgumentMappingFn*
RuntimeInferShapeContext::GetPhiArgumentMappingFn() const {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature*
RuntimeInferShapeContext::GetPhiDefaultKernelSignature() const {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
void RuntimeInferShapeContext::SetSkipLoD(bool skip) { can_skip_lod_ = skip; }
DDim RuntimeInferShapeContext::GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only phi::DenseTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
ToTypeName(var->Type())));
}
}
std::vector<DDim> RuntimeInferShapeContext::GetDims(
const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(
vars.begin(), vars.end(), std::back_inserter(ret), [this](Variable* var) {
return this->GetDim(var);
});
return ret;
}
std::vector<DDim> RuntimeInferShapeContext::GetRepeatedDims(
const std::string& name) const {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
void RuntimeInferShapeContext::SetDim(Variable* var, const DDim& dim) {
if (var->IsType<phi::DenseTensor>()) {
var->GetMutable<phi::DenseTensor>()->Resize(dim);
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect phi::DenseTensor or SelectedRows, but "
"received "
"(%s).",
ToTypeName(var->Type())));
}
}
void RuntimeInferShapeContext::SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length,
dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void RuntimeInferShapeContext::SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
std::vector<proto::VarType::Type> RuntimeInferShapeContext::GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(),
vars.end(),
retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type RuntimeInferShapeContext::GetVarType(Variable* var) const {
return ToVarType(var->Type());
}
const std::vector<Variable*>& RuntimeInferShapeContext::InputVars(
const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
}
const std::vector<Variable*>& RuntimeInferShapeContext::OutputVars(
const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
try {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
#endif
} else if (platform::is_mlu_place(place)) {
#ifndef PADDLE_WITH_MLU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with MLU support.",
place));
#else
auto dev_id = place.device;
platform::SetMLUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
{
// TODO(wangchaochaohu) : refine code to use only one RecordEvent)
// in order to record different op type cost time
// and different op name cost time,we set two event.
platform::RecordEvent op_type_record_event(
Type(), platform::TracerEventType::Operator, 1);
auto op_name = platform::OpName(outputs_, Type());
platform::RecordEvent op_name_record_event(
op_name,
platform::TracerEventType::Operator,
FLAGS_enable_host_event_recorder_hook ? 20 : 1,
platform::EventRole::kUniqueOp);
RunImpl(scope, place);
}
VLOG(3) << GetExecutionPlace(place) << " " << DebugStringEx(&scope);
} catch (platform::EnforceNotMet& exception) {
framework::InsertCallStackInfo(Type(), Attrs(), &exception);
throw std::move(exception);
} catch (platform::EOFException&) {
std::rethrow_exception(std::current_exception());
} catch (std::exception& ex) {
LOG(WARNING) << Type() << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", " << ex.what();
std::rethrow_exception(std::current_exception());
} catch (...) {
LOG(WARNING) << Type() << " raises an unknown exception";
std::rethrow_exception(std::current_exception());
}
}
bool OperatorBase::HasInputs(const std::string& name) const {
return inputs_.find(name) != inputs_.end();
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(
ins.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's input %s should contain only one variable.",
type_,
name));
return ins.empty() ? kEmptyVarName : ins[0];
}
const std::vector<std::string>& OperatorBase::Inputs(
const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE_NE(
it,
inputs_.end(),
platform::errors::NotFound(
"Operator %s does not have the input %s.", type_, name));
return it->second;
}
bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.find(name) != outputs_.end()) {
return true;
} else {
return false;
}
}
std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name);
PADDLE_ENFORCE_LE(
outs.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's output %s should contain only one variable.",
type_,
name));
return outs.empty() ? kEmptyVarName : outs[0];
}
const std::vector<std::string>& OperatorBase::Outputs(
const std::string& name) const {
auto it = outputs_.find(name);
PADDLE_ENFORCE_NE(
it,
outputs_.end(),
platform::errors::NotFound(
"Operator %s does not have an output called %s.", type_, name));
return it->second;
}
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{";
const std::unordered_set<std::string>* no_need_buffer_vars = nullptr;
if (info_ && info_->NoNeedBufferVarsInferer()) {
no_need_buffer_vars =
&(Info().NoNeedBufferVarsInferer()(Inputs(), Outputs(), Attrs()));
if (no_need_buffer_vars->empty()) no_need_buffer_vars = nullptr;
}
for (auto it = inputs_.begin(); it != inputs_.end();) {
auto& input = *it;
bool is_no_need_buffer_var =
(no_need_buffer_vars && no_need_buffer_vars->count(input.first) > 0);
ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) {
auto var_name = input.second[i];
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
ss << "[uninited]";
} else {
int row_size = GetRowSize(*scope, var_name);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
std::string dtype = is_no_need_buffer_var
? "unknown_dtype"
: GetDtype(*scope, var_name);
std::string place = is_no_need_buffer_var
? "unknown_place"
: GetPlace(*scope, var_name);
ss << ":" << dtype;
ss << "[" << GetDimsDebug(*scope, var_name, true) << "]";
ss << "(" << GetLoDDebug(*scope, var_name) << ")";
ss << "(" << place << ")";
}
}
if (i != input.second.size() - 1) {
ss << ", ";
}
}
ss << "]";
++it;
if (it != inputs_.end()) {
ss << ", ";
}
}
ss << "}, outputs:{";
for (auto it = outputs_.begin(); it != outputs_.end();) {
auto& output = *it;
ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) {
auto var_name = output.second[i];
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
ss << "[uninited]";
} else {
int row_size = GetRowSize(*scope, output.second[i]);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
std::string dtype = GetDtype(*scope, output.second[i]);
ss << ":" << dtype;
ss << "[" << GetDimsDebug(*scope, var_name, true) << "]";
ss << "(" << GetLoDDebug(*scope, var_name) << ")";
ss << "(" << GetPlace(*scope, var_name) << ")";
}
}
if (i != output.second.size() - 1) {
ss << ", ";
}
}
ss << "]";
++it;
if (it != outputs_.end()) {
ss << ", ";
}
}
ss << "}.";
return ss.str();
}
OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: type_(type),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
// NOTE(zjl): why op_info may be nullptr?
info_(OpInfoMap::Instance().GetNullable(type)) {
// In dygraph mode, all the OperatorBase will be constructed by function:
// framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
// Inputs, outputs and attrs will be set to empty map
// to improve the execution efficiency of dygraph.
if (inputs_.size() > 0 || outputs_.size() > 0) {
GenerateTemporaryNames();
CheckAllInputOutputSet();
}
// In OperatorBase level, all attributes with VarDesc type will be considered
// as Input.
for (auto& attr : FilterAttrVar(attrs)) {
VLOG(3) << "found Attribute with Variable type: " << attr.first;
inputs_[attr.first] = std::move(AttrVarNames(attr.second));
attrs_.erase(attr.first);
}
}
std::vector<std::string> OperatorBase::InputVars() const {
std::vector<std::string> ret_val;
for (auto& o : inputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto& info = Info();
// get all OpProto::Var for outputs
for (auto& o : info.Proto().outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
void OperatorBase::CheckAllInputOutputSet() const {
if (info_ == nullptr || info_->proto_ == nullptr) return;
for (auto& in : info_->Proto().inputs()) {
if (!in.dispensable() && !in.extra()) {
PADDLE_ENFORCE_NE(
inputs_.find(in.name()),
inputs_.end(),
platform::errors::NotFound(
"Operator %s's input (%s) is not set.", Type(), in.name()));
}
proto::VarType::Type GetInputVarType(const std::string& name) const override {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
for (auto& out : info_->Proto().outputs()) {
if (!out.dispensable() && !out.extra() && !out.intermediate()) {
PADDLE_ENFORCE_NE(
outputs_.find(out.name()),
outputs_.end(),
platform::errors::NotFound(
"Operator %s's output (%s) is not set.", Type(), out.name()));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
}
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(),
1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name,
vars.size()));
SetDim(vars[0], dim);
void OperatorBase::GenerateTemporaryNames() {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
}
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
const phi::DenseTensor* GetLoDTensorOrSelectedRowsValueFromVar(
const Variable& var) {
if (var.IsType<phi::DenseTensor>()) {
return static_cast<const phi::DenseTensor*>(&(var.Get<phi::DenseTensor>()));
} else if (var.IsType<phi::SelectedRows>()) {
return &(var.Get<phi::SelectedRows>().value());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type is %s, expect phi::DenseTensor or SelectedRows.",
ToTypeName(var.Type())));
}
}
protected:
DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
return var->GetMutable<phi::DenseTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().GetCompleteDims();
return var->GetMutable<phi::SelectedRows>()->mutable_value();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only phi::DenseTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
"Variable type is %s, expect phi::DenseTensor or SelectedRows.",
ToTypeName(var->Type())));
}
}
}
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(vars.begin(),
vars.end(),
std::back_inserter(ret),
[this](Variable* var) { return this->GetDim(var); });
return ret;
}
OperatorWithKernel::OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
OperatorWithKernel::~OperatorWithKernel() = default;
void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<phi::DenseTensor>()) {
var->GetMutable<phi::DenseTensor>()->Resize(dim);
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect phi::DenseTensor or SelectedRows, but "
"received "
"(%s).",
ToTypeName(var->Type())));
}
}
bool ExecutionContext::HasInput(const std::string& name) const {
auto* var = InputVar(name);
return var != nullptr;
}
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length,
dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
bool ExecutionContext::HasInputs(const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
SetDim(vars[i], dims[i]);
for (const auto* input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
bool ExecutionContext::HasOutput(const std::string& name) const {
auto* var = OutputVar(name);
return var != nullptr;
}
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
const Variable* ExecutionContext::InputVar(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) return nullptr;
PADDLE_ENFORCE_LE(
it->second.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's input %s should contain only one variable.",
op_.Type(),
name));
return it->second.empty() ? nullptr : it->second[0];
}
Variable* ExecutionContext::OutputVar(const std::string& name) const {
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) return nullptr;
PADDLE_ENFORCE_LE(
it->second.size(),
1UL,
platform::errors::InvalidArgument(
"Operator %s's output %s should contain only one variable.",
op_.Type(),
name));
return it->second.empty() ? nullptr : it->second[0];
}
template <>
const std::vector<const phi::DenseTensor*>
ExecutionContext::MultiInput<phi::DenseTensor>(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
auto vars = MultiInputVar(name);
if (vars.size() == 0) {
return {};
}
std::vector<const phi::DenseTensor*> res;
res.reserve(vars.size());
std::transform(vars.begin(),
vars.end(),
retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this,
std::placeholders::_1));
return retv;
}
std::back_inserter(res),
[&](const Variable* var) -> const phi::DenseTensor* {
if (var == nullptr) return nullptr;
PADDLE_ENFORCE_EQ(
var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"Input variable should be phi::DenseTensor, "
"but the received type is %s.",
ToTypeName(var->Type())));
return &(var->Get<phi::DenseTensor>());
});
return res;
}
proto::VarType::Type GetVarType(Variable* var) const {
return ToVarType(var->Type());
template <>
std::vector<phi::DenseTensor*> ExecutionContext::MultiOutput<phi::DenseTensor>(
const std::string& name) const {
auto vars = MultiOutputVar(name);
if (vars.size() == 0) {
return {};
}
std::vector<phi::DenseTensor*> res;
res.reserve(vars.size());
std::transform(vars.begin(),
vars.end(),
std::back_inserter(res),
[&](Variable* var) -> phi::DenseTensor* {
return var == nullptr ? nullptr
: var->GetMutable<phi::DenseTensor>();
});
return res;
}
private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first
bool has_phi_kernel = false;
auto& kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
for (auto& kernel : kernel_key_map) {
has_phi_kernel = true;
if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
return true;
}
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it,
ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
return true;
}
}
} else {
if (has_phi_kernel) {
// if has phi kernel, but not find phi gpu kernel and fluid gpu kernel,
// this op doesn't support GPU
return false;
} else {
// All control operator must support GPU
return true;
}
}
const OperatorBase& op_;
const RuntimeContext& ctx_;
};
return false;
}
struct OperatorWithKernel::CacheImpl {
static const char kNotAllowInferShapeCahce[];
......
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/memory/malloc.h"
......@@ -47,7 +48,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
class InferShapeContext;
class OpInfo;
class Scope;
class Variable;
......@@ -146,6 +146,114 @@ class RuntimeContext {
VariableValueMap outputs;
};
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx);
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
std::vector<std::string> Inputs(const std::string& name) const override;
std::vector<std::string> Outputs(const std::string& name) const override;
std::string GetInputNameByIdx(size_t idx) const override;
std::string GetOutputNameByIdx(size_t idx) const override;
void ShareDim(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override;
void ShareAllLoD(const std::string& in,
const std::string& out) const override;
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override;
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
void SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override;
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
std::vector<DDim> GetInputsDim(const std::string& name) const override;
proto::VarType::Type GetInputVarType(const std::string& name) const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override;
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override;
void SetOutputDim(const std::string& name, const DDim& dim) override;
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip);
protected:
DDim GetDim(Variable* var) const;
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
void SetDim(Variable* var, const DDim& dim);
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims);
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const;
proto::VarType::Type GetVarType(Variable* var) const;
private:
const std::vector<Variable*>& InputVars(const std::string& name) const;
const std::vector<Variable*>& OutputVars(const std::string& name) const;
const OperatorBase& op_;
const RuntimeContext& ctx_;
bool can_skip_lod_{false};
};
/**
* OperatorBase has the basic elements that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册