未验证 提交 735eba29 编写于 作者: D dzhwinter 提交者: GitHub

Feature/operator run place (#6783)

* "change operator interface"

* "move devicepool to device_context"

* "fix operator test"

* "fix op_registry Run interface"

* "net op passed. Need to fix nccl multi-Context"

* "add nccl group function"

* "add nccl group function"

* "fix gpu count exceed 32 error"

* "fix recurrent op, nccl op"

* "change the other operators interface with Place"

* "fix typo"

* "fix pybind"

* "fix device in python side"

* "fix pybind failed"

* "add init for test"

* "fix CI"
上级 b8de1401
...@@ -291,10 +291,10 @@ public: ...@@ -291,10 +291,10 @@ public:
} }
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::Place& place) const override {
PADDLE_ENFORCE(symbols_ready_, "operators and variables should be created first."); PADDLE_ENFORCE(symbols_ready_, "operators and variables should be created first.");
for (auto& op : runtime_table_.ops()) { for (auto& op : runtime_table_.ops()) {
op->Run(scope, dev_ctx); op->Run(scope, place);
} }
} }
......
...@@ -30,7 +30,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) ...@@ -30,7 +30,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
...@@ -59,5 +59,5 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry ...@@ -59,5 +59,5 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(init SRCS init.cc DEPS gflags executor place stringpiece) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init) cc_test(init_test SRCS init_test.cc DEPS init)
...@@ -33,13 +33,7 @@ namespace framework { ...@@ -33,13 +33,7 @@ namespace framework {
const std::string kFeedOpType = "feed"; const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch"; const std::string kFetchOpType = "fetch";
DeviceContextPool* DeviceContextPool::pool = nullptr; Executor::Executor(const platform::Place& place) : place_(place) {}
Executor::Executor(const std::vector<platform::Place>& places) {
DeviceContextPool& pool = DeviceContextPool::Get();
auto borrowed_contexts = pool.Borrow(places);
device_contexts_.swap(borrowed_contexts);
}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) { if (var_type == proto::VarDesc::LOD_TENSOR) {
...@@ -71,7 +65,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -71,7 +65,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), pdesc.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), pdesc.Size());
auto& block = pdesc.Block(block_id); auto& block = pdesc.Block(block_id);
auto& device = device_contexts_[0];
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
...@@ -107,7 +100,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -107,7 +100,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString(); VLOG(3) << op->DebugString();
op->Run(*local_scope, *device); op->Run(*local_scope, place_);
} }
if (create_local_scope) { if (create_local_scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
......
...@@ -14,9 +14,6 @@ limitations under the License. */ ...@@ -14,9 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <unordered_map>
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/program_desc.h" #include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -26,96 +23,13 @@ limitations under the License. */ ...@@ -26,96 +23,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class DeviceContextPool {
public:
static DeviceContextPool& Get() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool;
}
static DeviceContextPool& Create(const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}
const platform::DeviceContext* Borrow(const platform::Place& place) {
auto range = device_contexts_.equal_range(place);
if (range.first == range.second) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
return range.first->second;
}
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto range = device_contexts_.equal_range(place);
if (range.first == range.second) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
// TODO(dzhwinter) : assign the first found device. Will enhanced later.
// device load balancer maybe useful here.
borrowed_contexts.emplace_back(range.first->second);
}
return borrowed_contexts;
}
explicit DeviceContextPool(const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_.emplace(
places[i], new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(
places[i], new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
~DeviceContextPool() {}
private:
static DeviceContextPool* pool;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
return hash_(place.which());
}
};
std::unordered_multimap<const platform::Place, const platform::DeviceContext*,
Hash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
class Executor { class Executor {
public: public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed // TODO(dzhwinter) : Do not rely on this function, it will be removed
explicit Executor(const platform::DeviceContext& device) explicit Executor(const platform::DeviceContext& device)
: Executor(std::vector<platform::Place>({device.GetPlace()})) {} : Executor(device.GetPlace()) {}
explicit Executor(const platform::Place& place)
: Executor(std::vector<platform::Place>({place})) {}
explicit Executor(const std::vector<platform::Place>& places); explicit Executor(const platform::Place& place);
/* @Brief /* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope * Runtime evaluation of the given ProgramDesc under certain Scope
...@@ -128,7 +42,7 @@ class Executor { ...@@ -128,7 +42,7 @@ class Executor {
bool create_vars = true); bool create_vars = true);
private: private:
std::vector<const platform::DeviceContext*> device_contexts_; const platform::Place place_;
}; };
} // namespace framework } // namespace framework
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include "paddle/framework/executor.h"
#include "paddle/framework/init.h" #include "paddle/framework/init.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/string/piece.h" #include "paddle/string/piece.h"
...@@ -48,7 +48,7 @@ bool InitDevices(const std::vector<std::string> &devices) { ...@@ -48,7 +48,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
std::vector<platform::Place> places; std::vector<platform::Place> places;
for (auto &device : devices) { for (auto &device : devices) {
auto p = string::Piece(device); auto p = string::Piece(device);
if (string::Find(p, ':', 0) == string::Piece::npos) { if (string::HasPrefix(p, "CPU")) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
} else if (string::HasPrefix(p, "GPU")) { } else if (string::HasPrefix(p, "GPU")) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -69,10 +69,9 @@ bool InitDevices(const std::vector<std::string> &devices) { ...@@ -69,10 +69,9 @@ bool InitDevices(const std::vector<std::string> &devices) {
return platform::is_cpu_place(place); return platform::is_cpu_place(place);
}) == places.end()) { }) == places.end()) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified any device, use CPU by Default."; LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
} }
DeviceContextPool::Create(places); platform::DeviceContextPool::Create(places);
return true;
return true; return true;
} }
......
...@@ -23,5 +23,9 @@ TEST(Init, InitDevices) { ...@@ -23,5 +23,9 @@ TEST(Init, InitDevices) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::vector<std::string> ds2 = {"CPU", "GPU:0", "GPU:1"}; std::vector<std::string> ds2 = {"CPU", "GPU:0", "GPU:1"};
ASSERT_EQ(InitDevices(ds2), true); ASSERT_EQ(InitDevices(ds2), true);
// test re-init
std::vector<std::string> ds3 = {"GPU:0", "GPU:1"};
ASSERT_EQ(InitDevices(ds3), true);
#endif #endif
} }
...@@ -8,8 +8,7 @@ namespace framework { ...@@ -8,8 +8,7 @@ namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, void Run(const Scope& scope, const platform::Place& place) const override {}
const platform::DeviceContext& dev_ctx) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -28,8 +27,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -28,8 +27,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, void Run(const Scope& scope, const platform::Place& place) const override {}
const platform::DeviceContext& dev_ctx) const override {}
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -76,8 +74,8 @@ TEST(OpRegistry, CreateOp) { ...@@ -76,8 +74,8 @@ TEST(OpRegistry, CreateOp) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUPlace cpu_place;
op->Run(scope, dev_ctx); op->Run(scope, cpu_place);
float scale_get = op->Attr<float>("scale"); float scale_get = op->Attr<float>("scale");
ASSERT_EQ(scale_get, scale); ASSERT_EQ(scale_get, scale);
} }
...@@ -117,8 +115,8 @@ TEST(OpRegistry, DefaultValue) { ...@@ -117,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUPlace cpu_place;
op->Run(scope, dev_ctx); op->Run(scope, cpu_place);
ASSERT_EQ(op->Attr<float>("scale"), 1.0); ASSERT_EQ(op->Attr<float>("scale"), 1.0);
} }
...@@ -167,9 +165,9 @@ TEST(OpRegistry, CustomChecker) { ...@@ -167,9 +165,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_type(paddle::framework::proto::AttrType::INT); attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(4); attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope; paddle::framework::Scope scope;
op->Run(scope, dev_ctx); op->Run(scope, cpu_place);
int test_attr = op->Attr<int>("test_attr"); int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
......
...@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/shape_inference.h" #include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h" #include "paddle/framework/var_type.h"
...@@ -388,11 +390,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -388,11 +390,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
}; };
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
ExecutionContext ctx(*this, scope, dev_ctx); auto dev_ctx = pool.Borrow(place);
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
...@@ -404,6 +406,8 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -404,6 +406,8 @@ void OperatorWithKernel::Run(const Scope& scope,
// check if op[type] have kernel for kernel_key // check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx);
auto kernel_key = GetKernelType(ctx); auto kernel_key = GetKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key); auto kernel_iter = kernels.find(kernel_key);
......
...@@ -83,8 +83,7 @@ class OperatorBase { ...@@ -83,8 +83,7 @@ class OperatorBase {
virtual std::string DebugString() const; virtual std::string DebugString() const;
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const Scope& scope, virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
const platform::DeviceContext& dev_ctx) const = 0;
virtual bool IsNetOp() const { return false; } virtual bool IsNetOp() const { return false; }
...@@ -159,8 +158,7 @@ class OperatorBase { ...@@ -159,8 +158,7 @@ class OperatorBase {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, void Run(const Scope& scope, const platform::Place& place) const override {}
const platform::DeviceContext& dev_ctx) const override {}
std::unique_ptr<OperatorBase> Clone() const override { std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this)); return std::unique_ptr<OperatorBase>(new NOP(*this));
} }
...@@ -383,8 +381,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -383,8 +381,7 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const Scope& scope, void Run(const Scope& scope, const platform::Place& place) const final;
const platform::DeviceContext& dev_ctx) const final;
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() { AllOpKernels() {
......
...@@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/operator.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,8 +28,7 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -27,8 +28,7 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {} : OperatorBase(type, inputs, outputs, attrs), x(1) {}
void Run(const Scope& scope, void Run(const Scope& scope, const platform::Place& place) const override {
const platform::DeviceContext& dev_ctx) const override {
++op_run_num; ++op_run_num;
ASSERT_EQ(static_cast<int>(inputs_.size()), 1); ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
ASSERT_EQ(static_cast<int>(outputs_.size()), 1); ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
...@@ -41,10 +41,9 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -41,10 +41,9 @@ class OpWithoutKernelTest : public OperatorBase {
int x{0}; int x{0};
}; };
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpeWithoutKernelTestProtoAndCheckerMaker(OpProto* proto, OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
AddOutput("output", "output of test op"); AddOutput("output", "output of test op");
...@@ -65,11 +64,12 @@ static void BuildVar(const std::string& param_name, ...@@ -65,11 +64,12 @@ static void BuildVar(const std::string& param_name,
} }
} }
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OP_WITHOUT_GRADIENT(test_operator,
test_operator, paddle::framework::OpWithoutKernelTest, paddle::framework::OpWithoutKernelTest,
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); paddle::framework::OpWithoutKernelCheckerMaker);
TEST(OperatorBase, all) { TEST(OperatorBase, all) {
paddle::framework::InitDevices({"CPU"});
paddle::framework::proto::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator"); op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs()); BuildVar("input", {"IN1"}, op_desc.add_inputs());
...@@ -80,13 +80,13 @@ TEST(OperatorBase, all) { ...@@ -80,13 +80,13 @@ TEST(OperatorBase, all) {
attr->set_type(paddle::framework::proto::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope.Var("OUT1"); scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, cpu_place);
ASSERT_EQ(paddle::framework::op_run_num, 1); ASSERT_EQ(paddle::framework::op_run_num, 1);
} }
...@@ -123,7 +123,6 @@ template <typename T1, typename T2> ...@@ -123,7 +123,6 @@ template <typename T1, typename T2>
class CPUKernelTest : public OpKernel<float> { class CPUKernelTest : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op().DebugString() << std::endl; std::cout << ctx.op().DebugString() << std::endl;
cpu_kernel_run_num++; cpu_kernel_run_num++;
ASSERT_EQ(ctx.op().Input("x"), "IN1"); ASSERT_EQ(ctx.op().Input("x"), "IN1");
...@@ -195,6 +194,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, ...@@ -195,6 +194,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input // test with single input
TEST(OpKernel, all) { TEST(OpKernel, all) {
paddle::framework::InitDevices({"CPU"});
paddle::framework::proto::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs()); BuildVar("x", {"IN1"}, op_desc.add_inputs());
...@@ -205,12 +205,12 @@ TEST(OpKernel, all) { ...@@ -205,12 +205,12 @@ TEST(OpKernel, all) {
attr->set_type(paddle::framework::proto::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_place);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
} }
...@@ -224,7 +224,9 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, ...@@ -224,7 +224,9 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) { TEST(OpKernel, multi_inputs) {
using namespace paddle::framework; using namespace paddle::framework;
paddle::framework::InitDevices({"CPU"});
proto::OpDesc op_desc; proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel"); op_desc.set_type("op_multi_inputs_with_kernel");
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs()); BuildVar("k", {"k0"}, op_desc.add_inputs());
...@@ -235,7 +237,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -235,7 +237,7 @@ TEST(OpKernel, multi_inputs) {
attr->set_type(paddle::framework::proto::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope; paddle::framework::Scope scope;
scope.Var("x0")->GetMutable<LoDTensor>(); scope.Var("x0")->GetMutable<LoDTensor>();
scope.Var("x1")->GetMutable<LoDTensor>(); scope.Var("x1")->GetMutable<LoDTensor>();
...@@ -245,7 +247,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -245,7 +247,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y1")->GetMutable<LoDTensor>(); scope.Var("y1")->GetMutable<LoDTensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_place);
} }
class OperatorClone : public paddle::framework::OperatorBase { class OperatorClone : public paddle::framework::OperatorBase {
...@@ -257,10 +259,11 @@ class OperatorClone : public paddle::framework::OperatorBase { ...@@ -257,10 +259,11 @@ class OperatorClone : public paddle::framework::OperatorBase {
const paddle::framework::AttributeMap& attrs) const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const paddle::framework::Scope& scope, void Run(const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {} const paddle::platform::Place& place) const override {}
}; };
TEST(Operator, Clone) { TEST(Operator, Clone) {
paddle::framework::InitDevices({"CPU"});
OperatorClone a("ABC", paddle::framework::VariableNameMap{}, OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{}, paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{}); paddle::framework::AttributeMap{});
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,11 +28,16 @@ class ArrayOp : public framework::OperatorBase { ...@@ -27,11 +28,16 @@ class ArrayOp : public framework::OperatorBase {
protected: protected:
size_t GetOffset(const framework::Scope &scope, size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const { const platform::Place &place) const {
auto *i = scope.FindVar(Input("I")); auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set"); PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>(); auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
size_t offset; size_t offset;
if (platform::is_gpu_place(i_tensor.place())) { if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU // FIXME: Avoid copy from GPU to CPU
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <numeric> #include <numeric>
#include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,7 +32,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -30,7 +32,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &rank_table = auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>(); scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
...@@ -103,6 +105,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -103,6 +105,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
continue; continue;
} }
auto slice = out->Slice(out_offset, out_offset + len); auto slice = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice); dev_ctx, &slice);
out_offset += len; out_offset += len;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/framework/data_type.h" #include "paddle/framework/data_type.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/var_type.h" #include "paddle/framework/var_type.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -71,7 +72,7 @@ class AssignOp : public framework::OperatorBase { ...@@ -71,7 +72,7 @@ class AssignOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
if (x == nullptr) { if (x == nullptr) {
return; return;
...@@ -80,6 +81,10 @@ class AssignOp : public framework::OperatorBase { ...@@ -80,6 +81,10 @@ class AssignOp : public framework::OperatorBase {
PADDLE_ENFORCE( PADDLE_ENFORCE(
out != nullptr, out != nullptr,
"The Output(Out) should not be null if the Input(X) is set."); "The Output(Out) should not be null if the Input(X) is set.");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/beam_search_decode_op.h" #include "paddle/operators/beam_search_decode_op.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -55,7 +56,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -55,7 +56,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto& dev_ctx = *pool.Borrow(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx); framework::ExecutionContext ctx(*this, scope, dev_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids"); const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
......
...@@ -189,7 +189,7 @@ class BeamSearchOp : public framework::OperatorBase { ...@@ -189,7 +189,7 @@ class BeamSearchOp : public framework::OperatorBase {
} }
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::Place& dev_place) const override {
LOG(INFO) << "run beam search op"; LOG(INFO) << "run beam search op";
auto ids_var = scope.FindVar(Input("ids")); auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores")); auto scores_var = scope.FindVar(Input("scores"));
......
...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/cond_op.h" #include "paddle/operators/cond_op.h"
#include "paddle/operators/gather.h" #include "paddle/operators/gather.h"
#include "paddle/operators/scatter.h" #include "paddle/operators/scatter.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -193,12 +193,15 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -193,12 +193,15 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
} }
} }
void CondOp::Run(const Scope& scope, void CondOp::Run(const Scope& scope, const platform::Place& place) const {
const platform::DeviceContext& dev_ctx) const { // get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto& dev_ctx = *pool.Borrow(place);
PrepareDataForSubnet(scope, dev_ctx); PrepareDataForSubnet(scope, dev_ctx);
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope); std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);
for (int i = 0; i < BRANCH_NUM; ++i) { for (int i = 0; i < BRANCH_NUM; ++i) {
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); sub_net_op_[i]->Run(*sub_scopes[i], place);
} }
MergeDataFromSubnet(scope, dev_ctx); MergeDataFromSubnet(scope, dev_ctx);
} }
......
...@@ -78,7 +78,7 @@ class CondOp : public framework::OperatorBase { ...@@ -78,7 +78,7 @@ class CondOp : public framework::OperatorBase {
} }
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override; const platform::Place& place) const override;
private: private:
const int TRUE_BRANCH = 0; const int TRUE_BRANCH = 0;
......
...@@ -51,7 +51,7 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -51,7 +51,7 @@ class ConditionalBlockOp : public ConditionalOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto xs = InputTensors(scope); auto xs = InputTensors(scope);
bool need_run = std::all_of( bool need_run = std::all_of(
xs.begin(), xs.end(), xs.begin(), xs.end(),
...@@ -65,8 +65,8 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -65,8 +65,8 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->front() = &scope.NewScope(); scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front(); auto &cur_scope = *scopes->front();
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false); exec.Run(*block->Program(), &cur_scope, block->ID(), false);
} }
} }
...@@ -104,7 +104,7 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -104,7 +104,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope); auto xs = this->InputTensors(scope);
bool need_run = std::all_of( bool need_run = std::all_of(
xs.begin(), xs.end(), xs.begin(), xs.end(),
...@@ -116,21 +116,21 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -116,21 +116,21 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>(); auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
framework::Scope &cur_scope = *scopes[0]; framework::Scope &cur_scope = *scopes[0];
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false); exec.Run(*block->Program(), &cur_scope, block->ID(), false);
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("Params"), AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Params"),
Outputs(framework::GradVarName("Params"))); Outputs(framework::GradVarName("Params")));
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("X"), AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("X"),
Outputs(framework::GradVarName("X"))); Outputs(framework::GradVarName("X")));
} }
} }
private: private:
void AssignLocalGradientToGlobal( void AssignLocalGradientToGlobal(
const platform::DeviceContext &dev_ctx, const framework::Scope &cur_scope, const platform::Place &place, const framework::Scope &cur_scope,
const std::vector<std::string> &p_names, const std::vector<std::string> &p_names,
const std::vector<std::string> &pg_names) const { const std::vector<std::string> &pg_names) const {
for (size_t i = 0; i < p_names.size(); ++i) { for (size_t i = 0; i < p_names.size(); ++i) {
...@@ -144,7 +144,7 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -144,7 +144,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto assign = framework::OpRegistry::CreateOp( auto assign = framework::OpRegistry::CreateOp(
"assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}}, "assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}},
framework::AttributeMap{}); framework::AttributeMap{});
assign->Run(cur_scope, dev_ctx); assign->Run(cur_scope, place);
cur_scope.Rename(new_in_grad_name, in_grad_name); cur_scope.Rename(new_in_grad_name, in_grad_name);
} }
} }
......
...@@ -25,7 +25,7 @@ class FeedOp : public framework::OperatorBase { ...@@ -25,7 +25,7 @@ class FeedOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto feed_var_name = Input("X"); auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name); auto *feed_var = scope.FindVar(feed_var_name);
...@@ -47,7 +47,12 @@ class FeedOp : public framework::OperatorBase { ...@@ -47,7 +47,12 @@ class FeedOp : public framework::OperatorBase {
auto &feed_list = feed_var->Get<framework::FeedFetchList>(); auto &feed_list = feed_var->Get<framework::FeedFetchList>();
auto &feed_item = feed_list.at(static_cast<size_t>(col)); auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>(); auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
framework::CopyFrom(feed_item, dev_ctx.GetPlace(), dev_ctx, out_item);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(feed_item, place, dev_ctx, out_item);
out_item->set_lod(feed_item.lod()); out_item->set_lod(feed_item.lod());
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +27,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -26,7 +27,7 @@ class FetchOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto fetch_var_name = Input("X"); auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name); auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr, PADDLE_ENFORCE(fetch_var != nullptr,
...@@ -51,6 +52,9 @@ class FetchOp : public framework::OperatorBase { ...@@ -51,6 +52,9 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait(); dev_ctx.Wait();
dst_item.set_lod(src_item.lod()); dst_item.set_lod(src_item.lod());
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/framework/data_type.h" #include "paddle/framework/data_type.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,7 +34,7 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -33,7 +34,7 @@ class FillConstantOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto data_type = auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype")); static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value"); auto value = Attr<float>("value");
...@@ -45,8 +46,11 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -45,8 +46,11 @@ class FillConstantOp : public framework::OperatorBase {
auto cpu = platform::CPUPlace(); auto cpu = platform::CPUPlace();
out.mutable_data(cpu, framework::ToTypeIndex(data_type)); out.mutable_data(cpu, framework::ToTypeIndex(data_type));
} else { } else {
out.mutable_data(dev_ctx.GetPlace(), framework::ToTypeIndex(data_type)); out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
math::set_constant(dev_ctx, &out, value); math::set_constant(dev_ctx, &out, value);
} }
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/framework/data_type.h" #include "paddle/framework/data_type.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h" #include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,7 +43,7 @@ class FillOp : public framework::OperatorBase { ...@@ -42,7 +43,7 @@ class FillOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto &out = auto &out =
detail::Ref(detail::Ref(scope.FindVar(Output("Out")), detail::Ref(detail::Ref(scope.FindVar(Output("Out")),
"Cannot find variable %s", Output("Out")) "Cannot find variable %s", Output("Out"))
...@@ -51,12 +52,11 @@ class FillOp : public framework::OperatorBase { ...@@ -51,12 +52,11 @@ class FillOp : public framework::OperatorBase {
auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype")); auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype"));
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : dev_ctx.GetPlace(), out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
framework::ToTypeIndex(dtype));
framework::LoDTensor tensor; framework::LoDTensor tensor;
if (force_cpu || platform::is_cpu_place(dev_ctx.GetPlace())) { if (force_cpu || platform::is_cpu_place(place)) {
tensor.ShareDataWith(out); tensor.ShareDataWith(out);
} else { } else {
// Always make tensor in CPU memory. // Always make tensor in CPU memory.
...@@ -67,9 +67,11 @@ class FillOp : public framework::OperatorBase { ...@@ -67,9 +67,11 @@ class FillOp : public framework::OperatorBase {
framework::VisitDataType( framework::VisitDataType(
dtype, FillOpVisitor(&tensor, Attr<std::vector<float>>("value"))); dtype, FillOpVisitor(&tensor, Attr<std::vector<float>>("value")));
if (!force_cpu && platform::is_gpu_place(dev_ctx.GetPlace())) { if (!force_cpu && platform::is_gpu_place(place)) {
// Copy tensor to out // Copy tensor to out
framework::CopyFrom(tensor, dev_ctx.GetPlace(), dev_ctx, &out); platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(tensor, place, dev_ctx, &out);
} }
} }
}; };
......
...@@ -52,7 +52,7 @@ class IncrementOp : public framework::OperatorBase { ...@@ -52,7 +52,7 @@ class IncrementOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out = auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); *scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
......
...@@ -29,7 +29,7 @@ class IsEmptyOp : public framework::OperatorBase { ...@@ -29,7 +29,7 @@ class IsEmptyOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
// get input // get input
auto *var = scope.FindVar(Input(kInput)); auto *var = scope.FindVar(Input(kInput));
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
......
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fstream>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/device_context.h"
#include <fstream>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +26,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -26,7 +26,7 @@ class LoadOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
...@@ -40,7 +40,9 @@ class LoadOp : public framework::OperatorBase { ...@@ -40,7 +40,9 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
framework::DeserializeFromStream(fin, tensor); framework::DeserializeFromStream(fin, tensor);
auto place = dev_ctx.GetPlace(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
// copy CPU to GPU // copy CPU to GPU
framework::LoDTensor cpu_tensor; framework::LoDTensor cpu_tensor;
......
...@@ -26,7 +26,7 @@ class LoDArrayLengthOp : public framework::OperatorBase { ...@@ -26,7 +26,7 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out = auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); *scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
......
...@@ -24,7 +24,7 @@ class LoDRankTableOp : public framework::OperatorBase { ...@@ -24,7 +24,7 @@ class LoDRankTableOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto *out = auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>(); scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h" #include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,7 +33,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -32,7 +33,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s",
Input("X")) Input("X"))
.Get<framework::LoDTensor>(); .Get<framework::LoDTensor>();
...@@ -86,6 +87,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -86,6 +87,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
// out[i][offset: offset+len] = x[each_range.begin: each_range.end] // out[i][offset: offset+len] = x[each_range.begin: each_range.end]
auto slice = out[i].Slice(static_cast<int>(offset), auto slice = out[i].Slice(static_cast<int>(offset),
static_cast<int>(offset + len)); static_cast<int>(offset + len));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin), framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)), static_cast<int>(each_range.end)),
x.place(), dev_ctx, &slice); x.place(), dev_ctx, &slice);
......
...@@ -28,7 +28,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase { ...@@ -28,7 +28,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto &rank_table = auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>(); scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
auto *out = auto *out =
......
...@@ -28,7 +28,11 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -28,7 +28,11 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>(); auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
auto &in_true = scope.FindVar(Input("InTrue"))->Get<framework::LoDTensor>(); auto &in_true = scope.FindVar(Input("InTrue"))->Get<framework::LoDTensor>();
......
...@@ -24,7 +24,7 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -24,7 +24,7 @@ class NCCLInitOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
const auto &name = Output("Communicator"); const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name); "Can not find variable '%s' in the scope.", name);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/init.h"
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/program_desc.h" #include "paddle/framework/program_desc.h"
...@@ -49,7 +50,7 @@ const f::DDim kDims = {100, 100}; ...@@ -49,7 +50,7 @@ const f::DDim kDims = {100, 100};
class NCCLTester : public ::testing::Test { class NCCLTester : public ::testing::Test {
public: public:
virtual void SetUp() override { virtual void SetUp() override {
cpu_ctx = new p::CPUDeviceContext(p::CPUPlace()); paddle::platform::CPUPlace cpu_place;
for (size_t i = 0; i < gpu_list.size(); ++i) { for (size_t i = 0; i < gpu_list.size(); ++i) {
p::GPUPlace place(i); p::GPUPlace place(i);
dev_ctxs.emplace_back(new p::CUDADeviceContext(place)); dev_ctxs.emplace_back(new p::CUDADeviceContext(place));
...@@ -65,6 +66,7 @@ class NCCLTester : public ::testing::Test { ...@@ -65,6 +66,7 @@ class NCCLTester : public ::testing::Test {
} }
void NCCLInitOp() { void NCCLInitOp() {
paddle::platform::CPUPlace cpu_place;
std::unique_ptr<f::OpDesc> op1(new f::OpDesc); std::unique_ptr<f::OpDesc> op1(new f::OpDesc);
op1->SetType("ncclInit"); op1->SetType("ncclInit");
...@@ -76,7 +78,7 @@ class NCCLTester : public ::testing::Test { ...@@ -76,7 +78,7 @@ class NCCLTester : public ::testing::Test {
auto op = f::OpRegistry::CreateOp(*op1); auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "invoke NCCLInitOp."; VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *cpu_ctx); op->Run(g_scope, cpu_place);
VLOG(1) << "NCCLInitOp finished."; VLOG(1) << "NCCLInitOp finished.";
} }
...@@ -111,13 +113,12 @@ class NCCLTester : public ::testing::Test { ...@@ -111,13 +113,12 @@ class NCCLTester : public ::testing::Test {
VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type(); VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type();
VLOG(1) << " send_tensor : " << send_tensor->numel() VLOG(1) << " send_tensor : " << send_tensor->numel()
<< " recv_tensor : " << recv_tensor->numel(); << " recv_tensor : " << recv_tensor->numel();
op->Run(*scope, *ctx); op->Run(*scope, place);
VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type(); VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type();
} }
public: public:
std::vector<p::DeviceContext *> dev_ctxs; std::vector<p::DeviceContext *> dev_ctxs;
p::DeviceContext *cpu_ctx;
f::Scope g_scope; f::Scope g_scope;
std::mutex mu; std::mutex mu;
}; };
...@@ -131,14 +132,14 @@ TEST(NCCL, ncclInitOp) { ...@@ -131,14 +132,14 @@ TEST(NCCL, ncclInitOp) {
op_desc->SetAttr("gpus", {gpu_list}); op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope; f::Scope g_scope;
std::unique_ptr<p::DeviceContext> ctx(new p::CPUDeviceContext(p::CPUPlace())); paddle::platform::CPUPlace cpu_place;
auto *var = g_scope.Var("x1"); auto *var = g_scope.Var("x1");
var->GetMutable<p::Communicator>(); var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc); auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp."; VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx.get()); op->Run(g_scope, cpu_place);
VLOG(1) << "NCCLInitOp finished."; VLOG(1) << "NCCLInitOp finished.";
} }
...@@ -294,9 +295,18 @@ int main(int argc, char **argv) { ...@@ -294,9 +295,18 @@ int main(int argc, char **argv) {
return 0; return 0;
} }
for (int i = 0; i < dev_count; ++i) { std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
gpu_list.emplace_back(i); gpu_list.emplace_back(i);
} }
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
// device context should be release before scope. // device context should be release before scope.
......
...@@ -65,9 +65,9 @@ class NetOp : public framework::OperatorBase { ...@@ -65,9 +65,9 @@ class NetOp : public framework::OperatorBase {
* will be used. * will be used.
*/ */
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::Place& place) const override {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(scope, dev_ctx); op->Run(scope, place);
} }
} }
......
...@@ -13,8 +13,7 @@ class TestOp : public framework::OperatorBase { ...@@ -13,8 +13,7 @@ class TestOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp); DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope, void Run(const Scope& scope, const platform::Place& place) const override {
const platform::DeviceContext& dev_ctx) const override {
++run_cnt; ++run_cnt;
} }
}; };
......
...@@ -227,14 +227,15 @@ class RecurrentOp : public RecurrentBase { ...@@ -227,14 +227,15 @@ class RecurrentOp : public RecurrentBase {
: RecurrentBase(type, inputs, outputs, attrs) {} : RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope)); auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
VLOG(3) << "Static RNN input sequence length = " << seq_len; VLOG(3) << "Static RNN input sequence length = " << seq_len;
StepScopes scopes = CreateStepScopes(scope, seq_len); StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
...@@ -270,6 +271,10 @@ class RecurrentOp : public RecurrentBase { ...@@ -270,6 +271,10 @@ class RecurrentOp : public RecurrentBase {
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
// Copy inside::output -> outside::output // Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output // outside::output[seq_offset: seq_offset + 1] = inside::output
this->LinkTensorWithCallback( this->LinkTensorWithCallback(
...@@ -278,14 +283,13 @@ class RecurrentOp : public RecurrentBase { ...@@ -278,14 +283,13 @@ class RecurrentOp : public RecurrentBase {
framework::LoDTensor *dst_tensor) { framework::LoDTensor *dst_tensor) {
if (i == 0) { // create output tensor at begin if (i == 0) { // create output tensor at begin
dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims())); dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims()));
dst_tensor->mutable_data(dev_ctx.GetPlace(), src_tensor.type()); dst_tensor->mutable_data(place, src_tensor.type());
} }
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1); auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
// Explicit copy output since the local RNN scope can be destroyed // Explicit copy output since the local RNN scope can be destroyed
// early. // early.
framework::CopyFrom(src_tensor, dev_ctx.GetPlace(), dev_ctx, framework::CopyFrom(src_tensor, place, dev_ctx, &dst_out);
&dst_out);
}); });
scopes.Next(); scopes.Next();
...@@ -311,15 +315,20 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -311,15 +315,20 @@ class RecurrentGradOp : public RecurrentBase {
: RecurrentBase(type, inputs, outputs, attrs) {} : RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(GetSequenceLength(scope)); auto seq_len = static_cast<size_t>(GetSequenceLength(scope));
StepScopes scopes = CreateStepScopes(scope, seq_len); StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
for (size_t step_id = 0; step_id < seq_len; ++step_id) { for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
VLOG(3) << "Recurrent backward operate at the time step " << seq_offset; VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
...@@ -366,8 +375,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -366,8 +375,7 @@ class RecurrentGradOp : public RecurrentBase {
auto *cur_grad_var = cur_scope.Var(cur_grad); auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor = auto cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>(); cur_grad_var->GetMutable<framework::LoDTensor>();
framework::CopyFrom(ex_tensor, dev_ctx.GetPlace(), dev_ctx, framework::CopyFrom(ex_tensor, place, dev_ctx, cur_grad_tensor);
cur_grad_tensor);
} }
} }
...@@ -410,7 +418,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -410,7 +418,7 @@ class RecurrentGradOp : public RecurrentBase {
auto zero_op = framework::OpRegistry::CreateOp( auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{}, "fill_constant", framework::VariableNameMap{},
{{"Out", {pg_names[param_id]}}}, attrs); {{"Out", {pg_names[param_id]}}}, attrs);
zero_op->Run(scope, dev_ctx); zero_op->Run(scope, place);
} }
auto new_inside_name = cur_scope.Rename(inside_grad_name); auto new_inside_name = cur_scope.Rename(inside_grad_name);
...@@ -419,7 +427,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -419,7 +427,7 @@ class RecurrentGradOp : public RecurrentBase {
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
sum_op->Run(cur_scope, dev_ctx); sum_op->Run(cur_scope, place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} }
...@@ -437,11 +445,11 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -437,11 +445,11 @@ class RecurrentGradOp : public RecurrentBase {
} }
if (step_id == 0) { // alloc memory if (step_id == 0) { // alloc memory
outside->Resize(PrependDims(seq_len, inside.dims())); outside->Resize(PrependDims(seq_len, inside.dims()));
outside->mutable_data(dev_ctx.GetPlace(), inside.type()); outside->mutable_data(place, inside.type());
} }
auto dst = outside->Slice(seq_offset, seq_offset + 1); auto dst = outside->Slice(seq_offset, seq_offset + 1);
framework::CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx, &dst); framework::CopyFrom(inside, place, dev_ctx, &dst);
}); });
VLOG(5) << "Link outside gradient finished "; VLOG(5) << "Link outside gradient finished ";
...@@ -453,8 +461,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -453,8 +461,8 @@ class RecurrentGradOp : public RecurrentBase {
[&](const framework::LoDTensor &inside, [&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) { framework::LoDTensor *outside) {
outside->Resize(inside.dims()); outside->Resize(inside.dims());
outside->mutable_data(dev_ctx.GetPlace(), inside.type()); outside->mutable_data(place, inside.type());
framework::CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx, outside); framework::CopyFrom(inside, place, dev_ctx, outside);
}); });
VLOG(5) << "Link initialize state gradient finished "; VLOG(5) << "Link initialize state gradient finished ";
} }
......
...@@ -73,7 +73,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -73,7 +73,7 @@ class RecvOp : public framework::OperatorBase {
} }
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run. // FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
...@@ -113,7 +113,9 @@ class RecvOp : public framework::OperatorBase { ...@@ -113,7 +113,9 @@ class RecvOp : public framework::OperatorBase {
auto *var = recv_scope.Var(grad_var_name); auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>(); auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy // FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(v.second, place, dev_ctx, tensor);
} }
rpc_service_->Reset(); rpc_service_->Reset();
...@@ -121,7 +123,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -121,7 +123,7 @@ class RecvOp : public framework::OperatorBase {
framework::proto::ProgramDesc program_desc; framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str); program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc); framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_ctx); framework::Executor executor(place);
// Run sub graph to get optimized tensor // Run sub graph to get optimized tensor
try { try {
executor.Run(program, &recv_scope, 0, /*global_block*/ executor.Run(program, &recv_scope, 0, /*global_block*/
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/lod_rank_table.h> #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h" #include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,7 +54,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -53,7 +54,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto &x = auto &x =
detail::Ref(scope.FindVar(Input("X")), detail::Ref(scope.FindVar(Input("X")),
"Cannot find input lod tensor variable %s", Input("X")) "Cannot find input lod tensor variable %s", Input("X"))
...@@ -69,11 +70,11 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -69,11 +70,11 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
out.Resize(x.dims()); out.Resize(x.dims());
out.mutable_data(x.place(), x.type()); out.mutable_data(x.place(), x.type());
this->process(dev_ctx, x, rank_table, &out); this->process(place, x, rank_table, &out);
} }
protected: protected:
virtual void process(const platform::DeviceContext &dev_ctx, virtual void process(const platform::Place &place,
const framework::LoDTensor &x, const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table, const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const = 0; framework::LoDTensor *out) const = 0;
...@@ -104,7 +105,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -104,7 +105,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
return absolute_table; return absolute_table;
} }
size_t CopyTensorAndLod(const platform::DeviceContext &dev_ctx, size_t CopyTensorAndLod(const platform::Place &place,
const AbsoluteRankTableItem &item, const AbsoluteRankTableItem &item,
const framework::LoDTensor &x, const framework::LoDTensor &x,
framework::LoDTensor *out, size_t out_offset) const { framework::LoDTensor *out, size_t out_offset) const {
...@@ -130,6 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -130,6 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
auto x_sliced = x.Slice(x_offset, x_offset + len); auto x_sliced = x.Slice(x_offset, x_offset + len);
auto out_sliced = out->Slice(out_offset, out_offset + len); auto out_sliced = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced);
out_offset += len; out_offset += len;
return out_offset; return out_offset;
...@@ -145,8 +148,7 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase { ...@@ -145,8 +148,7 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase {
: ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {} : ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {}
protected: protected:
void process(const platform::DeviceContext &dev_ctx, void process(const platform::Place &place, const framework::LoDTensor &x,
const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table, const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const override { framework::LoDTensor *out) const override {
auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x); auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x);
...@@ -154,7 +156,7 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase { ...@@ -154,7 +156,7 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase {
out->mutable_lod()->clear(); out->mutable_lod()->clear();
for (auto &item : rank_table.items()) { for (auto &item : rank_table.items()) {
PADDLE_ENFORCE_LT(item.index, absolute_table.size()); PADDLE_ENFORCE_LT(item.index, absolute_table.size());
out_offset = CopyTensorAndLod(dev_ctx, absolute_table[item.index], x, out, out_offset = CopyTensorAndLod(place, absolute_table[item.index], x, out,
out_offset); out_offset);
} }
} }
...@@ -192,8 +194,7 @@ class ReorderLoDTensorByRankGradOp : public ReorderLoDTensorByRankTableBase { ...@@ -192,8 +194,7 @@ class ReorderLoDTensorByRankGradOp : public ReorderLoDTensorByRankTableBase {
: ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {} : ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {}
protected: protected:
void process(const platform::DeviceContext &dev_ctx, void process(const platform::Place &place, const framework::LoDTensor &x,
const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table, const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const override { framework::LoDTensor *out) const override {
auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x); auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x);
...@@ -214,7 +215,7 @@ class ReorderLoDTensorByRankGradOp : public ReorderLoDTensorByRankTableBase { ...@@ -214,7 +215,7 @@ class ReorderLoDTensorByRankGradOp : public ReorderLoDTensorByRankTableBase {
// Copy TensorAndLod // Copy TensorAndLod
size_t out_offset = 0; size_t out_offset = 0;
for (auto &offset : offsets) { for (auto &offset : offsets) {
out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[offset.first], out_offset = this->CopyTensorAndLod(place, absolute_table[offset.first],
x, out, out_offset); x, out, out_offset);
} }
} }
......
...@@ -25,7 +25,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase { ...@@ -25,7 +25,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto mem_var_name = Input("X"); auto mem_var_name = Input("X");
auto *mem_var = scope.FindVar(mem_var_name); auto *mem_var = scope.FindVar(mem_var_name);
PADDLE_ENFORCE(mem_var != nullptr, PADDLE_ENFORCE(mem_var != nullptr,
...@@ -77,7 +77,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -77,7 +77,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto out_grad_var_name = Input(framework::GradVarName("Out")); auto out_grad_var_name = Input(framework::GradVarName("Out"));
auto *out_grad_var = scope.FindVar(out_grad_var_name); auto *out_grad_var = scope.FindVar(out_grad_var_name);
...@@ -100,7 +100,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -100,7 +100,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto zero_op = framework::OpRegistry::CreateOp( auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", {}, {{"Out", {in_grad_var_name}}}, attrs); "fill_constant", {}, {{"Out", {in_grad_var_name}}}, attrs);
zero_op->Run(scope, dev_ctx); zero_op->Run(scope, dev_place);
} else { } else {
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>(); auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>(); auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
......
...@@ -21,7 +21,7 @@ USE_NO_KERNEL_OP(load); ...@@ -21,7 +21,7 @@ USE_NO_KERNEL_OP(load);
TEST(SaveLoadOp, CPU) { TEST(SaveLoadOp, CPU) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUPlace place; paddle::platform::CPUPlace place;
paddle::platform::CPUDeviceContext ctx(place);
auto var = scope.Var("test_var"); auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>(); auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({10, 10}); tensor->Resize({10, 10});
...@@ -42,13 +42,13 @@ TEST(SaveLoadOp, CPU) { ...@@ -42,13 +42,13 @@ TEST(SaveLoadOp, CPU) {
auto save_op = paddle::framework::OpRegistry::CreateOp( auto save_op = paddle::framework::OpRegistry::CreateOp(
"save", {{"X", {"test_var"}}}, {}, attrs); "save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, ctx); save_op->Run(scope, place);
auto load_var = scope.Var("out_var"); auto load_var = scope.Var("out_var");
auto target = load_var->GetMutable<paddle::framework::LoDTensor>(); auto target = load_var->GetMutable<paddle::framework::LoDTensor>();
auto load_op = paddle::framework::OpRegistry::CreateOp( auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {"out_var"}}}, attrs); "load", {}, {{"Out", {"out_var"}}}, attrs);
load_op->Run(scope, ctx); load_op->Run(scope, place);
int* actual = target->data<int>(); int* actual = target->data<int>();
for (int64_t i = 0; i < tensor->numel(); ++i) { for (int64_t i = 0; i < tensor->numel(); ++i) {
EXPECT_EQ(expect[i], actual[i]); EXPECT_EQ(expect[i], actual[i]);
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -62,7 +63,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -62,7 +63,7 @@ class SaveOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
...@@ -88,6 +89,11 @@ class SaveOp : public framework::OperatorBase { ...@@ -88,6 +89,11 @@ class SaveOp : public framework::OperatorBase {
"SaveOp only support LoDTensor, %s has wrong type", iname); "SaveOp only support LoDTensor, %s has wrong type", iname);
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(fout, tensor, dev_ctx);
} }
}; };
......
...@@ -27,11 +27,11 @@ class ShrinkRNNMemoryOp : public ArrayOp { ...@@ -27,11 +27,11 @@ class ShrinkRNNMemoryOp : public ArrayOp {
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto *x_var = scope.FindVar(Input("X")); auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set"); PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
auto &x_tensor = x_var->Get<framework::LoDTensor>(); auto &x_tensor = x_var->Get<framework::LoDTensor>();
size_t offset = this->GetOffset(scope, dev_ctx); size_t offset = this->GetOffset(scope, place);
auto *rank_table_var = scope.FindVar(Input("RankTable")); auto *rank_table_var = scope.FindVar(Input("RankTable"));
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set"); PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>(); auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
...@@ -93,7 +93,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -93,7 +93,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X"))); auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr"); PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
...@@ -105,6 +105,10 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -105,6 +105,10 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
dx_tensor.Resize(x_tensor.dims()); dx_tensor.Resize(x_tensor.dims());
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type()); dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
if (dout_var == nullptr) { // dx_tensor fill zero if (dout_var == nullptr) { // dx_tensor fill zero
math::set_constant(dev_ctx, &dx_tensor, 0.0f); math::set_constant(dev_ctx, &dx_tensor, 0.0f);
} else { } else {
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,7 +34,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -33,7 +34,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>(); auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
auto *out_true = auto *out_true =
...@@ -44,6 +45,9 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -44,6 +45,9 @@ class SplitLoDTensorOp : public framework::OperatorBase {
auto &x_lod = x.lod(); auto &x_lod = x.lod();
auto &mask_dim = mask.dims(); auto &mask_dim = mask.dims();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()}; std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) { if (platform::is_cpu_place(mask.place())) {
cpu_mask->ShareDataWith(mask); cpu_mask->ShareDataWith(mask);
......
...@@ -25,11 +25,11 @@ class WriteToArrayOp : public ArrayOp { ...@@ -25,11 +25,11 @@ class WriteToArrayOp : public ArrayOp {
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
if (x == nullptr) return; if (x == nullptr) return;
auto &x_tensor = x->Get<framework::LoDTensor>(); auto &x_tensor = x->Get<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx); size_t offset = GetOffset(scope, place);
auto *out = auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>(); scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>();
if (offset >= out->size()) { if (offset >= out->size()) {
...@@ -39,7 +39,11 @@ class WriteToArrayOp : public ArrayOp { ...@@ -39,7 +39,11 @@ class WriteToArrayOp : public ArrayOp {
} }
if (x_tensor.memory_size() > 0) { if (x_tensor.memory_size() > 0) {
auto *out_tensor = &out->at(offset); auto *out_tensor = &out->at(offset);
CopyFrom(x_tensor, dev_ctx.GetPlace(), dev_ctx, out_tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
CopyFrom(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod()); out_tensor->set_lod(x_tensor.lod());
} else { } else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so " VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
...@@ -119,17 +123,18 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -119,17 +123,18 @@ class ReadFromArrayOp : public ArrayOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set"); PADDLE_ENFORCE(x != nullptr, "X must be set");
auto &x_array = x->Get<framework::LoDTensorArray>(); auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out")); auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set"); PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tensor = out->GetMutable<framework::LoDTensor>(); auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx); size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) { if (offset < x_array.size()) {
framework::CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx, platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
out_tensor); auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor);
out_tensor->set_lod(x_array[offset].lod()); out_tensor->set_lod(x_array[offset].lod());
} else { } else {
VLOG(10) << "offset " << offset << " >= " << x_array.size(); VLOG(10) << "offset " << offset << " >= " << x_array.size();
......
...@@ -40,13 +40,14 @@ class WhileOp : public framework::OperatorBase { ...@@ -40,13 +40,14 @@ class WhileOp : public framework::OperatorBase {
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>(); auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
framework::Executor executor(dev_ctx); framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto step_scopes = auto step_scopes =
...@@ -97,8 +98,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -97,8 +98,8 @@ class WhileGradOp : public framework::OperatorBase {
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
framework::Executor executor(dev_ctx); framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
...@@ -189,7 +190,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -189,7 +190,7 @@ class WhileGradOp : public framework::OperatorBase {
auto zero_op = framework::OpRegistry::CreateOp( auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{}, "fill_constant", framework::VariableNameMap{},
{{"Out", {pg_names[param_id]}}}, attrs); {{"Out", {pg_names[param_id]}}}, attrs);
zero_op->Run(scope, dev_ctx); zero_op->Run(scope, dev_place);
} }
} }
...@@ -197,7 +198,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -197,7 +198,7 @@ class WhileGradOp : public framework::OperatorBase {
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
sum_op->Run(cur_scope, dev_ctx); sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} }
} }
......
...@@ -25,7 +25,7 @@ ENDIF() ...@@ -25,7 +25,7 @@ ENDIF()
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS}) system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context) nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
......
...@@ -15,6 +15,59 @@ limitations under the License. */ ...@@ -15,6 +15,59 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Borrow(
const platform::Place& place) {
auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
return it->second;
}
std::vector<const platform::DeviceContext*> DeviceContextPool::Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto it = device_contexts_.find(place);
if (it != device_contexts_.end()) {
borrowed_contexts.emplace_back(it->second);
} else {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
}
return borrowed_contexts;
}
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_.emplace(places[i],
new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(places[i],
new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
CPUDeviceContext::CPUDeviceContext() { CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice()); eigen_device_.reset(new Eigen::DefaultDevice());
} }
......
...@@ -11,8 +11,8 @@ limitations under the License. */ ...@@ -11,8 +11,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/platform/enforce.h" #include <memory>
#include "paddle/platform/place.h" #include <unordered_map>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
...@@ -20,10 +20,13 @@ limitations under the License. */ ...@@ -20,10 +20,13 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#include <memory>
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -105,5 +108,51 @@ class CUDNNDeviceContext : public CUDADeviceContext { ...@@ -105,5 +108,51 @@ class CUDNNDeviceContext : public CUDADeviceContext {
#endif #endif
/*! \brief device context pool singleton */
class DeviceContextPool {
public:
explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool& Get() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool;
}
/*! \brief Create should only called by Init function */
static DeviceContextPool& Create(const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}
/*! \brief Return handle of single device context. */
const platform::DeviceContext* Borrow(const platform::Place& place);
/*! \brief Return handle of multi-device context. */
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places);
~DeviceContextPool() {}
private:
static DeviceContextPool* pool;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which()
<< (sizeof(int) * 8 - NUM_PLACE_TYPE_LIMIT_IN_BIT);
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::GPUPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
};
std::unordered_map<const platform::Place, const platform::DeviceContext*,
Hash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/platform/device_context.h"
#include "glog/logging.h"
TEST(Device, Init) { TEST(Device, Init) {
using paddle::platform::DeviceContext; using paddle::platform::DeviceContext;
...@@ -62,3 +64,54 @@ TEST(Device, CUDNNDeviceContext) { ...@@ -62,3 +64,54 @@ TEST(Device, CUDNNDeviceContext) {
} }
} }
} }
TEST(Device, DeviceContextPool) {
using paddle::platform::DeviceContextPool;
using paddle::platform::CUDADeviceContext;
using paddle::platform::Place;
using paddle::platform::CPUPlace;
using paddle::platform::GPUPlace;
DeviceContextPool& pool = DeviceContextPool::Get();
auto cpu_dev_ctx1 = pool.Borrow(CPUPlace());
auto cpu_dev_ctx2 = pool.Borrow(CPUPlace());
EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1);
std::vector<Place> gpu_places;
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
gpu_places.emplace_back(GPUPlace(i));
}
auto dev_ctxs = pool.Borrow(gpu_places);
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
// check same as GPUPlace(i)
GPUPlace place = boost::get<GPUPlace>(dev_ctx->GetPlace());
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
}
}
int main(int argc, char** argv) {
int dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA "
"device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
...@@ -63,6 +63,8 @@ extern void LoadNCCLDSO(); ...@@ -63,6 +63,8 @@ extern void LoadNCCLDSO();
__macro(ncclAllReduce); \ __macro(ncclAllReduce); \
__macro(ncclBcast); \ __macro(ncclBcast); \
__macro(ncclAllGather); \ __macro(ncclAllGather); \
__macro(ncclGroupStart); \
__macro(ncclGroupEnd); \
__macro(ncclReduce); \ __macro(ncclReduce); \
__macro(ncclGetErrorString); __macro(ncclGetErrorString);
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "paddle/platform/macros.h"
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
#include "paddle/string/to_string.h" #include "paddle/string/to_string.h"
......
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/dynload/nccl.h" #include "paddle/platform/dynload/nccl.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
static int dev_count = 0; static int dev_count = 0;
namespace paddle { namespace paddle {
...@@ -31,7 +33,8 @@ namespace platform { ...@@ -31,7 +33,8 @@ namespace platform {
TEST(NCCL, init) { TEST(NCCL, init) {
std::vector<ncclComm_t> comms; std::vector<ncclComm_t> comms;
comms.resize(dev_count); comms.resize(dev_count);
dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
for (int i = 0; i < dev_count; ++i) { for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]); dynload::ncclCommDestroy(comms[i]);
} }
...@@ -131,6 +134,18 @@ int main(int argc, char** argv) { ...@@ -131,6 +134,18 @@ int main(int argc, char** argv) {
<< dev_count; << dev_count;
return 0; return 0;
} }
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
...@@ -60,12 +60,14 @@ struct IsGPUPlace : public boost::static_visitor<bool> { ...@@ -60,12 +60,14 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; } bool operator()(const CPUPlace &) const { return false; }
bool operator()(const MKLDNNPlace &) const { return false; } bool operator()(const MKLDNNPlace &) const { return false; }
bool operator()(const GPUPlace &gpu) const { return true; } bool operator()(const GPUPlace &gpu) const { return true; }
bool operator()(const CUDNNPlace &) const { return true; }
}; };
struct IsMKLDNNPlace : public boost::static_visitor<bool> { struct IsMKLDNNPlace : public boost::static_visitor<bool> {
bool operator()(const MKLDNNPlace &) const { return true; } bool operator()(const MKLDNNPlace &) const { return true; }
bool operator()(const CPUPlace &) const { return false; } bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GPUPlace &) const { return false; } bool operator()(const GPUPlace &) const { return false; }
bool operator()(const CUDNNPlace &) const { return false; }
}; };
// Define the max number of Place in bit length. i.e., the max number of places // Define the max number of Place in bit length. i.e., the max number of places
......
...@@ -360,10 +360,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -360,10 +360,10 @@ All parameter, weight, gradient are variables in Paddle.
}) })
.def("run", .def("run",
[](OperatorBase &self, const Scope &scope, [](OperatorBase &self, const Scope &scope,
const platform::DeviceContext &dev_ctx) { const platform::CPUPlace &place) { self.Run(scope, place); })
self.Run(scope, dev_ctx); .def("run",
dev_ctx.Wait(); [](OperatorBase &self, const Scope &scope,
}) const platform::GPUPlace &place) { self.Run(scope, place); })
.def("type", .def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); }) [](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs", .def("outputs",
...@@ -417,7 +417,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -417,7 +417,7 @@ All parameter, weight, gradient are variables in Paddle.
}); });
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>()) .def(py::init<const platform::Place &>())
.def("run", &Executor::Run); .def("run", &Executor::Run);
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/framework/executor.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -63,8 +63,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -63,8 +63,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>( auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace())); tensor.dims(), platform::CPUPlace()));
framework::DeviceContextPool &pool = platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
framework::DeviceContextPool::Get();
auto dev_ctx = static_cast<const platform::CUDADeviceContext *>( auto dev_ctx = static_cast<const platform::CUDADeviceContext *>(
pool.Borrow(tensor.place())); pool.Borrow(tensor.place()));
...@@ -138,7 +137,7 @@ void PyCUDATensorSetFromArray( ...@@ -138,7 +137,7 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place); auto *dst = self.mutable_data<T>(place);
framework::DeviceContextPool &pool = framework::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto dev_ctx = auto dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Borrow(place)); static_cast<const platform::CUDADeviceContext *>(pool.Borrow(place));
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
......
...@@ -6,7 +6,6 @@ if(WITH_TESTING) ...@@ -6,7 +6,6 @@ if(WITH_TESTING)
add_library(paddle_test_util STATIC TestUtil.cpp) add_library(paddle_test_util STATIC TestUtil.cpp)
add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies}) add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies})
if(NOT MOBILE_INFERENCE) if(NOT MOBILE_INFERENCE)
add_library(paddle_gtest_main STATIC paddle_gtest_main.cc) cc_library(paddle_gtest_main SRCS paddle_gtest_main.cc DEPS init paddle_memory gtest gflags)
add_dependencies(paddle_gtest_main paddle_memory gtest gflags)
endif() endif()
endif() endif()
...@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cstring> #include <cstring>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
...@@ -32,8 +34,11 @@ int main(int argc, char** argv) { ...@@ -32,8 +34,11 @@ int main(int argc, char** argv) {
google::ParseCommandLineFlags(&new_argc, &new_argv_address, false); google::ParseCommandLineFlags(&new_argc, &new_argv_address, false);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
paddle::memory::Used(paddle::platform::CPUPlace()); paddle::memory::Used(paddle::platform::CPUPlace());
std::vector<std::string> devs = {"CPU"};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
paddle::memory::Used(paddle::platform::GPUPlace(0)); paddle::memory::Used(paddle::platform::GPUPlace(0));
devs.push_back("GPU:0");
#endif #endif
paddle::framework::InitDevices(devs);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
...@@ -42,5 +42,10 @@ def __read_gflags_from_env__(): ...@@ -42,5 +42,10 @@ def __read_gflags_from_env__():
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)]) ["--tryfromenv=" + ",".join(read_env_flags)])
if core.is_compile_gpu():
core.init_devices(["CPU", "GPU:0"])
else:
core.init_devices(["CPU"])
__read_gflags_from_env__() __read_gflags_from_env__()
...@@ -47,13 +47,14 @@ class Executor(object): ...@@ -47,13 +47,14 @@ class Executor(object):
act_places.append(p) act_places.append(p)
# TODO(dzhwinter) : consider that our fluid tests all written in # TODO(dzhwinter) : consider that our fluid tests all written in
# GPUPlace(gpu_id), this will be changed in next PR. # GPUPlace(gpu_id), this will be changed in the future
if core.is_compile_gpu(): if core.is_compile_gpu():
core.init_devices(["CPU", "GPU:0"]) core.init_devices(["CPU", "GPU:0"])
else: else:
core.init_devices(["CPU"]) core.init_devices(["CPU"])
self.executor = core.Executor(act_places) # TODO(dzhwinter) : only use the first place
self.executor = core.Executor(act_places[0])
self.places = places self.places = places
def aslodtensor(self, data): def aslodtensor(self, data):
......
...@@ -90,12 +90,10 @@ def get_numeric_gradient(scope, ...@@ -90,12 +90,10 @@ def get_numeric_gradient(scope,
def product(dim): def product(dim):
return reduce(lambda a, b: a * b, dim, 1) return reduce(lambda a, b: a * b, dim, 1)
ctx = core.DeviceContext.create(core.CPUPlace())
def get_output(): def get_output():
sum = [] sum = []
for output_name in output_names: for output_name in output_names:
op.run(scope, ctx) op.run(scope, core.CPUPlace())
sum.append( sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean()) np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).mean() return np.array(sum).mean()
......
...@@ -113,8 +113,7 @@ class TestSparseAdagradOp(unittest.TestCase): ...@@ -113,8 +113,7 @@ class TestSparseAdagradOp(unittest.TestCase):
LearningRate='LearningRate', LearningRate='LearningRate',
epsilon=2.0) epsilon=2.0)
ctx = core.DeviceContext.create(place) adagrad_op.run(scope, place)
adagrad_op.run(scope, ctx)
# get and compare moment result # get and compare moment result
moment_result_array = np.array(moment) moment_result_array = np.array(moment)
......
...@@ -296,8 +296,7 @@ class TestBatchNormOp(OpTest): ...@@ -296,8 +296,7 @@ class TestBatchNormOp(OpTest):
momentum=momentum, momentum=momentum,
epsilon=epsilon) epsilon=epsilon)
ctx = core.DeviceContext.create(place) batch_norm_op.run(scope, place)
batch_norm_op.run(scope, ctx)
# check forward result # check forward result
self.__assert_close(y_tensor, y_out, "y_out") self.__assert_close(y_tensor, y_out, "y_out")
...@@ -320,7 +319,7 @@ class TestBatchNormOp(OpTest): ...@@ -320,7 +319,7 @@ class TestBatchNormOp(OpTest):
["y_out", "mean", "variance", "saved_mean", "saved_variance"], ["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place, place,
feed_dict={"y_out": y_grad}) feed_dict={"y_out": y_grad})
batch_norm_op_grad.run(scope, ctx) batch_norm_op_grad.run(scope, place)
x_grad_tensor = create_or_get_tensor(scope, x_grad_tensor = create_or_get_tensor(scope,
grad_var_name("x_val"), None, grad_var_name("x_val"), None,
......
...@@ -57,8 +57,7 @@ class TestBeamSearchDecodeOp(unittest.TestCase): ...@@ -57,8 +57,7 @@ class TestBeamSearchDecodeOp(unittest.TestCase):
SentenceIds="sentence_ids", SentenceIds="sentence_ids",
SentenceScores="sentence_scores") SentenceScores="sentence_scores")
ctx = core.DeviceContext.create(self.cpu_place) beam_search_decode_op.run(self.scope, self.cpu_place)
beam_search_decode_op.run(self.scope, ctx)
expected_lod = [[0, 4, 8], [0, 1, 3, 6, 9, 10, 13, 16, 19]] expected_lod = [[0, 4, 8], [0, 1, 3, 6, 9, 10, 13, 16, 19]]
self.assertEqual(sentence_ids.lod(), expected_lod) self.assertEqual(sentence_ids.lod(), expected_lod)
......
...@@ -14,7 +14,6 @@ def create_tensor(scope, name, np_data): ...@@ -14,7 +14,6 @@ def create_tensor(scope, name, np_data):
class BeamSearchOpTester(unittest.TestCase): class BeamSearchOpTester(unittest.TestCase):
def setUp(self): def setUp(self):
self.scope = core.Scope() self.scope = core.Scope()
self.ctx = core.DeviceContext.create(core.CPUPlace())
self._create_ids() self._create_ids()
self._create_scores() self._create_scores()
self._create_pre_ids() self._create_pre_ids()
...@@ -32,7 +31,7 @@ class BeamSearchOpTester(unittest.TestCase): ...@@ -32,7 +31,7 @@ class BeamSearchOpTester(unittest.TestCase):
level=0, level=0,
beam_size=2, beam_size=2,
end_id=0, ) end_id=0, )
op.run(self.scope, self.ctx) op.run(self.scope, core.CPUPlace())
selected_ids = self.scope.find_var("selected_ids").get_tensor() selected_ids = self.scope.find_var("selected_ids").get_tensor()
print 'selected_ids', np.array(selected_ids) print 'selected_ids', np.array(selected_ids)
print 'lod', selected_ids.lod() print 'lod', selected_ids.lod()
......
...@@ -65,8 +65,7 @@ class TestCondOp(unittest.TestCase): ...@@ -65,8 +65,7 @@ class TestCondOp(unittest.TestCase):
self.create_global_variables() self.create_global_variables()
self.create_cond_op() self.create_cond_op()
self.create_sub_net() self.create_sub_net()
ctx = core.DeviceContext.create(core.CPUPlace()) self.condop.run(self.scope, core.CPUPlace())
self.condop.run(self.scope, ctx)
return np.array(self.scope.find_var("Out").get_tensor()) return np.array(self.scope.find_var("Out").get_tensor())
def create_global_variables(self): def create_global_variables(self):
......
...@@ -24,7 +24,6 @@ class TestGaussianRandomOp(unittest.TestCase): ...@@ -24,7 +24,6 @@ class TestGaussianRandomOp(unittest.TestCase):
def gaussian_random_test(self, place): def gaussian_random_test(self, place):
context = core.DeviceContext.create(place)
program = fluid.Program() program = fluid.Program()
block = program.global_block() block = program.global_block()
vout = block.create_var(name="Out") vout = block.create_var(name="Out")
......
...@@ -33,8 +33,7 @@ class TestIsEmptyOp(unittest.TestCase): ...@@ -33,8 +33,7 @@ class TestIsEmptyOp(unittest.TestCase):
def one_case(self, input, target): def one_case(self, input, target):
op = Operator(type="is_empty", X=input, Out="out") op = Operator(type="is_empty", X=input, Out="out")
ctx = core.DeviceContext.create(core.CPUPlace()) op.run(self.scope, core.CPUPlace())
op.run(self.scope, ctx)
out = self.scope.var("out").get_tensor() out = self.scope.var("out").get_tensor()
self.assertEqual(np.array(out)[0], target) self.assertEqual(np.array(out)[0], target)
......
...@@ -55,8 +55,7 @@ class TestSparseSGDOp(unittest.TestCase): ...@@ -55,8 +55,7 @@ class TestSparseSGDOp(unittest.TestCase):
Grad='Grad', Grad='Grad',
ParamOut='Param', ParamOut='Param',
LearningRate='LearningRate') LearningRate='LearningRate')
ctx = core.DeviceContext.create(place) sgd_op.run(scope, place)
sgd_op.run(scope, ctx)
# get and compare result # get and compare result
result_array = np.array(param) result_array = np.array(param)
......
...@@ -26,7 +26,6 @@ class TestUniformRandomOp(unittest.TestCase): ...@@ -26,7 +26,6 @@ class TestUniformRandomOp(unittest.TestCase):
self.uniform_random_test(place=core.GPUPlace(0)) self.uniform_random_test(place=core.GPUPlace(0))
def uniform_random_test(self, place): def uniform_random_test(self, place):
context = core.DeviceContext.create(place)
program = fluid.Program() program = fluid.Program()
block = program.global_block() block = program.global_block()
vout = block.create_var(name="Out") vout = block.create_var(name="Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册