diff --git a/paddle/fluid/framework/op_registry_test.cc b/paddle/fluid/framework/op_registry_test.cc index bfbb2cfc2c57c705cf42c65825edcc6dea08cf41..2746168f1dda493368b81820bde2f093d06d7b4e 100644 --- a/paddle/fluid/framework/op_registry_test.cc +++ b/paddle/fluid/framework/op_registry_test.cc @@ -25,7 +25,10 @@ namespace framework { class CosineOp : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 61529fe38b15fe2a4bfa0d64159994d6b62fb086..8effbf1bc6298bdcc381e2176411a79da134653f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } } +void OperatorBase::Run(const Scope& scope, const platform::Place& place) { + if (platform::is_gpu_place(place)) { +#ifndef PADDLE_WITH_CUDA + PADDLE_THROW("Cannot run operator on place %s", place); +#else + auto dev_id = boost::get(place).device; + platform::SetDeviceId(dev_id); +#endif + } + RunImpl(scope, place); +} + std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, @@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext { const Scope& scope_; }; -void OperatorWithKernel::Run(const Scope& scope, - const platform::Place& place) const { +void OperatorWithKernel::RunImpl(const Scope& scope, + const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 52300abeb7df346d610d2363335dc9d3330ee39e..708f87dc8632ac500e1050122c5fd5412071fd22 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -89,8 +89,9 @@ class OperatorBase { std::string DebugString() const { return DebugStringEx(nullptr); } - /// Net will call this function to Run an op. - virtual void Run(const Scope& scope, const platform::Place& place) const = 0; + /// Net will call this interface function to Run an op. + // The implementation should be written at RunImpl + void Run(const Scope& scope, const platform::Place& place); // FIXME(typhoonzero): this is only used for recv_op to stop event_loop. virtual void Stop() {} @@ -144,6 +145,8 @@ class OperatorBase { private: void GenerateTemporaryNames(); void CheckAllInputOutputSet() const; + virtual void RunImpl(const Scope& scope, + const platform::Place& place) const = 0; }; // Macro for define a clone method. @@ -168,10 +171,13 @@ class OperatorBase { class NOP : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} std::unique_ptr Clone() const override { return std::unique_ptr(new NOP(*this)); } + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class ExecutionContext { @@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase { const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const Scope& scope, const platform::Place& place) const final; - static std::unordered_map& AllOpKernels() { static std::unordered_map g_all_op_kernels; @@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase { // indicate kernel DataType by input data. Defaultly all input data must be // same. proto::DataType IndicateDataType(const ExecutionContext& ctx) const; + void RunImpl(const Scope& scope, const platform::Place& place) const final; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index b90f5538bb620275521cdc11bf47b4014b2a66e2..0732ec5afe8738313e1d73c52c5303a2e8b1e96a 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase { OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs), x(1) {} - void Run(const Scope& scope, const platform::Place& place) const override { + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override { ++op_run_num; ASSERT_EQ(static_cast(inputs_.size()), 1); ASSERT_EQ(static_cast(outputs_.size()), 1); @@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase { const paddle::framework::VariableNameMap& outputs, const paddle::framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const paddle::framework::Scope& scope, - const paddle::platform::Place& place) const override {} + + private: + void RunImpl(const paddle::framework::Scope& scope, + const paddle::platform::Place& place) const override {} }; TEST(Operator, Clone) { diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index bf8e11bd8c047275fe341ead9424d02e98d5d8f4..69464c4cff52400d8a25a692c5df8d2fe06230e4 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &rank_table = scope.FindVar(Input("RankTable"))->Get(); diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index f99f9af4276c0e8928f821ae166d55aed02e8e27..b72e72b12f8a6155b6eb3be1468b8dbc7bd48d4e 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); if (x == nullptr) { return; diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index 7737d4e098ac9a0e56e1db2aee796550e8d71ba3..6d3efcfeb8497a78d56180898e5e3a66e52ff22d 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase { const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(dev_place); diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 9e2a05a60c30e388093aceddd40e58273364c8f9..bfbe78097d2f20ae4c5efa594d17f931c7ea5920 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase { PADDLE_THROW("Not Implemented"); } - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { auto ids_var = scope.FindVar(Input("ids")); auto scores_var = scope.FindVar(Input("scores")); auto pre_ids_var = scope.FindVar(Input("pre_ids")); diff --git a/paddle/fluid/operators/cond_op.cc b/paddle/fluid/operators/cond_op.cc index dd93790d5b52a2ccc8358a94f7ead346d384f191..d63748a61cec0f10269e05bcef3bb0d10345000d 100644 --- a/paddle/fluid/operators/cond_op.cc +++ b/paddle/fluid/operators/cond_op.cc @@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, } } -void CondOp::Run(const Scope& scope, const platform::Place& place) const { +void CondOp::RunImpl(const Scope& scope, const platform::Place& place) const { // get device context from pool platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(place); diff --git a/paddle/fluid/operators/cond_op.h b/paddle/fluid/operators/cond_op.h index 695af4490696b29d2d47f5825ebc0159b39663c0..0bb14bc8c2cfabeeb13e1e1afd51b034742b74f0 100644 --- a/paddle/fluid/operators/cond_op.h +++ b/paddle/fluid/operators/cond_op.h @@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase { sub_net_op_[FALSE_BRANCH] = std::move(net); } - void Run(const framework::Scope& scope, - const platform::Place& place) const override; + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override; private: const int TRUE_BRANCH = 0; diff --git a/paddle/fluid/operators/conditional_block_op.cc b/paddle/fluid/operators/conditional_block_op.cc index 30435c6cca0a4fb1d41dce47b8fefeafb6c48a51..228b0998360550348fdd30c842a394e8f8ce5935 100644 --- a/paddle/fluid/operators/conditional_block_op.cc +++ b/paddle/fluid/operators/conditional_block_op.cc @@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ConditionalOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto xs = InputTensors(scope); bool need_run; @@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ConditionalOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto xs = this->InputTensors(scope); bool need_run; diff --git a/paddle/fluid/operators/create_reader_op.cc b/paddle/fluid/operators/create_reader_op.cc index d1ba51f2c0f13a1b6e4d7ccb93c912703a0b1d86..1393f1a66baaf3b53f797aa61fd42ac3cf54f8db 100644 --- a/paddle/fluid/operators/create_reader_op.cc +++ b/paddle/fluid/operators/create_reader_op.cc @@ -106,8 +106,10 @@ template class CreateRandomDataGeneratorOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& shape_concat = Attr>("shape_concat"); const auto& ranks = Attr>("ranks"); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); @@ -155,8 +157,10 @@ class CreateRandomDataGeneratorOpMaker class CreateShuffleReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) @@ -187,8 +191,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { class CreateBatchReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) diff --git a/paddle/fluid/operators/feed_op.cc b/paddle/fluid/operators/feed_op.cc index 0b3f5f0d1d09a932e15936285f5cb226daa86e95..41fa69a0972ef8ad528f2a04b0260c40155ffd3e 100644 --- a/paddle/fluid/operators/feed_op.cc +++ b/paddle/fluid/operators/feed_op.cc @@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto feed_var_name = Input("X"); auto *feed_var = scope.FindVar(feed_var_name); diff --git a/paddle/fluid/operators/fetch_op.cc b/paddle/fluid/operators/fetch_op.cc index 54e5892016cdb01f50189147a7453b868c5a48c0..6cb5565013dcacac33e828386f1ea8909e831c1a 100644 --- a/paddle/fluid/operators/fetch_op.cc +++ b/paddle/fluid/operators/fetch_op.cc @@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto fetch_var_name = Input("X"); auto *fetch_var = scope.FindVar(fetch_var_name); PADDLE_ENFORCE(fetch_var != nullptr, diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index d4bf6406e5716a6b65a234d1cd642b64dcc5726f..6dd58d28db23ff3de8a27e898a9b539787d08718 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase { class FillConstantOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto data_type = static_cast(Attr("dtype")); auto value = Attr("value"); diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 8e318f37cf0bc945597b5aa7b384e53038c97786..0b97c9c2827ac1be4e99c647dbedc2d9b8730e41 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &out = detail::Ref(detail::Ref(scope.FindVar(Output("Out")), "Cannot find variable %s", Output("Out")) diff --git a/paddle/fluid/operators/get_places_op.cc b/paddle/fluid/operators/get_places_op.cc index ba908e472bbc165a244d8543713f1dbf293abb48..ef635048bd4faa2dc0067152f5f7472acbfe47af 100644 --- a/paddle/fluid/operators/get_places_op.cc +++ b/paddle/fluid/operators/get_places_op.cc @@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { bool is_gpu; if (Attr("device_type") == "AUTO") { is_gpu = platform::is_gpu_place(place); diff --git a/paddle/fluid/operators/increment_op.cc b/paddle/fluid/operators/increment_op.cc index 3d488067b254c37515c6bdb9a4589aad311f344f..de4949584b7b20bec7b31f2ad1b69053ee9ffc0f 100644 --- a/paddle/fluid/operators/increment_op.cc +++ b/paddle/fluid/operators/increment_op.cc @@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &out = *scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/fluid/operators/is_empty_op.cc b/paddle/fluid/operators/is_empty_op.cc index ea424018d66dac85d5a4ad75cbf5199064d52848..dac8505e3f2cb33b35b6184184e4762078a19c49 100644 --- a/paddle/fluid/operators/is_empty_op.cc +++ b/paddle/fluid/operators/is_empty_op.cc @@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { // get input auto *var = scope.FindVar(Input(kInput)); PADDLE_ENFORCE_NOT_NULL(var); diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 1948063d886b79964b1a52d9d82a8e7d2fb0d493..d043702ebae627951927f2dbec893d40f77f0c73 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); std::ifstream fin(filename); diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index c9bf5d72b234f96d9eb5a4c275737ac8c18cd63d..9393cccfc66ec930db6ef68bd6f3c5065ceea80e 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); std::ifstream fin(filename); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", diff --git a/paddle/fluid/operators/lod_array_length_op.cc b/paddle/fluid/operators/lod_array_length_op.cc index f11f5a89f5ad5b2f3deed905625aefa1e9d9935b..daa57c20450f1f92cb0bb500e37d0d8c49c05758 100644 --- a/paddle/fluid/operators/lod_array_length_op.cc +++ b/paddle/fluid/operators/lod_array_length_op.cc @@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &out = *scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/fluid/operators/lod_rank_table_op.cc b/paddle/fluid/operators/lod_rank_table_op.cc index 0b9426a9f8f0b0b3082667dc7a1414aceb824aca..3264766d6b693244f8dbfa6462b9c7aa13d5b5ec 100644 --- a/paddle/fluid/operators/lod_rank_table_op.cc +++ b/paddle/fluid/operators/lod_rank_table_op.cc @@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto x = scope.FindVar(Input("X"))->Get(); auto *out = scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index edc32bcec1441e50e24612789727db9a044cde54..d6e24dc976a1ebe2afa182618d09839b105381c1 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", Input("X")) .Get(); diff --git a/paddle/fluid/operators/max_sequence_len_op.cc b/paddle/fluid/operators/max_sequence_len_op.cc index eff8b927e52c94a4e19bb10c644cbaa34a7a0581..cef0dc307dbe97473e9041f51c25eca7cc9a0f1a 100644 --- a/paddle/fluid/operators/max_sequence_len_op.cc +++ b/paddle/fluid/operators/max_sequence_len_op.cc @@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &rank_table = scope.FindVar(Input("RankTable"))->Get(); auto *out = diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index 255f55334093213df867852e4d222f0e227e8c5d..88e67b6b86a3731cc2caf5529aa4892c6d605a86 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 52420ceba0de0323dae000aa301ce7838b3311b6..703e8dd00fc8e613344db11065d6a45afa2a0cc8 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { const auto &name = Output("Communicator"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), "Can not find variable '%s' in the scope.", name); diff --git a/paddle/fluid/operators/net_op.h b/paddle/fluid/operators/net_op.h index 14e5909851c4ac08b5f59c5c193c801827b91234..479ba386a70adaff09ae31e24c449fc18a9853b1 100644 --- a/paddle/fluid/operators/net_op.h +++ b/paddle/fluid/operators/net_op.h @@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase { this->CompleteAddOp(); } - /** - * @brief Run the network. - * - * Run all the operators with the `scope`, if no scope is provided, default - * scope will be used instead. If no OpContext is provicded, default context - * will be used. - */ - void Run(const framework::Scope& scope, - const platform::Place& place) const override { - for (auto& op : ops_) { - op->Run(scope, place); - } - } - bool SupportGPU() const override { for (auto& op : ops_) { if (!op->SupportGPU()) { @@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase { std::vector> ops_; private: + /** + * @brief Run the network. + * + * Run all the operators with the `scope`, if no scope is provided, default + * scope will be used instead. If no OpContext is provicded, default context + * will be used. + */ + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + for (auto& op : ops_) { + op->Run(scope, place); + } + } + bool add_op_done_{false}; std::set intermediate_outputs_; diff --git a/paddle/fluid/operators/net_op_test.cc b/paddle/fluid/operators/net_op_test.cc index cc20be0c81763abe2adcf09de858ce51e16d77a6..265f15e82ed29824ed65917dbe45e5edf9dc8993 100644 --- a/paddle/fluid/operators/net_op_test.cc +++ b/paddle/fluid/operators/net_op_test.cc @@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; DEFINE_OP_CLONE_METHOD(TestOp); - void Run(const Scope& scope, const platform::Place& place) const override { + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override { ++run_cnt; } }; diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index e25df92479943d210d98f02374f377f778f43d2c..d791d11172869d42b08c059b900e729bcc9b5d96 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -118,8 +118,9 @@ class ParallelDoOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); @@ -207,8 +208,9 @@ class ParallelDoGradOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *block = Attr(kParallelBlock); auto *program = block->Program(); diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index 3616545309e8c279f61a22e571a5e71335c47f93..4d12fdbb6b62d1d7095d10aa6f33d12598a8e99e 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase { PADDLE_THROW("Not implemented."); } - void Run(const framework::Scope& scope, - const platform::Place& place) const override { + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { const framework::Variable* in_var_ptr = nullptr; std::string phase = kForward; std::string printed_var_name = ""; diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 4d562c291911f54c9d1e8fed2e84035808bffbb7..127df82ff13b89de42e45113a21d6f5e7c2f20ed 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference { class ReadOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable(); if (!reader->HasNext()) { diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index e4b9b8dab9b0394752d538aa5f59be3c06d0188f..33a744a5b7fef5802569a305d18746f04ed88136 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase { const framework::AttributeMap &attrs) : RecurrentBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto seq_len = static_cast(this->GetSequenceLength(scope)); VLOG(3) << "Static RNN input sequence length = " << seq_len; StepScopes scopes = CreateStepScopes(scope, seq_len); @@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase { const framework::AttributeMap &attrs) : RecurrentBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto seq_len = static_cast(GetSequenceLength(scope)); StepScopes scopes = CreateStepScopes(scope, seq_len); auto reverse = Attr(kReverse); diff --git a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc index 148a65bb4b7fe599f2fdb833c179665e58fe1c41..79ba9e543b892d051995d4bafb0ceaaf09838cd2 100644 --- a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc @@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input lod tensor variable %s", Input("X")) diff --git a/paddle/fluid/operators/rnn_memory_helper_op.cc b/paddle/fluid/operators/rnn_memory_helper_op.cc index 504456c4b069f81319893ae51f57503f5025761a..e9329a0e7e279e2bdd3c45986580c87aa5d0b1fe 100644 --- a/paddle/fluid/operators/rnn_memory_helper_op.cc +++ b/paddle/fluid/operators/rnn_memory_helper_op.cc @@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto mem_var_name = Input("X"); auto *mem_var = scope.FindVar(mem_var_name); PADDLE_ENFORCE(mem_var != nullptr, @@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto out_grad_var_name = Input(framework::GradVarName("Out")); auto *out_grad_var = scope.FindVar(out_grad_var_name); diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index c23de9073ef965b989e98936b2dd07fc6bce7fdc..e3953e4b08082c08e1bbf77a834d4a895b327f83 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 483cdfa4c3b9e3b9abd3f32bc5e6e5e0b493bd23..85ba8e01182c2cd01aa599ddbce68b6b2d9aa5f4 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); diff --git a/paddle/fluid/operators/shrink_rnn_memory_op.cc b/paddle/fluid/operators/shrink_rnn_memory_op.cc index df50a324fde1637f1f9f64a0b0d4eff8ba3f26d2..7fe0526381d1fc18ad0552c321875af42df0f6dc 100644 --- a/paddle/fluid/operators/shrink_rnn_memory_op.cc +++ b/paddle/fluid/operators/shrink_rnn_memory_op.cc @@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x_var = scope.FindVar(Input("X")); PADDLE_ENFORCE(x_var != nullptr, "Input X must be set"); auto &x_tensor = x_var->Get(); @@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); auto *dx_var = scope.FindVar(Output(framework::GradVarName("X"))); PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr"); diff --git a/paddle/fluid/operators/split_lod_tensor_op.cc b/paddle/fluid/operators/split_lod_tensor_op.cc index f821dc54d7bbe697d3642e64dc1628ec7d966592..f9600d99a36f59feddfbb5295b8b21ca6d5034cd 100644 --- a/paddle/fluid/operators/split_lod_tensor_op.cc +++ b/paddle/fluid/operators/split_lod_tensor_op.cc @@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &mask = scope.FindVar(Input("Mask"))->Get(); auto *out_true = diff --git a/paddle/fluid/operators/tensor_array_read_write_op.cc b/paddle/fluid/operators/tensor_array_read_write_op.cc index 50811fb22491598849216f41a584ae0b68f8f306..704ee964c908c44d84316985429a6551b770e33f 100644 --- a/paddle/fluid/operators/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/tensor_array_read_write_op.cc @@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); if (x == nullptr) return; auto &x_tensor = x->Get(); @@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); PADDLE_ENFORCE(x != nullptr, "X must be set"); auto &x_array = x->Get(); diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index d254c572acff52d967e551c377b3b32b05c92973..a7a05cc5f79da6c1e6945a83f997e54041d2045d 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); @@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place);