diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index f0c128d554b296d7fe5c6818d3911aaee5c0adce..73b3051235ee90b31bd65acb22f454fc13d64da9 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -11,7 +11,7 @@ void PlainNet::InferShape(Scope* scope) { } } -void PlainNet::Run(Scope* scope, DeviceContext* ctx) { +void PlainNet::Run(std::shared_ptr scope, DeviceContext* ctx) { for (auto& op : ops_) { op.Run(ctx); } diff --git a/paddle/framework/net.h b/paddle/framework/net.h index b2894320dafdfaf9b8e0bffc8c863a2caae35a61..76992e07282904fd1074bb0ced2367a8d20e3ec2 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -69,7 +69,7 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual void Run(Scope *scope, DeviceContext *ctx) = 0; + virtual void Run(std::shared_ptr scope, DeviceContext *ctx) = 0; /** * @brief Add an Operator according to `def`. @@ -123,7 +123,7 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(Scope *scope, DeviceContext *ctx) override; + virtual void Run(std::shared_ptr scope, DeviceContext *ctx) override; /** * @brief Add an operator to this network.