diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 341da8befd45abd1a3fc86581be33319a8791567..b22e06cc79b41cdfe2c126ae854865d2a77d6eeb 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -25,7 +25,10 @@ namespace framework { class CosineOp : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 52387aabd9d0b41b13814499fb3f0107f42401e7..240a0602c9fe587d807cc80da70d1926031f2aea 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } } +void OperatorBase::Run(const Scope& scope, const platform::Place& place) { + if (platform::is_gpu_place(place)) { +#ifndef PADDLE_WITH_CUDA + PADDLE_THROW("Cannot run operator on place %s", place); +#else + auto dev_id = boost::get(place).device; + platform::SetDeviceId(dev_id); +#endif + } + RunImpl(scope, place); +} + std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, @@ -475,8 +487,8 @@ class RuntimeInferShapeContext : public InferShapeContext { const Scope& scope_; }; -void OperatorWithKernel::Run(const Scope& scope, - const platform::Place& place) const { +void OperatorWithKernel::RunImpl(const Scope& scope, + const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c9140f304c89e32a0fa8bd24722cc2e5dbc4e2e1..886e373348c1f4d1d7c8ad1e6557a7436a242de3 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -89,8 +89,9 @@ class OperatorBase { std::string DebugString() const { return DebugStringEx(nullptr); } - /// Net will call this function to Run an op. - virtual void Run(const Scope& scope, const platform::Place& place) const = 0; + /// Net will call this interface function to Run an op. + // The implementation should be written at RunImpl + void Run(const Scope& scope, const platform::Place& place); // FIXME(typhoonzero): this is only used for recv_op to stop event_loop. virtual void Stop() {} @@ -144,6 +145,8 @@ class OperatorBase { private: void GenerateTemporaryNames(); void CheckAllInputOutputSet() const; + virtual void RunImpl(const Scope& scope, + const platform::Place& place) const = 0; }; // Macro for define a clone method. @@ -168,10 +171,13 @@ class OperatorBase { class NOP : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} std::unique_ptr Clone() const override { return std::unique_ptr(new NOP(*this)); } + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class ExecutionContext { @@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase { const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const Scope& scope, const platform::Place& place) const final; - static std::unordered_map& AllOpKernels() { static std::unordered_map g_all_op_kernels; @@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase { // indicate kernel DataType by input data. Defaultly all input data must be // same. proto::DataType IndicateDataType(const ExecutionContext& ctx) const; + void RunImpl(const Scope& scope, const platform::Place& place) const final; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index b69d7c7a7406eb3e18d385c568cb9c21b9b4107b..7100e64732651504a7a4ceb042d2a029d955aa98 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase { OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs), x(1) {} - void Run(const Scope& scope, const platform::Place& place) const override { + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override { ++op_run_num; ASSERT_EQ(static_cast(inputs_.size()), 1); ASSERT_EQ(static_cast(outputs_.size()), 1); @@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase { const paddle::framework::VariableNameMap& outputs, const paddle::framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const paddle::framework::Scope& scope, - const paddle::platform::Place& place) const override {} + + private: + void RunImpl(const paddle::framework::Scope& scope, + const paddle::platform::Place& place) const override {} }; TEST(Operator, Clone) { diff --git a/paddle/operators/array_to_lod_tensor_op.cc b/paddle/operators/array_to_lod_tensor_op.cc index ba5c6bd3c681b4ae4f612da96df866227961df3d..3b9ebae1537df2ab8fa7a6ee2e51f7d9249ceef3 100644 --- a/paddle/operators/array_to_lod_tensor_op.cc +++ b/paddle/operators/array_to_lod_tensor_op.cc @@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &rank_table = scope.FindVar(Input("RankTable"))->Get(); diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc index e04aa2d28cff7b106b30304bfa19ba18e2affd21..0d1ce62bd6e642694f19f9878abc715bf2e96f8a 100644 --- a/paddle/operators/assign_op.cc +++ b/paddle/operators/assign_op.cc @@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); if (x == nullptr) { return; diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc index 72e05607b0b612807d552b4c45b58f9d9ce9c2af..a1b443042579c1e3e38fd009d0496bcaa6a5ce6b 100644 --- a/paddle/operators/beam_search_decode_op.cc +++ b/paddle/operators/beam_search_decode_op.cc @@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase { const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(dev_place); diff --git a/paddle/operators/beam_search_op.h b/paddle/operators/beam_search_op.h index 7ad85874fcbd6ea48d688b32f2cc982d6b76d3c4..8d62e71565d51fabfda8fd70a7265ad2e2510b1a 100644 --- a/paddle/operators/beam_search_op.h +++ b/paddle/operators/beam_search_op.h @@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase { PADDLE_THROW("Not Implemented"); } - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { auto ids_var = scope.FindVar(Input("ids")); auto scores_var = scope.FindVar(Input("scores")); auto pre_ids_var = scope.FindVar(Input("pre_ids")); diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index e333002bfd1ab40c62882f09cd207a12a0939648..28bac0b7bed1f67a8a3a911269eb3db30ceaa210 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, } } -void CondOp::Run(const Scope& scope, const platform::Place& place) const { +void CondOp::RunImpl(const Scope& scope, const platform::Place& place) const { // get device context from pool platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(place); diff --git a/paddle/operators/cond_op.h b/paddle/operators/cond_op.h index 7dcdc47e0b2ff216bea92d083fe5897009384d39..2dc0e23301244bef2734faa30baf8d0480d776a6 100644 --- a/paddle/operators/cond_op.h +++ b/paddle/operators/cond_op.h @@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase { sub_net_op_[FALSE_BRANCH] = std::move(net); } - void Run(const framework::Scope& scope, - const platform::Place& place) const override; + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override; private: const int TRUE_BRANCH = 0; diff --git a/paddle/operators/conditional_block_op.cc b/paddle/operators/conditional_block_op.cc index bdcdb85be7a94a748961048ac97e69f2f3b78677..f7572ccfaf52e2202458f632f471d92d43f80f14 100644 --- a/paddle/operators/conditional_block_op.cc +++ b/paddle/operators/conditional_block_op.cc @@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ConditionalOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto xs = InputTensors(scope); bool need_run; @@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ConditionalOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto xs = this->InputTensors(scope); bool need_run; diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 5ba2a25ab4c679f638e928a9e04c20d683a93630..66fd132b3a3b06800f0f43eb55f2f9699d77b1b1 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -72,8 +72,10 @@ template class CreateRandomDataGeneratorOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& shape_concat = Attr>("shape_concat"); const auto& ranks = Attr>("ranks"); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); @@ -120,8 +122,10 @@ class CreateRandomDataGeneratorOpMaker class CreateShuffleReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) @@ -152,8 +156,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { class CreateBatchReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index 789d01e0022b5c36957f295265a9dc42649b310f..3f6f8a589d7ad1314a919f371d81ffdd9eabe2e9 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto feed_var_name = Input("X"); auto *feed_var = scope.FindVar(feed_var_name); diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 7205ee2a879dfff711ad1cabebe197ef53377a1c..bb4b7356e78584778ffcf45315616ac0c57bcb44 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto fetch_var_name = Input("X"); auto *fetch_var = scope.FindVar(fetch_var_name); PADDLE_ENFORCE(fetch_var != nullptr, diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index dcd43a30c86b62d79f52ac640f14b295a062146c..ce4e7bf7f24c275c0b42198a6a7a63adc788876e 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase { class FillConstantOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto data_type = static_cast(Attr("dtype")); auto value = Attr("value"); diff --git a/paddle/operators/fill_op.cc b/paddle/operators/fill_op.cc index 4f5a2ed169565771629fe8df7c25cf23bc94e339..bc72a189026d5abdf2fd4e30b9c25d2b62804ad3 100644 --- a/paddle/operators/fill_op.cc +++ b/paddle/operators/fill_op.cc @@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &out = detail::Ref(detail::Ref(scope.FindVar(Output("Out")), "Cannot find variable %s", Output("Out")) diff --git a/paddle/operators/get_places_op.cc b/paddle/operators/get_places_op.cc index 24fafb23074c358669715f1840246c3520f948c7..a7168a10796487cd366d0a6d3a88db0f221cb655 100644 --- a/paddle/operators/get_places_op.cc +++ b/paddle/operators/get_places_op.cc @@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { bool is_gpu; if (Attr("device_type") == "AUTO") { is_gpu = platform::is_gpu_place(place); diff --git a/paddle/operators/increment_op.cc b/paddle/operators/increment_op.cc index e0b80cc4e74429dee1b9a25e41b116970ad4de2a..adc7e8f1a469300aaf2578d4c79e2e12d0b6a45f 100644 --- a/paddle/operators/increment_op.cc +++ b/paddle/operators/increment_op.cc @@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &out = *scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/operators/is_empty_op.cc b/paddle/operators/is_empty_op.cc index 492ae48845aa5aa123989e62d07f5ae899af6193..1de3437b0c24dd5cfa5f4575ace7f51a42057919 100644 --- a/paddle/operators/is_empty_op.cc +++ b/paddle/operators/is_empty_op.cc @@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { // get input auto *var = scope.FindVar(Input(kInput)); PADDLE_ENFORCE_NOT_NULL(var); diff --git a/paddle/operators/load_combine_op.cc b/paddle/operators/load_combine_op.cc index f4be793d7bf1f346c011842c57fb5b5179a697d6..13b1c5da90b62e89cc40c82791e22604318a5de3 100644 --- a/paddle/operators/load_combine_op.cc +++ b/paddle/operators/load_combine_op.cc @@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); std::ifstream fin(filename); diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index f886b423ac7cb89961d1fdb5c6d3776ccafcaf60..88d0cc725d6a22a462bde7a747f5c67543229786 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); std::ifstream fin(filename); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", diff --git a/paddle/operators/lod_array_length_op.cc b/paddle/operators/lod_array_length_op.cc index d2c52745cfdf8d0fdb168ef2d90e75a515c31015..aa18aa2646102071f584bd84b72cbb9e2b9b98c9 100644 --- a/paddle/operators/lod_array_length_op.cc +++ b/paddle/operators/lod_array_length_op.cc @@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &out = *scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/operators/lod_rank_table_op.cc b/paddle/operators/lod_rank_table_op.cc index 692b9bf3710d764eceafda8390eedb8590794ddf..8e05ee63a023b4e0c78017c96eb21983e8f2a446 100644 --- a/paddle/operators/lod_rank_table_op.cc +++ b/paddle/operators/lod_rank_table_op.cc @@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto x = scope.FindVar(Input("X"))->Get(); auto *out = scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/operators/lod_tensor_to_array_op.cc b/paddle/operators/lod_tensor_to_array_op.cc index 685a807a8acafd36f44161fb17e0e88070d0bf43..0b1d2ffc8f84305456a44e9036c575c17db5226f 100644 --- a/paddle/operators/lod_tensor_to_array_op.cc +++ b/paddle/operators/lod_tensor_to_array_op.cc @@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", Input("X")) .Get(); diff --git a/paddle/operators/max_sequence_len_op.cc b/paddle/operators/max_sequence_len_op.cc index 019150e4914e8bd34a5e8b7d37318aee43942fcc..794a1e56d3eca9cd5503f8cb4660fde2cfc621d9 100644 --- a/paddle/operators/max_sequence_len_op.cc +++ b/paddle/operators/max_sequence_len_op.cc @@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &rank_table = scope.FindVar(Input("RankTable"))->Get(); auto *out = diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc index 87644d316d42c7d9453a99b759214b24088062df..53ee7d63f339a46c687997f6d8db72c5da70796c 100644 --- a/paddle/operators/merge_lod_tensor_op.cc +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index 9d51153b0631b988c9297f395672be67e18ee3f9..974ae9d963689f9ed1651c3975cdd0cf8c198123 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { const auto &name = Output("Communicator"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), "Can not find variable '%s' in the scope.", name); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index b24042f5ef5822eabcada8ed9d21c552579e8064..9ac8f34347ff365a03beee9d0de1780e627770c7 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase { this->CompleteAddOp(); } - /** - * @brief Run the network. - * - * Run all the operators with the `scope`, if no scope is provided, default - * scope will be used instead. If no OpContext is provicded, default context - * will be used. - */ - void Run(const framework::Scope& scope, - const platform::Place& place) const override { - for (auto& op : ops_) { - op->Run(scope, place); - } - } - bool SupportGPU() const override { for (auto& op : ops_) { if (!op->SupportGPU()) { @@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase { std::vector> ops_; private: + /** + * @brief Run the network. + * + * Run all the operators with the `scope`, if no scope is provided, default + * scope will be used instead. If no OpContext is provicded, default context + * will be used. + */ + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + for (auto& op : ops_) { + op->Run(scope, place); + } + } + bool add_op_done_{false}; std::set intermediate_outputs_; diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 9358f29f62fc21801f8036400d2baebdfd663a3a..95d21f1516855f70c568bf8a3875c1ce5d0e2da5 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; DEFINE_OP_CLONE_METHOD(TestOp); - void Run(const Scope& scope, const platform::Place& place) const override { + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override { ++run_cnt; } }; diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 89045923f9ff2f33bc112b199c493047440e15c4..b1233c93f8d38e38c5d1a56b3f24bf0357bb0c3f 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -124,8 +124,9 @@ class ParallelDoOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); @@ -216,8 +217,9 @@ class ParallelDoGradOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *block = Attr(kParallelBlock); auto *program = block->Program(); diff --git a/paddle/operators/print_op.cc b/paddle/operators/print_op.cc index 8b233d64c904a8870212af33c5839cfc555b5dc8..e869e4d6204b200bea673c168ea4d1d06ce9fd41 100644 --- a/paddle/operators/print_op.cc +++ b/paddle/operators/print_op.cc @@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase { PADDLE_THROW("Not implemented."); } - void Run(const framework::Scope& scope, - const platform::Place& place) const override { + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { const framework::Variable* in_var_ptr = nullptr; std::string phase = kForward; std::string printed_var_name = ""; diff --git a/paddle/operators/read_op.cc b/paddle/operators/read_op.cc index 3ae454101f585cf412a306fd3198f99fbdb8324d..924b787faa89da231d648c98bcc1ead6f1677da7 100644 --- a/paddle/operators/read_op.cc +++ b/paddle/operators/read_op.cc @@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference { class ReadOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable(); if (!reader->HasNext()) { diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index a136c5b447d7a64f783c00c928bf9e248aff6649..19ad7fbb709205026a6c5a9a69a3fdbc3f60cc1d 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase { const framework::AttributeMap &attrs) : RecurrentBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto seq_len = static_cast(this->GetSequenceLength(scope)); VLOG(3) << "Static RNN input sequence length = " << seq_len; StepScopes scopes = CreateStepScopes(scope, seq_len); @@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase { const framework::AttributeMap &attrs) : RecurrentBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto seq_len = static_cast(GetSequenceLength(scope)); StepScopes scopes = CreateStepScopes(scope, seq_len); auto reverse = Attr(kReverse); diff --git a/paddle/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/operators/reorder_lod_tensor_by_rank_op.cc index 3c30447949421da516213b47178828453671c693..f5c16870b5bec5115f43eee46df0e07a34284338 100644 --- a/paddle/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/operators/reorder_lod_tensor_by_rank_op.cc @@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input lod tensor variable %s", Input("X")) diff --git a/paddle/operators/rnn_memory_helper_op.cc b/paddle/operators/rnn_memory_helper_op.cc index eb55ed6a05b51d7a6c63d16fcf5aff73f6744903..fe88aa1fb5788dc3ee0811e7e063249ee6fd769c 100644 --- a/paddle/operators/rnn_memory_helper_op.cc +++ b/paddle/operators/rnn_memory_helper_op.cc @@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto mem_var_name = Input("X"); auto *mem_var = scope.FindVar(mem_var_name); PADDLE_ENFORCE(mem_var != nullptr, @@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto out_grad_var_name = Input(framework::GradVarName("Out")); auto *out_grad_var = scope.FindVar(out_grad_var_name); diff --git a/paddle/operators/save_combine_op.cc b/paddle/operators/save_combine_op.cc index bffa2908bc42d73332f22fa3706d24ab49cd4b38..5ce0bfb914693058dc1a9b81fa274c3f403b4fcc 100644 --- a/paddle/operators/save_combine_op.cc +++ b/paddle/operators/save_combine_op.cc @@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); diff --git a/paddle/operators/save_op.cc b/paddle/operators/save_op.cc index 4b1cbe88836e340c94f797806243a6768410ed3d..c8250d0c3dea409c202168aed346822bd8052a20 100644 --- a/paddle/operators/save_op.cc +++ b/paddle/operators/save_op.cc @@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index bf870115a4d7b6f4d578df7707826973d4363ba6..cd96ec5133af21e401dd08c74c70ae189cc24335 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x_var = scope.FindVar(Input("X")); PADDLE_ENFORCE(x_var != nullptr, "Input X must be set"); auto &x_tensor = x_var->Get(); @@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); auto *dx_var = scope.FindVar(Output(framework::GradVarName("X"))); PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr"); diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc index bd93c492015e074afe08ee167025aa6251b369d1..cd833889edfc70206de1155b4807acb7bf7d58eb 100644 --- a/paddle/operators/split_lod_tensor_op.cc +++ b/paddle/operators/split_lod_tensor_op.cc @@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &mask = scope.FindVar(Input("Mask"))->Get(); auto *out_true = diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index a70be8b8752d12433bb19b9953d80e397858834c..af3d9b7cc35dddb978bc5a4ac4e6458a22b65151 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); if (x == nullptr) return; auto &x_tensor = x->Get(); @@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); PADDLE_ENFORCE(x != nullptr, "X must be set"); auto &x_array = x->Get(); diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index a744ebd61595403ee495a2e2c9e84181422e92ff..06b0c77485cf54def8605fc9586091140b86e232 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); @@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place);