未验证 提交 5c1eda19 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove InterpretercoreInferShapeContext (#51209)

* Remove InterpretercoreInferShapeContext

* Fix lod errors
上级 f1f2a253
...@@ -129,7 +129,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -129,7 +129,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs["X"] = {scope_->FindVar(var_name)}; runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {scope_->Var(new_var_name)}; runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
op.get()->Info().infer_shape_(&infer_shape_ctx); op.get()->Info().infer_shape_(&infer_shape_ctx);
// 2. choose kernel // 2. choose kernel
......
...@@ -824,8 +824,7 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -824,8 +824,7 @@ void BuildOpFuncList(const platform::Place& place,
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) { op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
VLOG(4) << "infer shape"; VLOG(4) << "infer shape";
InterpretercoreInferShapeContext infer_shape_ctx(*op, RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT // TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel. // inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx); op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
......
...@@ -44,115 +44,6 @@ constexpr const char* kH2DStream = "H2DStream"; ...@@ -44,115 +44,6 @@ constexpr const char* kH2DStream = "H2DStream";
constexpr int kEmptyVarIndex = 0; constexpr int kEmptyVarIndex = 0;
class InterpretercoreInferShapeContext : public InferShapeContext {
public:
InterpretercoreInferShapeContext(const OperatorBase& op,
const RuntimeContext& ctx);
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
std::vector<std::string> Inputs(const std::string& name) const override;
std::vector<std::string> Outputs(const std::string& name) const override;
std::string GetInputNameByIdx(size_t idx) const override;
std::string GetOutputNameByIdx(size_t idx) const override;
void ShareDim(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override;
void ShareAllLoD(const std::string& in,
const std::string& out) const override;
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override;
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
void SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override;
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
std::vector<DDim> GetInputsDim(const std::string& name) const override;
proto::VarType::Type GetInputVarType(const std::string& name) const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override;
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override;
void SetOutputDim(const std::string& name, const DDim& dim) override;
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip);
protected:
DDim GetDim(Variable* var) const;
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
void SetDim(Variable* var, const DDim& dim);
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims);
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const;
proto::VarType::Type GetVarType(Variable* var) const;
private:
const std::vector<Variable*>& InputVars(const std::string& name) const;
const std::vector<Variable*>& OutputVars(const std::string& name) const;
const OperatorBase& op_;
const RuntimeContext& ctx_;
bool can_skip_lod_;
};
struct OpKernelFunc { struct OpKernelFunc {
OpKernelComputeFunc compute_func_; OpKernelComputeFunc compute_func_;
}; };
...@@ -260,7 +151,6 @@ enum class OpFuncType { ...@@ -260,7 +151,6 @@ enum class OpFuncType {
kGpuSync, // GPU or other device kernel without asynchronous operation kGpuSync, // GPU or other device kernel without asynchronous operation
kGpuAsync // GPU or other device kernel with asynchronous operation kGpuAsync // GPU or other device kernel with asynchronous operation
}; };
class RuntimeInferShapeContext;
struct OpFuncNode { struct OpFuncNode {
int stream_priority_{0}; // lower value, higher priority int stream_priority_{0}; // lower value, higher priority
...@@ -357,8 +247,7 @@ class Instruction { ...@@ -357,8 +247,7 @@ class Instruction {
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const; std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext() std::shared_ptr<RuntimeInferShapeContext> InnerInferShapeContext() const;
const;
std::shared_ptr<ExecutionContext> InnerExecutionContext() const; std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
...@@ -390,7 +279,7 @@ class Instruction { ...@@ -390,7 +279,7 @@ class Instruction {
const platform::DeviceContext& dev_ctx_; // not owned const platform::DeviceContext& dev_ctx_; // not owned
std::shared_ptr<RuntimeContext> runtime_ctx_; std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_; std::shared_ptr<RuntimeInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_; std::shared_ptr<ExecutionContext> execution_ctx_;
std::vector<size_t> gc_check_vars_; std::vector<size_t> gc_check_vars_;
......
此差异已折叠。
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
...@@ -47,7 +48,6 @@ limitations under the License. */ ...@@ -47,7 +48,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeContext;
class OpInfo; class OpInfo;
class Scope; class Scope;
class Variable; class Variable;
...@@ -146,6 +146,114 @@ class RuntimeContext { ...@@ -146,6 +146,114 @@ class RuntimeContext {
VariableValueMap outputs; VariableValueMap outputs;
}; };
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx);
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
std::vector<std::string> Inputs(const std::string& name) const override;
std::vector<std::string> Outputs(const std::string& name) const override;
std::string GetInputNameByIdx(size_t idx) const override;
std::string GetOutputNameByIdx(size_t idx) const override;
void ShareDim(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) override;
void ShareAllLoD(const std::string& in,
const std::string& out) const override;
void ShareLoD(const std::string& in,
const std::string& out,
size_t i = 0,
size_t j = 0) const override;
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
void SetLoDLevel(const std::string& out,
int32_t lod_level,
size_t j = 0) const override;
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
std::vector<DDim> GetInputsDim(const std::string& name) const override;
proto::VarType::Type GetInputVarType(const std::string& name) const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override;
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override;
void SetOutputDim(const std::string& name, const DDim& dim) override;
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip);
protected:
DDim GetDim(Variable* var) const;
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
void SetDim(Variable* var, const DDim& dim);
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims);
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const;
proto::VarType::Type GetVarType(Variable* var) const;
private:
const std::vector<Variable*>& InputVars(const std::string& name) const;
const std::vector<Variable*>& OutputVars(const std::string& name) const;
const OperatorBase& op_;
const RuntimeContext& ctx_;
bool can_skip_lod_{false};
};
/** /**
* OperatorBase has the basic elements that Net will call to do computation. * OperatorBase has the basic elements that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册