未验证 提交 5046869e 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8287 from tonyyang-svail/operator_set_device

Correctly handle cuda place for operators
...@@ -25,7 +25,10 @@ namespace framework { ...@@ -25,7 +25,10 @@ namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......
...@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { ...@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
} }
} }
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
#else
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
#endif
}
RunImpl(scope, place);
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -479,7 +491,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -479,7 +491,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
......
...@@ -89,8 +89,9 @@ class OperatorBase { ...@@ -89,8 +89,9 @@ class OperatorBase {
std::string DebugString() const { return DebugStringEx(nullptr); } std::string DebugString() const { return DebugStringEx(nullptr); }
/// Net will call this function to Run an op. /// Net will call this interface function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0; // The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop. // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {} virtual void Stop() {}
...@@ -144,6 +145,8 @@ class OperatorBase { ...@@ -144,6 +145,8 @@ class OperatorBase {
private: private:
void GenerateTemporaryNames(); void GenerateTemporaryNames();
void CheckAllInputOutputSet() const; void CheckAllInputOutputSet() const;
virtual void RunImpl(const Scope& scope,
const platform::Place& place) const = 0;
}; };
// Macro for define a clone method. // Macro for define a clone method.
...@@ -168,10 +171,13 @@ class OperatorBase { ...@@ -168,10 +171,13 @@ class OperatorBase {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
std::unique_ptr<OperatorBase> Clone() const override { std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this)); return std::unique_ptr<OperatorBase>(new NOP(*this));
} }
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class ExecutionContext { class ExecutionContext {
...@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const Scope& scope, const platform::Place& place) const final;
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() { AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
...@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
// same. // same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const; proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {} : OperatorBase(type, inputs, outputs, attrs), x(1) {}
void Run(const Scope& scope, const platform::Place& place) const override {
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {
++op_run_num; ++op_run_num;
ASSERT_EQ(static_cast<int>(inputs_.size()), 1); ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
ASSERT_EQ(static_cast<int>(outputs_.size()), 1); ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
...@@ -259,7 +262,9 @@ class OperatorClone : public paddle::framework::OperatorBase { ...@@ -259,7 +262,9 @@ class OperatorClone : public paddle::framework::OperatorBase {
const paddle::framework::VariableNameMap& outputs, const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs) const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const paddle::framework::Scope& scope,
private:
void RunImpl(const paddle::framework::Scope& scope,
const paddle::platform::Place& place) const override {} const paddle::platform::Place& place) const override {}
}; };
......
...@@ -31,7 +31,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -31,7 +31,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &rank_table = auto &rank_table =
......
...@@ -71,7 +71,9 @@ class AssignOp : public framework::OperatorBase { ...@@ -71,7 +71,9 @@ class AssignOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
if (x == nullptr) { if (x == nullptr) {
......
...@@ -55,7 +55,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -55,7 +55,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place); auto& dev_ctx = *pool.Get(dev_place);
......
...@@ -204,7 +204,8 @@ class BeamSearchOp : public framework::OperatorBase { ...@@ -204,7 +204,8 @@ class BeamSearchOp : public framework::OperatorBase {
PADDLE_THROW("Not Implemented"); PADDLE_THROW("Not Implemented");
} }
void Run(const framework::Scope& scope, private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
auto ids_var = scope.FindVar(Input("ids")); auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores")); auto scores_var = scope.FindVar(Input("scores"));
......
...@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
} }
} }
void CondOp::Run(const Scope& scope, const platform::Place& place) const { void CondOp::RunImpl(const Scope& scope, const platform::Place& place) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(place); auto& dev_ctx = *pool.Get(place);
......
...@@ -77,7 +77,8 @@ class CondOp : public framework::OperatorBase { ...@@ -77,7 +77,8 @@ class CondOp : public framework::OperatorBase {
sub_net_op_[FALSE_BRANCH] = std::move(net); sub_net_op_[FALSE_BRANCH] = std::move(net);
} }
void Run(const framework::Scope& scope, private:
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override; const platform::Place& place) const override;
private: private:
......
...@@ -65,7 +65,9 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -65,7 +65,9 @@ class ConditionalBlockOp : public ConditionalOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto xs = InputTensors(scope); auto xs = InputTensors(scope);
...@@ -128,7 +130,9 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -128,7 +130,9 @@ class ConditionalBlockGradOp : public ConditionalOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope); auto xs = this->InputTensors(scope);
......
...@@ -106,7 +106,9 @@ template <typename T> ...@@ -106,7 +106,9 @@ template <typename T>
class CreateRandomDataGeneratorOp : public framework::OperatorBase { class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat"); const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
...@@ -155,7 +157,9 @@ class CreateRandomDataGeneratorOpMaker ...@@ -155,7 +157,9 @@ class CreateRandomDataGeneratorOpMaker
class CreateShuffleReaderOp : public framework::OperatorBase { class CreateShuffleReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
...@@ -187,7 +191,9 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -187,7 +191,9 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class CreateBatchReaderOp : public framework::OperatorBase { class CreateBatchReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
......
...@@ -24,7 +24,9 @@ class FeedOp : public framework::OperatorBase { ...@@ -24,7 +24,9 @@ class FeedOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto feed_var_name = Input("X"); auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name); auto *feed_var = scope.FindVar(feed_var_name);
......
...@@ -26,7 +26,8 @@ class FetchOp : public framework::OperatorBase { ...@@ -26,7 +26,8 @@ class FetchOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto fetch_var_name = Input("X"); auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name); auto *fetch_var = scope.FindVar(fetch_var_name);
......
...@@ -33,7 +33,9 @@ class FillConstantInferShape : public framework::InferShapeBase { ...@@ -33,7 +33,9 @@ class FillConstantInferShape : public framework::InferShapeBase {
class FillConstantOp : public framework::OperatorBase { class FillConstantOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto data_type = auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype")); static_cast<framework::proto::DataType>(Attr<int>("dtype"));
......
...@@ -42,7 +42,9 @@ class FillOp : public framework::OperatorBase { ...@@ -42,7 +42,9 @@ class FillOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &out = auto &out =
detail::Ref(detail::Ref(scope.FindVar(Output("Out")), detail::Ref(detail::Ref(scope.FindVar(Output("Out")),
......
...@@ -37,7 +37,9 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -37,7 +37,9 @@ class GetPlacesOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
bool is_gpu; bool is_gpu;
if (Attr<std::string>("device_type") == "AUTO") { if (Attr<std::string>("device_type") == "AUTO") {
......
...@@ -51,7 +51,8 @@ class IncrementOp : public framework::OperatorBase { ...@@ -51,7 +51,8 @@ class IncrementOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out = auto &out =
......
...@@ -28,7 +28,8 @@ class IsEmptyOp : public framework::OperatorBase { ...@@ -28,7 +28,8 @@ class IsEmptyOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
// get input // get input
auto *var = scope.FindVar(Input(kInput)); auto *var = scope.FindVar(Input(kInput));
......
...@@ -26,7 +26,9 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -26,7 +26,9 @@ class LoadCombineOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
......
...@@ -25,7 +25,9 @@ class LoadOp : public framework::OperatorBase { ...@@ -25,7 +25,9 @@ class LoadOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
......
...@@ -25,7 +25,9 @@ class LoDArrayLengthOp : public framework::OperatorBase { ...@@ -25,7 +25,9 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out = auto &out =
......
...@@ -23,7 +23,9 @@ class LoDRankTableOp : public framework::OperatorBase { ...@@ -23,7 +23,9 @@ class LoDRankTableOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto *out = auto *out =
......
...@@ -32,7 +32,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -32,7 +32,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s",
Input("X")) Input("X"))
......
...@@ -27,7 +27,8 @@ class MaxSeqenceLenOp : public framework::OperatorBase { ...@@ -27,7 +27,8 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto &rank_table = auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>(); scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
......
...@@ -27,7 +27,9 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -27,7 +27,9 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
...@@ -26,7 +26,8 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -26,7 +26,8 @@ class NCCLInitOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
const auto &name = Output("Communicator"); const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
......
...@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase { ...@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
this->CompleteAddOp(); this->CompleteAddOp();
} }
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
for (auto& op : ops_) {
op->Run(scope, place);
}
}
bool SupportGPU() const override { bool SupportGPU() const override {
for (auto& op : ops_) { for (auto& op : ops_) {
if (!op->SupportGPU()) { if (!op->SupportGPU()) {
...@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase { ...@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
std::vector<std::unique_ptr<framework::OperatorBase>> ops_; std::vector<std::unique_ptr<framework::OperatorBase>> ops_;
private: private:
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
for (auto& op : ops_) {
op->Run(scope, place);
}
}
bool add_op_done_{false}; bool add_op_done_{false};
std::set<std::string> intermediate_outputs_; std::set<std::string> intermediate_outputs_;
......
...@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase { ...@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp); DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope, const platform::Place& place) const override {
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {
++run_cnt; ++run_cnt;
} }
}; };
......
...@@ -118,7 +118,8 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -118,7 +118,8 @@ class ParallelDoOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -207,7 +208,8 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -207,7 +208,8 @@ class ParallelDoGradOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *block = Attr<framework::BlockDesc *>(kParallelBlock); auto *block = Attr<framework::BlockDesc *>(kParallelBlock);
auto *program = block->Program(); auto *program = block->Program();
......
...@@ -130,7 +130,8 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -130,7 +130,8 @@ class TensorPrintOp : public framework::OperatorBase {
PADDLE_THROW("Not implemented."); PADDLE_THROW("Not implemented.");
} }
void Run(const framework::Scope& scope, private:
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
const framework::Variable* in_var_ptr = nullptr; const framework::Variable* in_var_ptr = nullptr;
std::string phase = kForward; std::string phase = kForward;
......
...@@ -54,7 +54,9 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -54,7 +54,9 @@ class ReadInferVarType : public framework::VarTypeInference {
class ReadOp : public framework::OperatorBase { class ReadOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>(); scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
......
...@@ -226,7 +226,8 @@ class RecurrentOp : public RecurrentBase { ...@@ -226,7 +226,8 @@ class RecurrentOp : public RecurrentBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: RecurrentBase(type, inputs, outputs, attrs) {} : RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope)); auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
VLOG(3) << "Static RNN input sequence length = " << seq_len; VLOG(3) << "Static RNN input sequence length = " << seq_len;
...@@ -315,7 +316,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -315,7 +316,8 @@ class RecurrentGradOp : public RecurrentBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: RecurrentBase(type, inputs, outputs, attrs) {} : RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(GetSequenceLength(scope)); auto seq_len = static_cast<size_t>(GetSequenceLength(scope));
StepScopes scopes = CreateStepScopes(scope, seq_len); StepScopes scopes = CreateStepScopes(scope, seq_len);
......
...@@ -75,7 +75,9 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -75,7 +75,9 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &x = auto &x =
detail::Ref(scope.FindVar(Input("X")), detail::Ref(scope.FindVar(Input("X")),
......
...@@ -24,7 +24,9 @@ class RNNMemoryHelperOp : public framework::OperatorBase { ...@@ -24,7 +24,9 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto mem_var_name = Input("X"); auto mem_var_name = Input("X");
auto *mem_var = scope.FindVar(mem_var_name); auto *mem_var = scope.FindVar(mem_var_name);
...@@ -76,7 +78,9 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -76,7 +78,9 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto out_grad_var_name = Input(framework::GradVarName("Out")); auto out_grad_var_name = Input(framework::GradVarName("Out"));
auto *out_grad_var = scope.FindVar(out_grad_var_name); auto *out_grad_var = scope.FindVar(out_grad_var_name);
......
...@@ -63,7 +63,9 @@ class SaveCombineOp : public framework::OperatorBase { ...@@ -63,7 +63,9 @@ class SaveCombineOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
......
...@@ -62,7 +62,9 @@ class SaveOp : public framework::OperatorBase { ...@@ -62,7 +62,9 @@ class SaveOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
......
...@@ -27,7 +27,8 @@ class ShrinkRNNMemoryOp : public ArrayOp { ...@@ -27,7 +27,8 @@ class ShrinkRNNMemoryOp : public ArrayOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *x_var = scope.FindVar(Input("X")); auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set"); PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
...@@ -108,7 +109,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -108,7 +109,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X"))); auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
......
...@@ -33,7 +33,9 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -33,7 +33,9 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>(); auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
......
...@@ -24,7 +24,8 @@ class WriteToArrayOp : public ArrayOp { ...@@ -24,7 +24,8 @@ class WriteToArrayOp : public ArrayOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
if (x == nullptr) return; if (x == nullptr) return;
...@@ -122,7 +123,9 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -122,7 +123,9 @@ class ReadFromArrayOp : public ArrayOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set"); PADDLE_ENFORCE(x != nullptr, "X must be set");
......
...@@ -39,7 +39,8 @@ class WhileOp : public framework::OperatorBase { ...@@ -39,7 +39,8 @@ class WhileOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>(); auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
...@@ -99,7 +100,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -99,7 +100,8 @@ class WhileGradOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册