提交 e834eb87 编写于 作者: T typhoonzero

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into grpc_benchmark

......@@ -170,6 +170,18 @@ sequence_pool
:noindex:
sequence_first_step
-------------------
.. autofunction:: paddle.v2.fluid.layers.sequence_first_step
:noindex:
sequence_last_step
------------------
.. autofunction:: paddle.v2.fluid.layers.sequence_last_step
:noindex:
pool2d
------
.. autofunction:: paddle.v2.fluid.layers.pool2d
......
......@@ -291,10 +291,10 @@ public:
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
const platform::Place& place) const override {
PADDLE_ENFORCE(symbols_ready_, "operators and variables should be created first.");
for (auto& op : runtime_table_.ops()) {
op->Run(scope, dev_ctx);
op->Run(scope, place);
}
}
......
......@@ -30,7 +30,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
......@@ -59,5 +59,8 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(init SRCS init.cc DEPS gflags executor place stringpiece)
cc_test(threadpool_test SRCS threadpool_test.cc)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context)
......@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once
#include <iostream>
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
......@@ -33,5 +36,23 @@ inline DataLayout StringToDataLayout(const std::string& str) {
}
}
inline std::string DataLayoutToString(const DataLayout& data_layout) {
switch (data_layout) {
case kNHWC:
return "NHWC";
case kNCHW:
return "NCHW";
case kAnyLayout:
return "ANY_LAYOUT";
default:
PADDLE_THROW("unknown DataLayou %d", data_layout);
}
}
inline std::ostream& operator<<(std::ostream& out, DataLayout l) {
out << DataLayoutToString(l);
return out;
}
} // namespace framework
} // namespace paddle
......@@ -33,13 +33,7 @@ namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
DeviceContextPool* DeviceContextPool::pool = nullptr;
Executor::Executor(const std::vector<platform::Place>& places) {
DeviceContextPool& pool = DeviceContextPool::Get();
auto borrowed_contexts = pool.Borrow(places);
device_contexts_.swap(borrowed_contexts);
}
Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) {
......@@ -71,7 +65,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), pdesc.Size());
auto& block = pdesc.Block(block_id);
auto& device = device_contexts_[0];
Scope* local_scope = scope;
if (create_vars) {
......@@ -107,7 +100,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString();
op->Run(*local_scope, *device);
op->Run(*local_scope, place_);
}
if (create_local_scope) {
scope->DeleteScope(local_scope);
......
......@@ -14,9 +14,6 @@ limitations under the License. */
#pragma once
#include <map>
#include <unordered_map>
#include "paddle/framework/op_info.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h"
......@@ -26,96 +23,13 @@ limitations under the License. */
namespace paddle {
namespace framework {
class DeviceContextPool {
public:
static DeviceContextPool& Get() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool;
}
static DeviceContextPool& Create(const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}
const platform::DeviceContext* Borrow(const platform::Place& place) {
auto range = device_contexts_.equal_range(place);
if (range.first == range.second) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
return range.first->second;
}
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto range = device_contexts_.equal_range(place);
if (range.first == range.second) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
// TODO(dzhwinter) : assign the first found device. Will enhanced later.
// device load balancer maybe useful here.
borrowed_contexts.emplace_back(range.first->second);
}
return borrowed_contexts;
}
explicit DeviceContextPool(const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_.emplace(
places[i], new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(
places[i], new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
~DeviceContextPool() {}
private:
static DeviceContextPool* pool;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
return hash_(place.which());
}
};
std::unordered_multimap<const platform::Place, const platform::DeviceContext*,
Hash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
class Executor {
public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed
explicit Executor(const platform::DeviceContext& device)
: Executor(std::vector<platform::Place>({device.GetPlace()})) {}
explicit Executor(const platform::Place& place)
: Executor(std::vector<platform::Place>({place})) {}
: Executor(device.GetPlace()) {}
explicit Executor(const std::vector<platform::Place>& places);
explicit Executor(const platform::Place& place);
/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
......@@ -128,7 +42,7 @@ class Executor {
bool create_vars = true);
private:
std::vector<const platform::DeviceContext*> device_contexts_;
const platform::Place place_;
};
} // namespace framework
......
......@@ -14,8 +14,8 @@
#include <algorithm>
#include <string>
#include "paddle/framework/executor.h"
#include "paddle/framework/init.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/string/piece.h"
......@@ -48,7 +48,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
std::vector<platform::Place> places;
for (auto &device : devices) {
auto p = string::Piece(device);
if (string::Find(p, ':', 0) == string::Piece::npos) {
if (string::HasPrefix(p, "CPU")) {
places.emplace_back(platform::CPUPlace());
} else if (string::HasPrefix(p, "GPU")) {
#ifdef PADDLE_WITH_CUDA
......@@ -69,10 +69,9 @@ bool InitDevices(const std::vector<std::string> &devices) {
return platform::is_cpu_place(place);
}) == places.end()) {
places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified any device, use CPU by Default.";
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
DeviceContextPool::Create(places);
return true;
platform::DeviceContextPool::Create(places);
return true;
}
......
......@@ -23,5 +23,9 @@ TEST(Init, InitDevices) {
#ifdef PADDLE_WITH_CUDA
std::vector<std::string> ds2 = {"CPU", "GPU:0", "GPU:1"};
ASSERT_EQ(InitDevices(ds2), true);
// test re-init
std::vector<std::string> ds3 = {"GPU:0", "GPU:1"};
ASSERT_EQ(InitDevices(ds3), true);
#endif
}
......@@ -20,7 +20,25 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
enum LibraryType { kPlain = 0; kMKLDNN = 1; kCUDNN = 2; }
enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
inline std::string LibraryTypeToString(const LibraryType& library_type) {
switch (library_type) {
case kPlain:
return "PLAIN";
case kMKLDNN:
return "MKLDNN";
case kCUDNN:
return "CUDNN";
default:
PADDLE_THROW("unknown LibraryType %d", library_type);
}
}
inline std::ostream& operator<<(std::ostream& out, LibraryType l) {
out << LibraryTypeToString(l);
return out;
}
} // namespace
} // framework
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/data_layout.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/library_type.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace framework {
struct OpKernelType {
struct Hash {
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which() + (1 << LEFT_SHIFT);
int data_type =
static_cast<int>(key.data_type_) + (1 << (LEFT_SHIFT + 1));
int data_layout =
static_cast<int>(key.data_layout_) + (1 << (LEFT_SHIFT + 2));
int library_type =
static_cast<int>(key.library_type_) + (1 << (LEFT_SHIFT + 3));
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type);
}
};
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;
OpKernelType(proto::DataType data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
data_layout_(data_layout),
place_(place),
library_type_(library_type) {}
OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
data_layout_(data_layout),
place_(dev_ctx.GetPlace()),
library_type_(library_type) {}
bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
};
inline std::ostream& operator<<(std::ostream& os,
const OpKernelType& kernel_key) {
os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
<< kernel_key.data_layout_ << "]:place[" << kernel_key.place_
<< "]:library_type[" << kernel_key.library_type_ << "]";
return os;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_kernel_type.h"
#include <gtest/gtest.h>
#include <iostream>
TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
std::ostringstream stream;
stream << op_kernel_type;
ASSERT_EQ(
stream.str(),
"data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
}
TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using CPUPlace = paddle::platform::CPUPlace;
using GPUPlace = paddle::platform::GPUPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
OpKernelType op_kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
OpKernelType op_kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN);
OpKernelType::Hash hasher;
ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2));
}
\ No newline at end of file
......@@ -61,17 +61,6 @@ struct OperatorRegistrar : public Registrar {
class OpRegistry {
public:
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
OperatorRegistrar<OpType, ProtoMakerType> reg(op_type.c_str());
reg.info.grad_op_type_ = grad_op_type;
// register gradient op
if (!grad_op_type.empty()) {
OperatorRegistrar<GradOpType> grad_reg(grad_op_type.c_str());
}
}
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
......
......@@ -8,8 +8,7 @@ namespace framework {
class CosineOp : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void Run(const Scope& scope, const platform::Place& place) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -28,8 +27,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void Run(const Scope& scope, const platform::Place& place) const override {}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -76,8 +74,8 @@ TEST(OpRegistry, CreateOp) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
paddle::platform::CPUPlace cpu_place;
op->Run(scope, cpu_place);
float scale_get = op->Attr<float>("scale");
ASSERT_EQ(scale_get, scale);
}
......@@ -117,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
paddle::platform::CPUPlace cpu_place;
op->Run(scope, cpu_place);
ASSERT_EQ(op->Attr<float>("scale"), 1.0);
}
......@@ -167,9 +165,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
op->Run(scope, dev_ctx);
op->Run(scope, cpu_place);
int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
......
......@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm>
#include <atomic>
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h"
......@@ -240,12 +242,6 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return res;
}
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
return os;
}
bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
......@@ -388,11 +384,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
};
void OperatorWithKernel::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
ExecutionContext ctx(*this, scope, dev_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto dev_ctx = pool.Borrow(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
......@@ -404,6 +400,8 @@ void OperatorWithKernel::Run(const Scope& scope,
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx);
auto kernel_key = GetKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key);
......
......@@ -23,15 +23,14 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG
#include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/platform/variant.h"
#include "paddle/utils/Error.h"
......@@ -83,8 +82,7 @@ class OperatorBase {
virtual std::string DebugString() const;
/// Net will call this function to Run an op.
virtual void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const = 0;
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
virtual bool IsNetOp() const { return false; }
......@@ -159,8 +157,7 @@ class OperatorBase {
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void Run(const Scope& scope, const platform::Place& place) const override {}
std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this));
}
......@@ -345,34 +342,6 @@ class OpKernel : public OpKernelBase {
using ELEMENT_TYPE = T;
};
struct OpKernelType {
struct Hash {
std::hash<int> hash_;
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};
platform::Place place_;
proto::DataType data_type_;
OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}
OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};
class OperatorWithKernel : public OperatorBase {
public:
using OpKernelMap =
......@@ -383,8 +352,7 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final;
void Run(const Scope& scope, const platform::Place& place) const final;
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
......@@ -413,8 +381,6 @@ class OperatorWithKernel : public OperatorBase {
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
};
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
extern bool OpSupportGPU(const std::string& op_type);
} // namespace framework
......
......@@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
......@@ -27,8 +28,7 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
void Run(const Scope& scope, const platform::Place& place) const override {
++op_run_num;
ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
......@@ -41,10 +41,9 @@ class OpWithoutKernelTest : public OperatorBase {
int x{0};
};
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpeWithoutKernelTestProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
......@@ -65,11 +64,12 @@ static void BuildVar(const std::string& param_name,
}
}
REGISTER_OP_WITHOUT_GRADIENT(
test_operator, paddle::framework::OpWithoutKernelTest,
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
REGISTER_OP_WITHOUT_GRADIENT(test_operator,
paddle::framework::OpWithoutKernelTest,
paddle::framework::OpWithoutKernelCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::InitDevices({"CPU"});
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
......@@ -80,13 +80,13 @@ TEST(OperatorBase, all) {
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext device_context;
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context);
op->Run(scope, cpu_place);
ASSERT_EQ(paddle::framework::op_run_num, 1);
}
......@@ -123,7 +123,6 @@ template <typename T1, typename T2>
class CPUKernelTest : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op().DebugString() << std::endl;
cpu_kernel_run_num++;
ASSERT_EQ(ctx.op().Input("x"), "IN1");
......@@ -195,6 +194,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input
TEST(OpKernel, all) {
paddle::framework::InitDevices({"CPU"});
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs());
......@@ -205,12 +205,12 @@ TEST(OpKernel, all) {
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context);
op->Run(scope, cpu_place);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
}
......@@ -224,7 +224,9 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
paddle::framework::InitDevices({"CPU"});
proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs());
......@@ -235,7 +237,7 @@ TEST(OpKernel, multi_inputs) {
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
scope.Var("x0")->GetMutable<LoDTensor>();
scope.Var("x1")->GetMutable<LoDTensor>();
......@@ -245,7 +247,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y1")->GetMutable<LoDTensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
op->Run(scope, cpu_place);
}
class OperatorClone : public paddle::framework::OperatorBase {
......@@ -257,10 +259,11 @@ class OperatorClone : public paddle::framework::OperatorBase {
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {}
const paddle::platform::Place& place) const override {}
};
TEST(Operator, Clone) {
paddle::framework::InitDevices({"CPU"});
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{});
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <cstdio>
#include <functional>
#include <iostream>
#include <mutex>
#include <queue>
#include <thread>
#include "paddle/platform/call_once.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
typedef std::function<void()> Task;
class ThreadPool {
public:
/**
* @brief Get a instance of threadpool, the thread number will
* be specified as the number of hardware thread contexts
*/
static ThreadPool* GetInstance() {
std::call_once(init_flag, &ThreadPool::Init);
return threadpool.get();
}
~ThreadPool() {
{
// notify all threads to stop running
running_ = false;
scheduled_.notify_all();
}
for (auto& t : threads_) {
t->join();
t.reset(nullptr);
}
}
int GetNumThreads() const { return num_threads_; }
int GetAvailable() {
std::unique_lock<std::mutex> lock(mutex_);
return available_;
}
/**
* @brief Push a function to the queue, and will be scheduled and
* executed if a thread is available.
* @param[in] Task will be pushed to the task queue.
*/
void Run(const Task& fn) {
std::unique_lock<std::mutex> lock(mutex_);
tasks_.push(fn);
lock.unlock();
scheduled_.notify_one();
}
/**
* @brief Wait until all the tasks are completed.
*/
void Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
private:
ThreadPool& operator=(const ThreadPool&) = delete;
ThreadPool(const ThreadPool&) = delete;
ThreadPool(int num_threads)
: num_threads_(num_threads), available_(num_threads), running_(true) {
threads_.resize(num_threads);
for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number
thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
}
}
/**
* @brief If the task queue is empty and avaialbe
* is equal to the number of threads, means that
* all tasks are completed.
*
* Note: this function is not thread-safe.
*
* @return true if all tasks are completed.
*/
bool Done() { return tasks_.empty() && available_ == num_threads_; }
void TaskLoop() {
while (running_) {
std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) {
break;
}
// pop a task from the task queue
auto task = tasks_.front();
tasks_.pop();
--available_;
lock.unlock();
// run the task
task();
{
std::unique_lock<std::mutex> lock(mutex_);
++available_;
if (Done()) {
completed_.notify_all();
}
}
}
}
static void Init() {
if (threadpool.get() == nullptr) {
// TODO(Yancey1989): specify the max threads number
int num_threads = std::thread::hardware_concurrency();
PADDLE_ENFORCE_GT(num_threads, 0);
threadpool.reset(new ThreadPool(num_threads));
}
}
private:
static std::unique_ptr<ThreadPool> threadpool;
static std::once_flag init_flag;
int num_threads_;
int available_;
bool running_;
std::queue<Task> tasks_;
std::vector<std::unique_ptr<std::thread>> threads_;
std::mutex mutex_;
std::condition_variable scheduled_;
std::condition_variable completed_;
};
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr);
std::once_flag ThreadPool::init_flag;
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "threadpool.h"
#include <gtest/gtest.h>
#include <atomic>
#include <chrono>
#include <map>
#include <thread>
namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) {
for (int i = 0; i < cnt; ++i) {
pool->Run([&sum]() { sum.fetch_add(1); });
}
}
TEST(ThreadPool, ConcurrentInit) {
framework::ThreadPool* pool;
int concurrent_cnt = 50;
std::vector<std::thread> threads;
for (int i = 0; i < concurrent_cnt; ++i) {
std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); });
threads.push_back(std::move(t));
}
for (auto& t : threads) {
t.join();
}
}
TEST(ThreadPool, ConcurrentStart) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0);
std::vector<std::thread> threads;
int concurrent_cnt = 50;
// sum = (n * (n + 1)) / 2
for (int i = 1; i <= concurrent_cnt; ++i) {
std::thread t(do_sum, pool, std::ref(sum), i);
threads.push_back(std::move(t));
}
for (auto& t : threads) {
t.join();
}
pool->Wait();
EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2);
}
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -27,11 +28,16 @@ class ArrayOp : public framework::OperatorBase {
protected:
size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const {
const platform::Place &place) const {
auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
......
......@@ -12,10 +12,12 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include <numeric>
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -30,7 +32,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
......@@ -103,6 +105,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
continue;
}
auto slice = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice);
out_offset += len;
......
......@@ -15,6 +15,7 @@
#include "paddle/framework/data_type.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/var_type.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -71,7 +72,7 @@ class AssignOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X"));
if (x == nullptr) {
return;
......@@ -80,6 +81,10 @@ class AssignOp : public framework::OperatorBase {
PADDLE_ENFORCE(
out != nullptr,
"The Output(Out) should not be null if the Input(X) is set.");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/beam_search_decode_op.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -55,7 +56,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto& dev_ctx = *pool.Borrow(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
......
......@@ -189,7 +189,7 @@ class BeamSearchOp : public framework::OperatorBase {
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
const platform::Place& dev_place) const override {
LOG(INFO) << "run beam search op";
auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores"));
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cond_op.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/scatter.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -193,12 +193,15 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
}
}
void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
void CondOp::Run(const Scope& scope, const platform::Place& place) const {
// get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto& dev_ctx = *pool.Borrow(place);
PrepareDataForSubnet(scope, dev_ctx);
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);
for (int i = 0; i < BRANCH_NUM; ++i) {
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
sub_net_op_[i]->Run(*sub_scopes[i], place);
}
MergeDataFromSubnet(scope, dev_ctx);
}
......
......@@ -78,7 +78,7 @@ class CondOp : public framework::OperatorBase {
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
const platform::Place& place) const override;
private:
const int TRUE_BRANCH = 0;
......
......@@ -51,7 +51,7 @@ class ConditionalBlockOp : public ConditionalOp {
const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
......@@ -65,8 +65,8 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front();
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
}
}
......@@ -104,7 +104,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
......@@ -116,21 +116,21 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
framework::Scope &cur_scope = *scopes[0];
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("Params"),
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Params"),
Outputs(framework::GradVarName("Params")));
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("X"),
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("X"),
Outputs(framework::GradVarName("X")));
}
}
private:
void AssignLocalGradientToGlobal(
const platform::DeviceContext &dev_ctx, const framework::Scope &cur_scope,
const platform::Place &place, const framework::Scope &cur_scope,
const std::vector<std::string> &p_names,
const std::vector<std::string> &pg_names) const {
for (size_t i = 0; i < p_names.size(); ++i) {
......@@ -144,7 +144,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto assign = framework::OpRegistry::CreateOp(
"assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}},
framework::AttributeMap{});
assign->Run(cur_scope, dev_ctx);
assign->Run(cur_scope, place);
cur_scope.Rename(new_in_grad_name, in_grad_name);
}
}
......
......@@ -25,7 +25,7 @@ class FeedOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);
......@@ -47,7 +47,12 @@ class FeedOp : public framework::OperatorBase {
auto &feed_list = feed_var->Get<framework::FeedFetchList>();
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
framework::CopyFrom(feed_item, dev_ctx.GetPlace(), dev_ctx, out_item);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(feed_item, place, dev_ctx, out_item);
out_item->set_lod(feed_item.lod());
}
};
......
......@@ -14,6 +14,7 @@
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -26,7 +27,7 @@ class FetchOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
......@@ -51,6 +52,9 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait();
dst_item.set_lod(src_item.lod());
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/framework/data_type.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -33,7 +34,7 @@ class FillConstantOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value");
......@@ -45,8 +46,11 @@ class FillConstantOp : public framework::OperatorBase {
auto cpu = platform::CPUPlace();
out.mutable_data(cpu, framework::ToTypeIndex(data_type));
} else {
out.mutable_data(dev_ctx.GetPlace(), framework::ToTypeIndex(data_type));
out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
math::set_constant(dev_ctx, &out, value);
}
};
......
......@@ -15,6 +15,7 @@
#include "paddle/framework/data_type.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -42,7 +43,7 @@ class FillOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto &out =
detail::Ref(detail::Ref(scope.FindVar(Output("Out")),
"Cannot find variable %s", Output("Out"))
......@@ -51,12 +52,11 @@ class FillOp : public framework::OperatorBase {
auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : dev_ctx.GetPlace(),
framework::ToTypeIndex(dtype));
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
framework::LoDTensor tensor;
if (force_cpu || platform::is_cpu_place(dev_ctx.GetPlace())) {
if (force_cpu || platform::is_cpu_place(place)) {
tensor.ShareDataWith(out);
} else {
// Always make tensor in CPU memory.
......@@ -67,9 +67,11 @@ class FillOp : public framework::OperatorBase {
framework::VisitDataType(
dtype, FillOpVisitor(&tensor, Attr<std::vector<float>>("value")));
if (!force_cpu && platform::is_gpu_place(dev_ctx.GetPlace())) {
if (!force_cpu && platform::is_gpu_place(place)) {
// Copy tensor to out
framework::CopyFrom(tensor, dev_ctx.GetPlace(), dev_ctx, &out);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(tensor, place, dev_ctx, &out);
}
}
};
......
......@@ -52,7 +52,7 @@ class IncrementOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
......
......@@ -29,7 +29,7 @@ class IsEmptyOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
// get input
auto *var = scope.FindVar(Input(kInput));
PADDLE_ENFORCE_NOT_NULL(var);
......
......@@ -11,10 +11,10 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include "paddle/framework/op_registry.h"
#include <fstream>
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -26,7 +26,7 @@ class LoadOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
......@@ -40,7 +40,9 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
framework::DeserializeFromStream(fin, tensor);
auto place = dev_ctx.GetPlace();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
framework::LoDTensor cpu_tensor;
......
......@@ -26,7 +26,7 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
......
......@@ -24,7 +24,7 @@ class LoDRankTableOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();
......
......@@ -15,6 +15,7 @@
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -32,7 +33,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s",
Input("X"))
.Get<framework::LoDTensor>();
......@@ -86,6 +87,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
// out[i][offset: offset+len] = x[each_range.begin: each_range.end]
auto slice = out[i].Slice(static_cast<int>(offset),
static_cast<int>(offset + len));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx, &slice);
......
......@@ -94,8 +94,8 @@ class ColwiseSum<platform::CPUDeviceContext, T> {
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < size; ++j) {
for (size_t i = 0; i < static_cast<size_t>(height); ++i) {
for (size_t j = 0; j < static_cast<size_t>(size); ++j) {
if (i == 0) {
out_buf[j] = in_buf[i * size + j];
} else {
......
......@@ -28,7 +28,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
auto *out =
......
......@@ -28,7 +28,11 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
auto &in_true = scope.FindVar(Input("InTrue"))->Get<framework::LoDTensor>();
......
......@@ -113,7 +113,7 @@ This operator is used to perform matrix multiplication for input $X$ and $Y$.
The equation is:
$$Out = X * Y$$
$$Out = X * Y$$
Both the input $X$ and $Y$ can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input $X$.
......
......@@ -24,7 +24,7 @@ class NCCLInitOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
......
......@@ -22,6 +22,7 @@
#include <vector>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/init.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/program_desc.h"
......@@ -49,7 +50,7 @@ const f::DDim kDims = {100, 100};
class NCCLTester : public ::testing::Test {
public:
virtual void SetUp() override {
cpu_ctx = new p::CPUDeviceContext(p::CPUPlace());
paddle::platform::CPUPlace cpu_place;
for (size_t i = 0; i < gpu_list.size(); ++i) {
p::GPUPlace place(i);
dev_ctxs.emplace_back(new p::CUDADeviceContext(place));
......@@ -65,6 +66,7 @@ class NCCLTester : public ::testing::Test {
}
void NCCLInitOp() {
paddle::platform::CPUPlace cpu_place;
std::unique_ptr<f::OpDesc> op1(new f::OpDesc);
op1->SetType("ncclInit");
......@@ -76,7 +78,7 @@ class NCCLTester : public ::testing::Test {
auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *cpu_ctx);
op->Run(g_scope, cpu_place);
VLOG(1) << "NCCLInitOp finished.";
}
......@@ -111,13 +113,12 @@ class NCCLTester : public ::testing::Test {
VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type();
VLOG(1) << " send_tensor : " << send_tensor->numel()
<< " recv_tensor : " << recv_tensor->numel();
op->Run(*scope, *ctx);
op->Run(*scope, place);
VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type();
}
public:
std::vector<p::DeviceContext *> dev_ctxs;
p::DeviceContext *cpu_ctx;
f::Scope g_scope;
std::mutex mu;
};
......@@ -131,14 +132,14 @@ TEST(NCCL, ncclInitOp) {
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
std::unique_ptr<p::DeviceContext> ctx(new p::CPUDeviceContext(p::CPUPlace()));
paddle::platform::CPUPlace cpu_place;
auto *var = g_scope.Var("x1");
var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx.get());
op->Run(g_scope, cpu_place);
VLOG(1) << "NCCLInitOp finished.";
}
......@@ -294,9 +295,18 @@ int main(int argc, char **argv) {
return 0;
}
for (int i = 0; i < dev_count; ++i) {
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
gpu_list.emplace_back(i);
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
testing::InitGoogleTest(&argc, argv);
// device context should be release before scope.
......
......@@ -65,9 +65,9 @@ class NetOp : public framework::OperatorBase {
* will be used.
*/
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
const platform::Place& place) const override {
for (auto& op : ops_) {
op->Run(scope, dev_ctx);
op->Run(scope, place);
}
}
......
......@@ -13,8 +13,7 @@ class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
void Run(const Scope& scope, const platform::Place& place) const override {
++run_cnt;
}
};
......
......@@ -154,13 +154,14 @@ class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
"Noting that reducing on the first dim will make the LoD info lost.")
.SetDefault(0);
AddComment(R"DOC(
PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR)
model performance.
Within some context, e.g. the "query", a LTR model generates scores
for a list of items, which gives a partial order of the items.
PositiveNegativePairOp takes a list of reference rank order
(Input("Label")) and the model generated scores (Input(Score)) as
inputs and counts the pairs that ranked correctly and incorrectly.
PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR) model's
performance.
Within some context, e.g. the "query", a LTR model generates scores for a list
of items, which gives a partial order of the items. PositiveNegativePairOp
takes a list of reference rank order (Input("Label")) and the model generated
scores (Input(Score)) as inputs and counts the pairs that ranked correctly
and incorrectly.
)DOC");
}
};
......
......@@ -227,14 +227,15 @@ class RecurrentOp : public RecurrentBase {
: RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
VLOG(3) << "Static RNN input sequence length = " << seq_len;
StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx);
framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) {
......@@ -270,6 +271,10 @@ class RecurrentOp : public RecurrentBase {
executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
// Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
this->LinkTensorWithCallback(
......@@ -278,14 +283,13 @@ class RecurrentOp : public RecurrentBase {
framework::LoDTensor *dst_tensor) {
if (i == 0) { // create output tensor at begin
dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims()));
dst_tensor->mutable_data(dev_ctx.GetPlace(), src_tensor.type());
dst_tensor->mutable_data(place, src_tensor.type());
}
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
// Explicit copy output since the local RNN scope can be destroyed
// early.
framework::CopyFrom(src_tensor, dev_ctx.GetPlace(), dev_ctx,
&dst_out);
framework::CopyFrom(src_tensor, place, dev_ctx, &dst_out);
});
scopes.Next();
......@@ -311,15 +315,20 @@ class RecurrentGradOp : public RecurrentBase {
: RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(GetSequenceLength(scope));
StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx);
framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
......@@ -366,8 +375,7 @@ class RecurrentGradOp : public RecurrentBase {
auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>();
framework::CopyFrom(ex_tensor, dev_ctx.GetPlace(), dev_ctx,
cur_grad_tensor);
framework::CopyFrom(ex_tensor, place, dev_ctx, cur_grad_tensor);
}
}
......@@ -410,7 +418,7 @@ class RecurrentGradOp : public RecurrentBase {
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{},
{{"Out", {pg_names[param_id]}}}, attrs);
zero_op->Run(scope, dev_ctx);
zero_op->Run(scope, place);
}
auto new_inside_name = cur_scope.Rename(inside_grad_name);
......@@ -419,7 +427,7 @@ class RecurrentGradOp : public RecurrentBase {
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
sum_op->Run(cur_scope, dev_ctx);
sum_op->Run(cur_scope, place);
cur_scope.Rename(new_inside_name, inside_grad_name);
}
......@@ -437,11 +445,11 @@ class RecurrentGradOp : public RecurrentBase {
}
if (step_id == 0) { // alloc memory
outside->Resize(PrependDims(seq_len, inside.dims()));
outside->mutable_data(dev_ctx.GetPlace(), inside.type());
outside->mutable_data(place, inside.type());
}
auto dst = outside->Slice(seq_offset, seq_offset + 1);
framework::CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx, &dst);
framework::CopyFrom(inside, place, dev_ctx, &dst);
});
VLOG(5) << "Link outside gradient finished ";
......@@ -453,8 +461,8 @@ class RecurrentGradOp : public RecurrentBase {
[&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) {
outside->Resize(inside.dims());
outside->mutable_data(dev_ctx.GetPlace(), inside.type());
framework::CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx, outside);
outside->mutable_data(place, inside.type());
framework::CopyFrom(inside, place, dev_ctx, outside);
});
VLOG(5) << "Link initialize state gradient finished ";
}
......
......@@ -73,7 +73,7 @@ class RecvOp : public framework::OperatorBase {
}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
rpc_service_->SetScope(&recv_scope);
......@@ -113,7 +113,9 @@ class RecvOp : public framework::OperatorBase {
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(v.second, place, dev_ctx, tensor);
}
rpc_service_->Reset();
......@@ -121,7 +123,7 @@ class RecvOp : public framework::OperatorBase {
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_ctx);
framework::Executor executor(place);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
class ReorderLoDTensorByRankTableOpProtoMaker
: public framework::OpProtoAndCheckerMaker {
public:
ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto,
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the input lod tensor need to be reordered.");
AddInput("RankTable",
"(LoDRankTable) the rank table that input need follow");
AddOutput("Out", "(LoDTensor) reordered lod tensor");
AddComment(R"DOC(ReorderLoDTensorByRankTable
Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by
index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of
X will be the first sequence of Output.
NOTE: The RankTable does not need to be calculated by X.
For example:
The X = [Seq0, Seq1, Seq2, Seq3]. The indices of RankTable are [3, 0, 2, 1].
The Out = [Seq3, Seq0, Seq2, Seq1] with correct LoD information.
)DOC");
}
};
class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
public:
ReorderLoDTensorByRankTableBase(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x =
detail::Ref(scope.FindVar(Input("X")),
"Cannot find input lod tensor variable %s", Input("X"))
.Get<framework::LoDTensor>();
auto &rank_table = detail::Ref(scope.FindVar(Input("RankTable")),
"Cannot find input rank table variable %s",
Input("RankTable"))
.Get<framework::LoDRankTable>();
auto &out =
*detail::Ref(scope.FindVar(Output("Out")),
"Cannot find output lod tensor variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>();
out.Resize(x.dims());
out.mutable_data(x.place(), x.type());
this->process(place, x, rank_table, &out);
}
protected:
virtual void process(const platform::Place &place,
const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const = 0;
struct AbsoluteRankTableItem {
size_t offset; // the absolute/accumulated offset.
size_t length; // the length
framework::LoD lod;
};
std::vector<AbsoluteRankTableItem> GetAbsoluteOffsetAndLengthByLoDRankTable(
const framework::LoDTensor &x) const {
std::vector<AbsoluteRankTableItem> absolute_table;
size_t level = 0;
size_t size = x.lod()[level].size();
for (size_t i = 0; i < size - 1; ++i) {
auto lod_offset =
framework::GetSubLoDAndAbsoluteOffset(x.lod(), i, i + 1, level);
auto &offset = lod_offset.second;
absolute_table.emplace_back();
absolute_table.back().length = offset.second - offset.first;
absolute_table.back().offset = offset.first;
absolute_table.back().lod = lod_offset.first;
}
return absolute_table;
}
size_t CopyTensorAndLod(const platform::Place &place,
const AbsoluteRankTableItem &item,
const framework::LoDTensor &x,
framework::LoDTensor *out, size_t out_offset) const {
auto &out_lod = *out->mutable_lod();
auto len = item.length;
auto x_offset = item.offset;
if (out_lod.empty()) {
for (size_t i = 0; i < item.lod.size(); ++i) {
out_lod.push_back(std::vector<size_t>({0}));
}
}
for (size_t i = 0; i < out_lod.size(); ++i) {
auto &out_v = out_lod[i];
auto &new_lod_v = item.lod[i];
for (auto &detail : new_lod_v) {
out_v.push_back(out_v.back() + detail);
}
}
auto x_sliced = x.Slice(x_offset, x_offset + len);
auto out_sliced = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced);
out_offset += len;
return out_offset;
}
};
class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase {
public:
ReorderLoDTensorByRankTableOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {}
protected:
void process(const platform::Place &place, const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const override {
auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x);
size_t out_offset = 0;
out->mutable_lod()->clear();
for (auto &item : rank_table.items()) {
PADDLE_ENFORCE_LT(item.index, absolute_table.size());
out_offset = CopyTensorAndLod(place, absolute_table[item.index], x, out,
out_offset);
}
}
};
class IdentityInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim("Out", context->GetInputDim("X"));
}
};
class ReorderLodTensorByRankGradOpMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("reorder_lod_tensor_by_rank_grad");
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetInput("RankTable", Input("RankTable"));
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class ReorderLoDTensorByRankGradOp : public ReorderLoDTensorByRankTableBase {
public:
ReorderLoDTensorByRankGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {}
protected:
void process(const platform::Place &place, const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const override {
auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x);
// offsets = enumerate([item.index for item in rank_table.items()])
std::vector<std::pair<size_t, size_t>> offsets;
offsets.reserve(rank_table.items().size());
for (size_t i = 0; i < rank_table.items().size(); ++i) {
offsets.push_back({i, rank_table.items()[i].index});
}
// offsets.sort(key=lambda x: x[1])
std::sort(
offsets.begin(), offsets.end(),
[](const std::pair<size_t, size_t> &a,
const std::pair<size_t, size_t> &b) { return a.second < b.second; });
// Copy TensorAndLod
size_t out_offset = 0;
for (auto &offset : offsets) {
out_offset = this->CopyTensorAndLod(place, absolute_table[offset.first],
x, out, out_offset);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(reorder_lod_tensor_by_rank,
ops::ReorderLoDTensorByRankTableOp,
ops::ReorderLodTensorByRankGradOpMaker,
ops::ReorderLoDTensorByRankTableOpProtoMaker,
ops::IdentityInferShape);
REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad,
ops::ReorderLoDTensorByRankGradOp, ops::IdentityInferShape);
......@@ -25,7 +25,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto mem_var_name = Input("X");
auto *mem_var = scope.FindVar(mem_var_name);
PADDLE_ENFORCE(mem_var != nullptr,
......@@ -77,7 +77,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto out_grad_var_name = Input(framework::GradVarName("Out"));
auto *out_grad_var = scope.FindVar(out_grad_var_name);
......@@ -100,7 +100,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", {}, {{"Out", {in_grad_var_name}}}, attrs);
zero_op->Run(scope, dev_ctx);
zero_op->Run(scope, dev_place);
} else {
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
......
......@@ -21,7 +21,7 @@ USE_NO_KERNEL_OP(load);
TEST(SaveLoadOp, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
paddle::platform::CPUDeviceContext ctx(place);
auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({10, 10});
......@@ -42,13 +42,13 @@ TEST(SaveLoadOp, CPU) {
auto save_op = paddle::framework::OpRegistry::CreateOp(
"save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, ctx);
save_op->Run(scope, place);
auto load_var = scope.Var("out_var");
auto target = load_var->GetMutable<paddle::framework::LoDTensor>();
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {"out_var"}}}, attrs);
load_op->Run(scope, ctx);
load_op->Run(scope, place);
int* actual = target->data<int>();
for (int64_t i = 0; i < tensor->numel(); ++i) {
EXPECT_EQ(expect[i], actual[i]);
......
......@@ -21,6 +21,7 @@
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -62,7 +63,7 @@ class SaveOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
......@@ -88,6 +89,11 @@ class SaveOp : public framework::OperatorBase {
"SaveOp only support LoDTensor, %s has wrong type", iname);
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::SerializeToStream(fout, tensor, dev_ctx);
}
};
......
......@@ -27,11 +27,11 @@ class ShrinkRNNMemoryOp : public ArrayOp {
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
auto &x_tensor = x_var->Get<framework::LoDTensor>();
size_t offset = this->GetOffset(scope, dev_ctx);
size_t offset = this->GetOffset(scope, place);
auto *rank_table_var = scope.FindVar(Input("RankTable"));
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
......@@ -93,7 +93,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
......@@ -105,6 +105,10 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
dx_tensor.Resize(x_tensor.dims());
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
if (dout_var == nullptr) { // dx_tensor fill zero
math::set_constant(dev_ctx, &dx_tensor, 0.0f);
} else {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -33,7 +34,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
auto *out_true =
......@@ -44,6 +45,9 @@ class SplitLoDTensorOp : public framework::OperatorBase {
auto &x_lod = x.lod();
auto &mask_dim = mask.dims();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) {
cpu_mask->ShareDataWith(mask);
......
......@@ -25,11 +25,11 @@ class WriteToArrayOp : public ArrayOp {
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X"));
if (x == nullptr) return;
auto &x_tensor = x->Get<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx);
size_t offset = GetOffset(scope, place);
auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>();
if (offset >= out->size()) {
......@@ -39,7 +39,11 @@ class WriteToArrayOp : public ArrayOp {
}
if (x_tensor.memory_size() > 0) {
auto *out_tensor = &out->at(offset);
CopyFrom(x_tensor, dev_ctx.GetPlace(), dev_ctx, out_tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
CopyFrom(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod());
} else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
......@@ -119,17 +123,18 @@ class ReadFromArrayOp : public ArrayOp {
const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set");
auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx);
size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) {
framework::CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx,
out_tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor);
out_tensor->set_lod(x_array[offset].lod());
} else {
VLOG(10) << "offset " << offset << " >= " << x_array.size();
......
......@@ -70,18 +70,19 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
Transpose Operator.
The input tensor will be permuted according to the axis values given.
The op functions similar to how numpy.transpose works in python.
For example:
>> input = numpy.arange(6).reshape((2,3))
>> input
array([[0, 1, 2],
[3, 4, 5]])
>> axis = [1, 0]
>> output = input.transpose(axis)
>> output
array([[0, 3],
[1, 4],
[2, 5]])
The op functions is similar to how numpy.transpose works in python.
For example: input = numpy.arange(6).reshape((2,3))
the input is:
array([[0, 1, 2],
[3, 4, 5]])
given axis is: [1, 0]
output = input.transpose(axis)
then the output is:
array([[0, 3],
[1, 4],
[2, 5]])
So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1},
the output tensor shape will be (N, H, W, C)
......
......@@ -53,16 +53,14 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddComment(R"DOC(
"Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
$$
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
/07/iccv2011.pdf
)DOC");
Input shape is: $(N, C_{in}, H_{in}, W_{in})$, Output shape is:
$(N, C_{out}, H_{out}, W_{out})$, where
$$
H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
$$
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017/07/iccv2011.pdf
)DOC");
}
};
......
......@@ -40,13 +40,14 @@ class WhileOp : public framework::OperatorBase {
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
framework::Executor executor(dev_ctx);
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
auto step_scopes =
......@@ -97,8 +98,8 @@ class WhileGradOp : public framework::OperatorBase {
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
framework::Executor executor(dev_ctx);
const platform::Place &dev_place) const override {
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
......@@ -189,7 +190,7 @@ class WhileGradOp : public framework::OperatorBase {
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{},
{{"Out", {pg_names[param_id]}}}, attrs);
zero_op->Run(scope, dev_ctx);
zero_op->Run(scope, dev_place);
}
}
......@@ -197,7 +198,7 @@ class WhileGradOp : public framework::OperatorBase {
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
sum_op->Run(cur_scope, dev_ctx);
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
}
}
......
......@@ -25,7 +25,7 @@ ENDIF()
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info)
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
......
......@@ -15,6 +15,59 @@ limitations under the License. */
namespace paddle {
namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Borrow(
const platform::Place& place) {
auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
return it->second;
}
std::vector<const platform::DeviceContext*> DeviceContextPool::Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto it = device_contexts_.find(place);
if (it != device_contexts_.end()) {
borrowed_contexts.emplace_back(it->second);
} else {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
}
return borrowed_contexts;
}
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_.emplace(places[i],
new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(places[i],
new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
......
......@@ -11,8 +11,8 @@ limitations under the License. */
#pragma once
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include <memory>
#include <unordered_map>
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h"
......@@ -20,10 +20,13 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
#include <memory>
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "glog/logging.h"
namespace paddle {
namespace platform {
......@@ -105,5 +108,51 @@ class CUDNNDeviceContext : public CUDADeviceContext {
#endif
/*! \brief device context pool singleton */
class DeviceContextPool {
public:
explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool& Get() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool;
}
/*! \brief Create should only called by Init function */
static DeviceContextPool& Create(const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}
/*! \brief Return handle of single device context. */
const platform::DeviceContext* Borrow(const platform::Place& place);
/*! \brief Return handle of multi-device context. */
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places);
~DeviceContextPool() {}
private:
static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() + (1 << LEFT_SHIFT);
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::GPUPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
};
std::unordered_map<const platform::Place, const platform::DeviceContext*,
Hash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
} // namespace platform
} // namespace paddle
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"
#include "paddle/platform/device_context.h"
#include "glog/logging.h"
TEST(Device, Init) {
using paddle::platform::DeviceContext;
......@@ -62,3 +64,54 @@ TEST(Device, CUDNNDeviceContext) {
}
}
}
TEST(Device, DeviceContextPool) {
using paddle::platform::DeviceContextPool;
using paddle::platform::CUDADeviceContext;
using paddle::platform::Place;
using paddle::platform::CPUPlace;
using paddle::platform::GPUPlace;
DeviceContextPool& pool = DeviceContextPool::Get();
auto cpu_dev_ctx1 = pool.Borrow(CPUPlace());
auto cpu_dev_ctx2 = pool.Borrow(CPUPlace());
EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1);
std::vector<Place> gpu_places;
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
gpu_places.emplace_back(GPUPlace(i));
}
auto dev_ctxs = pool.Borrow(gpu_places);
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
// check same as GPUPlace(i)
GPUPlace place = boost::get<GPUPlace>(dev_ctx->GetPlace());
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
}
}
int main(int argc, char** argv) {
int dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA "
"device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -63,6 +63,8 @@ extern void LoadNCCLDSO();
__macro(ncclAllReduce); \
__macro(ncclBcast); \
__macro(ncclAllGather); \
__macro(ncclGroupStart); \
__macro(ncclGroupEnd); \
__macro(ncclReduce); \
__macro(ncclGetErrorString);
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <stdexcept>
#include <string>
#include "paddle/platform/macros.h"
#include "paddle/string/printf.h"
#include "paddle/string/to_string.h"
......
......@@ -12,17 +12,19 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/dynload/nccl.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h"
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
static int dev_count = 0;
namespace paddle {
......@@ -31,7 +33,8 @@ namespace platform {
TEST(NCCL, init) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
......@@ -131,6 +134,18 @@ int main(int argc, char** argv) {
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -60,26 +60,18 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const MKLDNNPlace &) const { return false; }
bool operator()(const GPUPlace &gpu) const { return true; }
bool operator()(const CUDNNPlace &) const { return true; }
};
struct IsMKLDNNPlace : public boost::static_visitor<bool> {
bool operator()(const MKLDNNPlace &) const { return true; }
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GPUPlace &) const { return false; }
bool operator()(const CUDNNPlace &) const { return false; }
};
// Define the max number of Place in bit length. i.e., the max number of places
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<CUDNNPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
BOOST_MPL_ASSERT((boost::mpl::less_equal<
Place::types::size,
boost::mpl::long_<1 << NUM_PLACE_TYPE_LIMIT_IN_BIT>>));
void set_place(const Place &);
const Place &get_place();
......
......@@ -360,10 +360,10 @@ All parameter, weight, gradient are variables in Paddle.
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::DeviceContext &dev_ctx) {
self.Run(scope, dev_ctx);
dev_ctx.Wait();
})
const platform::CPUPlace &place) { self.Run(scope, place); })
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::GPUPlace &place) { self.Run(scope, place); })
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
......@@ -417,7 +417,7 @@ All parameter, weight, gradient are variables in Paddle.
});
py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>())
.def(py::init<const platform::Place &>())
.def("run", &Executor::Run);
m.def("unique_integer", UniqueIntegerGenerator);
......
......@@ -14,9 +14,9 @@
#pragma once
#include <string>
#include "paddle/framework/executor.h"
#include "paddle/framework/tensor.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
......@@ -63,8 +63,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace()));
framework::DeviceContextPool &pool =
framework::DeviceContextPool::Get();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto dev_ctx = static_cast<const platform::CUDADeviceContext *>(
pool.Borrow(tensor.place()));
......@@ -138,7 +137,7 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
framework::DeviceContextPool &pool = framework::DeviceContextPool::Get();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Borrow(place));
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
......
......@@ -5,11 +5,3 @@ configure_file(submit_local.sh.in
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/paddle DESTINATION bin
PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ
GROUP_EXECUTE GROUP_READ WORLD_EXECUTE WORLD_READ)
configure_file(tools/usage_stat/usage.sh
paddle_usage
@ONLY)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/paddle_usage DESTINATION opt/paddle/bin
PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ
GROUP_EXECUTE GROUP_READ WORLD_EXECUTE WORLD_READ)
......@@ -165,9 +165,6 @@ case "$1" in
"make_diagram")
python -m paddle.utils.make_model_diagram ${@:2}
;;
"usage")
$PADDLE_BIN_PATH/paddle_usage ${@:2}
;;
"version")
version
;;
......
#!/bin/bash
ARGPARSE=`getopt -o u:vin:l:e: --long git-user:,help,dry-run,task-name:,log-file:,exit-code: -- "$@"`
KEEP_ANONYMOUS="A_USER_DOES_NOT_TELL_US"
# paddle config home dir, same as paddle
PADDLE_CONF_HOME="$HOME/.config/paddle"
# api url, mirror url(s) will be append later
PD_URLS="http://api.paddlepaddle.org/version"
usage()
{
echo "Usage: `basename $0` [options]"
echo "Options:"
echo " -e, --exit-code=EXIT_CODE The train/predict process's exit code"
echo " -l, --log-file=LOG_FILE_PATH Read which log file to get the duration of process"
echo " -n, --task-name=TASK_NAME The name of demo or example"
echo " -u, --git-user=GITHUB_USER provide contact info, like username or email"
echo " -v, -i Verbose output and interact with user when necessary"
echo " --help display this help message"
}
eval set -- "${ARGPARSE}"
while true; do
case "$1" in
-l|--log-file)
log_file=$2
shift 2
;;
-e|--exit-code)
exit_code=$2
shift 2
;;
-u|--git-user)
github_user=$2
shift 2
;;
-n|--task-name)
task=$2
shift 2
;;
-v|-i)
v=1
shift
;;
--dry-run)
dry_run=1
shift
;;
--)
shift
break
;;
--help)
usage
exit 0
;;
*)
echo "Invalid option $1"
usage
exit 1
;;
esac
done
# parse the log_file to get the time costs
if [ -s "${log_file}" ]; then
duration=`awk 'BEGIN{day=0;last_sec=0;min_sec=0;max_sec=0;}
{if(index($2,":")==3){
t=substr($2,1,8);
sec=day*86400+substr(t,1,2)*3600+substr(t,4,2)*60+substr(t,7,2);
if(sec<last_sec-600){day+=1;sec+=86400;}
last_sec=sec;
if(min_sec==0 || min_sec>sec){min_sec=sec;}
if(max_sec==0 || max_sec<sec){max_sec=sec;}
}}
END{print max_sec-min_sec}' ${log_file}`
else
duration=-1
fi
if [ "${v}" = "1" ]; then echo "duration: ${duration}"; fi
# try find the user/email if not given
if [ -z "${github_user}" ]; then
# search for cached username
if [ -s "${PADDLE_CONF_HOME}/github_user" ]; then
if [ "${v}" = "1" ]; then echo "read github_user from cache..."; fi
github_user=`cat ${PADDLE_CONF_HOME}/github_user`
else
# search the github-user from git config
if [ "${v}" = "1" ]; then echo "read github_user from git..."; fi
git_username=`git config --get user.name 2>/dev/null`
git_url=`git config --get remote.origin.url 2>/dev/null`
if [ "`echo ${git_url} | cut -b 1-19`" = "https://github.com/" ]; then
# under a git url, like https://github.com/user_xxx/proj_yyy.git
if [ "${v}" = "1" ]; then echo " from github url..."; fi
github_user=`echo ${git_url} | cut -d "/" -f 4`
if [ "${github_user}" = "PaddlePaddle" ]; then
github_user=
fi
fi
if [ -n "${git_username}" -a -z "${github_user}" ]; then
if [ "${v}" = "1" ]; then echo " from global git username..."; fi
github_user=${git_username}
fi
fi
fi
# allow user to set the user name, if it's not found
if [ -z "${github_user}" -a "${v}" = "1" ]; then
read -p "Please input your github username or email, or just return to keep this feedback anonymous:"
github_user=${REPLY}
if [ -z "${github_user}" ]; then
# empty input, consider as one anonymous user
github_user="${KEEP_ANONYMOUS}"
fi
fi
if [ -n "${github_user}" -a -z "${dry_run}" ]; then
# valid user and not in dry-run mode, then save to cache
mkdir -p ${PADDLE_CONF_HOME}
echo "${github_user}" >${PADDLE_CONF_HOME}/github_user
fi
if [ "${v}" = "1" ]; then echo "username: ${github_user}"; fi
if [ "${github_user}" = "${KEEP_ANONYMOUS}" ]; then
# anonymous user should keep the var empty.
github_user=
fi
# read local paddle version
paddle_version=`paddle version | grep PaddlePaddle | head -n1 | cut -d " " -f 2 | cut -d "," -f 1`
if [ "${v}" = "1" ]; then echo "version:${paddle_version}"; fi
# read local system time
system_time=`date "+%Y%m%d%H%M%S"`
if [ "${v}" = "1" ]; then echo "system time:${system_time}"; fi
# make empty job_name as default value.
if [ -z "${task}" ]; then
task="(unknown_task)"
fi
if [ "${v}" = "1" ]; then echo "task: ${task}"; fi
# concat the curl command
params="content={\"data_type\":\"usage\",\
\"system_time\":${system_time},\"paddle_version\":\"${paddle_version}\",\
\"github_user\":\"${github_user}\",\"job_name\":\"${task}\",\
\"duration\":${duration},\"exit_code\":\"${exit_code}\"\
}&type=1"
curl_cmd_prefix="curl -m 5 -X POST -d ${params}\
-b ${PADDLE_CONF_HOME}/paddle.cookie -c ${PADDLE_CONF_HOME}/paddle.cookie "
if [ "${dry_run}" = "1" ]; then
first_url=`echo ${PD_URLS} | cut -d " " -f 1`
echo "(dry-run mode)curl command: ${curl_cmd_prefix} ${first_url}"
exit 0
else
for u in ${PD_URLS}; do
curl_cmd="${curl_cmd_prefix} ${u}"
if [ "${v}" = "1" ]; then echo "run: ${curl_cmd}"; fi
${curl_cmd} >/dev/null 2>&1
if [ $? -eq 0 ]; then
if [ "${v}" = "1" ]; then echo "upload OK!"; fi
exit 0
else
if [ "${v}" = "1" ]; then echo "upload failed...try next"; fi
fi
done
if [ "${v}" = "1" ]; then echo "all urls tried but all failed...exit"; fi
exit 1
fi
......@@ -6,7 +6,6 @@ if(WITH_TESTING)
add_library(paddle_test_util STATIC TestUtil.cpp)
add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies})
if(NOT MOBILE_INFERENCE)
add_library(paddle_gtest_main STATIC paddle_gtest_main.cc)
add_dependencies(paddle_gtest_main paddle_memory gtest gflags)
cc_library(paddle_gtest_main SRCS paddle_gtest_main.cc DEPS init paddle_memory gtest gflags)
endif()
endif()
......@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <cstring>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/memory/memory.h"
int main(int argc, char** argv) {
......@@ -32,8 +34,11 @@ int main(int argc, char** argv) {
google::ParseCommandLineFlags(&new_argc, &new_argv_address, false);
testing::InitGoogleTest(&argc, argv);
paddle::memory::Used(paddle::platform::CPUPlace());
std::vector<std::string> devs = {"CPU"};
#ifdef PADDLE_WITH_CUDA
paddle::memory::Used(paddle::platform::GPUPlace(0));
devs.push_back("GPU:0");
#endif
paddle::framework::InitDevices(devs);
return RUN_ALL_TESTS();
}
......@@ -42,5 +42,10 @@ def __read_gflags_from_env__():
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
if core.is_compile_gpu():
core.init_devices(["CPU", "GPU:0"])
else:
core.init_devices(["CPU"])
__read_gflags_from_env__()
......@@ -47,13 +47,14 @@ class Executor(object):
act_places.append(p)
# TODO(dzhwinter) : consider that our fluid tests all written in
# GPUPlace(gpu_id), this will be changed in next PR.
# GPUPlace(gpu_id), this will be changed in the future
if core.is_compile_gpu():
core.init_devices(["CPU", "GPU:0"])
else:
core.init_devices(["CPU"])
self.executor = core.Executor(act_places)
# TODO(dzhwinter) : only use the first place
self.executor = core.Executor(act_places[0])
self.places = places
def aslodtensor(self, data):
......
......@@ -393,7 +393,10 @@ class Operator(object):
% (in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
in_arg_names.append(arg.name)
if isinstance(arg, basestring):
in_arg_names.append(arg)
else:
in_arg_names.append(arg.name)
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
......
......@@ -194,3 +194,9 @@ class LayerHelper(object):
else:
# For integer and boolean types, initialize with all zeros
return Constant()
def is_instance(self, param_name, cls):
param = self.kwargs.get(param_name, None)
if not isinstance(param, cls):
raise TypeError("The input {0} parameter of method {1} must be {2}",
param_name, self.layer_type, cls.__name__)
......@@ -3,6 +3,7 @@ from ..framework import Program, Variable, Operator
from .. import core
from tensor import assign, fill_constant
import contextlib
from ..registry import autodoc
__all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard',
......@@ -10,7 +11,7 @@ __all__ = [
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor',
'increment', 'array_write', 'create_array', 'less_than', 'array_read',
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock',
'StaticRNN'
'StaticRNN', 'reorder_lod_tensor_by_rank'
]
......@@ -1082,3 +1083,18 @@ class DynamicRNN(object):
if self.status != DynamicRNN.IN_RNN:
raise ValueError("{0} can only be invoked inside rnn block.".format(
method))
@autodoc
def reorder_lod_tensor_by_rank(x, rank_table):
helper = LayerHelper('reorder_lod_tensor_by_rank', **locals())
helper.is_instance('x', Variable)
helper.is_instance('rank_table', Variable)
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='reorder_lod_tensor_by_rank',
inputs={'X': [x],
'RankTable': [rank_table]},
outputs={'Out': [out]})
return out
......@@ -13,7 +13,8 @@ __all__ = [
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit', 'reduce_sum', 'reduce_mean'
'lstm_unit', 'reduce_sum', 'reduce_mean', 'sequence_first_step',
'sequence_last_step'
]
......@@ -574,9 +575,53 @@ def conv2d(input,
def sequence_pool(input, pool_type, **kwargs):
"""
This function add the operator for sequence pooling.
This is applied on top of the input using pool_type mentioned
in the parameters.
This function add the operator for sequence pooling.
It pools features of all time-steps of each instance, and is applied
on top of the input using pool_type mentioned in the parameters.
It supports four pool_type:
- average: :math:`Out[i] = \\frac{\sum_i X_i}{N}`
- sum: :math:`Out[i] = \sum_jX_{ij}`
- sqrt: :math:`Out[i] = \\frac{\sum_jX_{ij}}{\sqrt{len(X_i)}}`
- max: :math:`Out[i] = max(X_i)`
.. code-block:: text
x is a 1-level LoDTensor:
x.lod = [[0, 2, 5, 7]]
x.data = [1, 3, 2, 4, 6, 5, 1]
x.dims = [7, 1]
then output is a Tensor:
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
for different pool_type:
average: out.data = [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
sum : out.data = [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
pool_type (string): The pooling type of sequence_pool.
It supports average, sum, sqrt and max.
Returns:
The sequence pooling variable which is a Tensor.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
avg_x = fluid.layers.sequence_pool(input=x, pool_type='average')
sum_x = fluid.layers.sequence_pool(input=x, pool_type='sum')
sqrt_x = fluid.layers.sequence_pool(input=x, pool_type='sqrt')
max_x = fluid.layers.sequence_pool(input=x, pool_type='max')
"""
helper = LayerHelper('sequence_pool', input=input, **kwargs)
dtype = helper.input_dtype()
......@@ -593,6 +638,72 @@ def sequence_pool(input, pool_type, **kwargs):
return pool_out
def sequence_first_step(input, **kwargs):
"""
This funciton get the first step of sequence.
.. code-block:: text
x is a 1-level LoDTensor:
x.lod = [[0, 2, 5, 7]]
x.data = [1, 3, 2, 4, 6, 5, 1]
x.dims = [7, 1]
then output is a Tensor:
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
Returns:
The sequence's first step variable which is a Tensor.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_first_step = fluid.layers.sequence_first_step(input=x)
"""
return sequence_pool(input=input, pool_type="first")
def sequence_last_step(input, **kwargs):
"""
This funciton get the last step of sequence.
.. code-block:: text
x is a 1-level LoDTensor:
x.lod = [[0, 2, 5, 7]]
x.data = [1, 3, 2, 4, 6, 5, 1]
x.dims = [7, 1]
then output is a Tensor:
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
Returns:
The sequence's last step variable which is a Tensor.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_last_step = fluid.layers.sequence_last_step(input=x)
"""
return sequence_pool(input=input, pool_type="last")
def pool2d(input,
pool_size,
pool_type,
......
......@@ -8,7 +8,7 @@ import proto.framework_pb2 as framework_pb2
from framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
__all__ = ['deprecated', 'register_layer']
__all__ = ['deprecated', 'register_layer', 'autodoc']
def _convert_(name):
......@@ -175,12 +175,18 @@ def deprecated(func_or_class):
"""
Wrap func with deprecated warning
"""
warnings.simplefilter('always', DeprecationWarning) #turn off filter
warnings.simplefilter('always', DeprecationWarning) # turn off filter
warnings.warn(
"Call to deprecated function {}.".format(func.__name__),
category=DeprecationWarning,
stacklevel=2)
warnings.simplefilter('default', DeprecationWarning) #reset filter
warnings.simplefilter('default', DeprecationWarning) # reset filter
return func(*args, **kwargs)
return func_wrapper
def autodoc(func):
func.__doc__ = _generate_doc_string_(OpProtoHolder.instance().get_op_proto(
func.__name__))
return func
......@@ -33,7 +33,7 @@ def encoder_decoder():
fc1 = fluid.layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = layers.sequence_pool(input=lstm_hidden0, pool_type="last")
encoder_out = layers.sequence_last_step(input=lstm_hidden0)
# decoder
trg_language_word = layers.data(
......
......@@ -125,10 +125,11 @@ def model():
# need cos sim
inference = layers.cos_sim(X=usr_combined_features, Y=mov_combined_features)
scale_infer = layers.scale(x=inference, scale=5.0)
label = layers.data(name='score', shape=[1], dtype='float32')
square_cost = layers.square_error_cost(input=inference, label=label)
square_cost = layers.square_error_cost(input=scale_infer, label=label)
avg_cost = layers.mean(x=square_cost)
......
......@@ -90,12 +90,10 @@ def get_numeric_gradient(scope,
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
ctx = core.DeviceContext.create(core.CPUPlace())
def get_output():
sum = []
for output_name in output_names:
op.run(scope, ctx)
op.run(scope, core.CPUPlace())
sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).mean()
......
......@@ -113,8 +113,7 @@ class TestSparseAdagradOp(unittest.TestCase):
LearningRate='LearningRate',
epsilon=2.0)
ctx = core.DeviceContext.create(place)
adagrad_op.run(scope, ctx)
adagrad_op.run(scope, place)
# get and compare moment result
moment_result_array = np.array(moment)
......
......@@ -296,8 +296,7 @@ class TestBatchNormOp(OpTest):
momentum=momentum,
epsilon=epsilon)
ctx = core.DeviceContext.create(place)
batch_norm_op.run(scope, ctx)
batch_norm_op.run(scope, place)
# check forward result
self.__assert_close(y_tensor, y_out, "y_out")
......@@ -320,7 +319,7 @@ class TestBatchNormOp(OpTest):
["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place,
feed_dict={"y_out": y_grad})
batch_norm_op_grad.run(scope, ctx)
batch_norm_op_grad.run(scope, place)
x_grad_tensor = create_or_get_tensor(scope,
grad_var_name("x_val"), None,
......
......@@ -57,8 +57,7 @@ class TestBeamSearchDecodeOp(unittest.TestCase):
SentenceIds="sentence_ids",
SentenceScores="sentence_scores")
ctx = core.DeviceContext.create(self.cpu_place)
beam_search_decode_op.run(self.scope, ctx)
beam_search_decode_op.run(self.scope, self.cpu_place)
expected_lod = [[0, 4, 8], [0, 1, 3, 6, 9, 10, 13, 16, 19]]
self.assertEqual(sentence_ids.lod(), expected_lod)
......
......@@ -14,7 +14,6 @@ def create_tensor(scope, name, np_data):
class BeamSearchOpTester(unittest.TestCase):
def setUp(self):
self.scope = core.Scope()
self.ctx = core.DeviceContext.create(core.CPUPlace())
self._create_ids()
self._create_scores()
self._create_pre_ids()
......@@ -32,7 +31,7 @@ class BeamSearchOpTester(unittest.TestCase):
level=0,
beam_size=2,
end_id=0, )
op.run(self.scope, self.ctx)
op.run(self.scope, core.CPUPlace())
selected_ids = self.scope.find_var("selected_ids").get_tensor()
print 'selected_ids', np.array(selected_ids)
print 'lod', selected_ids.lod()
......
......@@ -65,8 +65,7 @@ class TestCondOp(unittest.TestCase):
self.create_global_variables()
self.create_cond_op()
self.create_sub_net()
ctx = core.DeviceContext.create(core.CPUPlace())
self.condop.run(self.scope, ctx)
self.condop.run(self.scope, core.CPUPlace())
return np.array(self.scope.find_var("Out").get_tensor())
def create_global_variables(self):
......
......@@ -63,8 +63,7 @@ class TestDynRNN(unittest.TestCase):
all_timesteps = fluid.layers.array_to_lod_tensor(
x=out, table=rank_table)
last = fluid.layers.sequence_pool(
input=all_timesteps, pool_type='last')
last = fluid.layers.sequence_last_step(input=all_timesteps)
logits = fluid.layers.fc(input=last, size=1, act=None)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
......@@ -101,7 +100,7 @@ class TestDynRNN(unittest.TestCase):
rnn.update_memory(mem, out_)
rnn.output(out_)
last = fluid.layers.sequence_pool(input=rnn(), pool_type='last')
last = fluid.layers.sequence_last_step(input=rnn())
logits = fluid.layers.fc(input=last, size=1, act=None)
label = fluid.layers.data(name='label', shape=[1], dtype='float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
......
......@@ -24,7 +24,6 @@ class TestGaussianRandomOp(unittest.TestCase):
def gaussian_random_test(self, place):
context = core.DeviceContext.create(place)
program = fluid.Program()
block = program.global_block()
vout = block.create_var(name="Out")
......
......@@ -33,8 +33,7 @@ class TestIsEmptyOp(unittest.TestCase):
def one_case(self, input, target):
op = Operator(type="is_empty", X=input, Out="out")
ctx = core.DeviceContext.create(core.CPUPlace())
op.run(self.scope, ctx)
op.run(self.scope, core.CPUPlace())
out = self.scope.var("out").get_tensor()
self.assertEqual(np.array(out)[0], target)
......
import unittest
import paddle.v2.fluid as fluid
import numpy
class TestReorderLoDTensor(unittest.TestCase):
def test_reorder(self):
dat = fluid.layers.data(name='input', shape=[1], lod_level=2)
dat.stop_gradient = False
rank_dat = fluid.layers.data(name='ref', shape=[1], lod_level=1)
table = fluid.layers.lod_rank_table(rank_dat)
new_dat = fluid.layers.reorder_lod_tensor_by_rank(
x=dat, rank_table=table)
loss = fluid.layers.mean(x=new_dat)
fluid.backward.append_backward_ops(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_startup_program())
ref = fluid.Tensor()
ref_lod = [0, 3, 4, 7, 8, 14]
ref.set_lod([ref_lod])
ref.set(numpy.random.random(size=[14, 1]).astype('float32'), cpu)
input = fluid.Tensor()
lod_level_0 = numpy.random.randint(low=1, high=5, size=5)
lod_level_0 = [0] + numpy.cumsum(lod_level_0).tolist()
lod_level_1 = numpy.random.randint(low=1, high=5, size=lod_level_0[-1])
lod_level_1 = [0] + numpy.cumsum(lod_level_1).tolist()
input.set_lod([lod_level_0, lod_level_1])
input.set(
numpy.random.random(size=[lod_level_1[-1], 1]).astype('float32'),
cpu)
ig = exe.run(fluid.default_main_program(),
feed={'input': input,
'ref': ref},
fetch_list=['input@GRAD'],
return_numpy=False)[0]
self.assertAlmostEqual(numpy.array(ig).sum(), 1.0, delta=0.001)
self.assertEqual(input.lod(), ig.lod())
if __name__ == '__main__':
unittest.main()
......@@ -55,8 +55,7 @@ class TestSparseSGDOp(unittest.TestCase):
Grad='Grad',
ParamOut='Param',
LearningRate='LearningRate')
ctx = core.DeviceContext.create(place)
sgd_op.run(scope, ctx)
sgd_op.run(scope, place)
# get and compare result
result_array = np.array(param)
......
......@@ -26,7 +26,6 @@ class TestUniformRandomOp(unittest.TestCase):
self.uniform_random_test(place=core.GPUPlace(0))
def uniform_random_test(self, place):
context = core.DeviceContext.create(place)
program = fluid.Program()
block = program.global_block()
vout = block.create_var(name="Out")
......
......@@ -79,8 +79,7 @@ if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
# the prefix is sys.prefix which should always be usr
paddle_bin_dir = 'opt/paddle/bin'
paddle_bins = ['${PADDLE_BINARY_DIR}/paddle/scripts/paddle_usage',
'${PADDLE_BINARY_DIR}/paddle/trainer/paddle_trainer',
paddle_bins = ['${PADDLE_BINARY_DIR}/paddle/trainer/paddle_trainer',
'${PADDLE_BINARY_DIR}/paddle/trainer/paddle_merge_model',
'${PADDLE_BINARY_DIR}/paddle/pserver/paddle_pserver_main',
'${PADDLE_BINARY_DIR}/paddle/scripts/paddle']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册