diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a62afe248baa2ab57edabfaac05bb858a1a1280c..86e1713b021e391f2cd00e38e6070336379901b7 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -703,8 +703,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { - RuntimeInferShapeContext infer_shape_ctx(*this, scope); - this->InferShape(&infer_shape_ctx); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -758,6 +756,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope); + this->InferShape(&infer_shape_ctx); kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); if (!transfered_inplace_vars.empty()) {