提交 3a37e142 编写于 作者: Y Yancey1989

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fix_pserver_sub_blocks

...@@ -23,7 +23,7 @@ repos: ...@@ -23,7 +23,7 @@ repos:
- id: clang-format-with-version-check - id: clang-format-with-version-check
name: clang-format name: clang-format
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: bash ./.clang_format.hook -i entry: bash ./tools/codestyle/clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: local - repo: local
...@@ -52,7 +52,7 @@ repos: ...@@ -52,7 +52,7 @@ repos:
hooks: hooks:
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python ./.copyright.hook entry: python ./tools/codestyle/copyright.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$
# Inference High-level APIs
This document describes the high-level inference APIs one can use to easily deploy a Paddle model for an application.
The APIs are described in `paddle_inference_api.h`, just one header file, and two libaries `libpaddle_fluid.so` and `libpaddle_fluid_api.so` are needed.
## PaddleTensor
We provide the `PaddleTensor` data structure is to give a general tensor interface.
The definition is
```c++
struct PaddleTensor {
std::string name; // variable name.
std::vector<int> shape;
PaddleBuf data; // blob of data.
PaddleDType dtype;
};
```
The data is stored in a continuous memory `PaddleBuf`, and tensor's data type is specified by a `PaddleDType`.
The `name` field is used to specify the name of input variable,
that is important when there are multiple inputs and need to distiuish which variable to set.
## engine
The inference APIs has two different underlying implementation, currently there are two valid engines:
- the native engine, which is consists of the native operators and framework,
- the Anakin engine, which is a Anakin library embeded.
The native engine takes a native Paddle model as input, and supports any model that trained by Paddle,
but the Anakin engine can only take the Anakin model as input(user need to manully transform the format first) and currently not all Paddle models are supported.
```c++
enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference.
};
```
## PaddlePredictor and how to create one
The main interface is `PaddlePredictor`, there are following methods
- `bool Run(const std::vector<PaddleTensor>& inputs, std::vector<PaddleTensor>* output_data)`
- take inputs and output `output_data`
- `Clone` to clone a predictor from an existing one, with model parameter shared.
There is a factory method to help create a predictor, and the user takes the ownership of this object.
```c++
template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
```
By specifying the engine kind and config, one can get an specific implementation.
## Reference
- [paddle_inference_api.h](./paddle_inference_api.h)
- [demos](./demo)
...@@ -110,7 +110,6 @@ class PaddlePredictor { ...@@ -110,7 +110,6 @@ class PaddlePredictor {
// The common configs for all the predictors. // The common configs for all the predictors.
struct Config { struct Config {
std::string model_dir; // path to the model directory. std::string model_dir; // path to the model directory.
bool enable_engine{false}; // Enable to execute (part of) the model on
}; };
}; };
......
...@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0);
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker(op.InputArgumentNames(), recv_vars); checker(op.InputArgumentNames(), recv_vars);
} }
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars[var->Name()] = var; all_vars_.emplace(var->Name(), var);
} }
auto graph = new SSAGraph(); auto graph = new SSAGraph();
...@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(program);
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0; size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
return dev_id;
};
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>( if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op); CreateDistTrainOp(&result, *op);
...@@ -199,15 +200,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -199,15 +200,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in
// the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
} else { } else {
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name); var_name_on_devices_.emplace(var_name, op_dev_id);
} }
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
...@@ -230,19 +235,22 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -230,19 +235,22 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name); var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(all_vars, g_name)) { if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0); CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0);
} else { } else {
InsertAllReduceOp(&result, g_name); InsertAllReduceOp(&result, g_name);
} }
break; break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
} }
} }
} catch (boost::bad_get e) { } catch (boost::bad_get e) {
...@@ -261,7 +269,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -261,7 +269,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
harzaeds need to be handled. hazards need to be handled.
*/ */
PolishGraphToSupportDataHazards(&result); PolishGraphToSupportDataHazards(&result);
...@@ -273,11 +281,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -273,11 +281,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient( bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
const std::unordered_map<std::string, VarDesc *> &all_vars, PADDLE_ENFORCE(all_vars_.count(og) != 0);
const std::string &og) const { if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE(all_vars.count(og) != 0);
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true; return true;
} }
return false; return false;
...@@ -363,24 +369,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -363,24 +369,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
int var_dev_id = -1; for (auto &varname : op.InputArgumentNames()) {
for (auto &var_name : op.InputArgumentNames()) { int dev_id = GetVarDeviceID(varname);
if (var_dev_id != -1) break; if (dev_id != -1) {
for (size_t i = 0; i < var_name_on_devices.size(); ++i) { return dev_id;
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
} }
} }
return var_dev_id; return -1;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
...@@ -449,6 +454,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -449,6 +454,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var; return var;
} }
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) { for (auto &prev_op : result->ops_) {
...@@ -463,16 +470,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, ...@@ -463,16 +470,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
CreateComputationalOp(result, op, 0); int op_dev_id = -1;
if (op.Type() == "split_byref") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type());
CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") { if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
} }
} }
// Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
result->ops_.emplace_back( int op_dev_id = -1;
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); if (op.Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
} else if (op.Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
// send_barrier and fetch_barrier op can be scheduled on device 0
op_dev_id = 0;
}
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send"); ConnectOp(result, result->ops_.back().get(), "send");
...@@ -488,9 +545,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -488,9 +545,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
"send, send_barrier. recv, fetch_barrier]"); "send, send_barrier. recv, fetch_barrier]");
} }
// TODO(Yancey1989): schedule rpc op on different place may CreateOpHandleIOs(result, op, op_dev_id);
// increate throughput
CreateOpHandleIOs(result, op, 0);
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
......
...@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const;
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t place_id) const; size_t device_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID( int GetOpDeviceID(const OpDesc &op) const;
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient( bool IsSparseGradient(const std::string &og) const;
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const; size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
......
...@@ -30,6 +30,7 @@ class SSAGraphBuilder { ...@@ -30,6 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars.PopAll(1, &timeout); auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) { if (timeout) {
std::lock_guard<std::mutex> l(exception_mu_);
if (exception_) { if (exception_) {
auto exp = *exception_; auto exp = *exception_;
exception_.reset(); exception_.reset();
...@@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp(
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted"; VLOG(10) << op << " " << op->Name() << "Signal posted";
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
std::lock_guard<std::mutex> l(exception_mu_);
exception_.reset(new platform::EnforceNotMet(ex)); exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) { } catch (...) {
LOG(FATAL) << "Unknown exception catched"; LOG(FATAL) << "Unknown exception catched";
......
...@@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
std::mutex exception_mu_;
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_; std::atomic<int> running_ops_;
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#endif #endif
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -49,8 +49,8 @@ Executor::Executor(const platform::Place& place) : place_(place) {} ...@@ -49,8 +49,8 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() { void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>() ::paddle::operators::distributed::GRPCClient>()
->SendComplete(); ->SendComplete();
} }
#endif #endif
......
...@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor(
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
details::SSAGraphBuilderFactory builder_factory( details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy); build_strategy);
...@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor(
#endif #endif
} }
builder_ = std::move(builder_factory.Create());
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program))); builder_->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
...@@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor( ...@@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
auto *main_scope = member_->local_scopes_[0]; // the the initialize bcast, all vars would be bcast from device(0), otherwise
// bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false;
for (auto &var : vars) { for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var); int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
if (initialize) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
}
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
...@@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs(
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
if (i == 0) {
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
......
...@@ -19,12 +19,14 @@ limitations under the License. */ ...@@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -68,6 +70,7 @@ class ParallelExecutor { ...@@ -68,6 +70,7 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
}; };
} // namespace framework } // namespace framework
......
...@@ -184,8 +184,8 @@ else() ...@@ -184,8 +184,8 @@ else()
set(DEPS_OPS ${DEPS_OPS} nccl_op) set(DEPS_OPS ${DEPS_OPS} nccl_op)
endif() endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
...@@ -195,18 +195,11 @@ if(WITH_DISTRIBUTE) ...@@ -195,18 +195,11 @@ if(WITH_DISTRIBUTE)
endif() endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS}) foreach(dist_op "prefetch_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
op_library(recv_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) endforeach()
op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL) # listen_and_serv_op sum_op executor SERIAL)
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_GRPC #ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
#define RPCSERVER_T detail::AsyncGRPCServer #define RPCSERVER_T distributed::AsyncGRPCServer
#define RPCCLIENT_T detail::GRPCClient #define RPCCLIENT_T distributed::GRPCClient
#else #else
#include "paddle/fluid/operators/detail/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/operators/detail/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#define RPCSERVER_T detail::AsyncBRPCServer #define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCCLIENT_T detail::BRPCClient #define RPCCLIENT_T distributed::BRPCClient
#endif #endif
if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC) if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
DEFINE_int32(brpc_channel_num, 24, DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server"); "Number of channels to send requests connected to one server");
...@@ -175,6 +175,6 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ...@@ -175,6 +175,6 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
return q; return q;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -31,13 +31,13 @@ limitations under the License. */ ...@@ -31,13 +31,13 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
struct ChannelContext { struct ChannelContext {
brpc::Channel channel; brpc::Channel channel;
...@@ -95,6 +95,6 @@ class BRPCClient : public RPCClient { ...@@ -95,6 +95,6 @@ class BRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(BRPCClient); DISABLE_COPY_AND_ASSIGN(BRPCClient);
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv { namespace sendrecv {
typedef std::unordered_map<std::string, typedef std::unordered_map<std::string,
paddle::operators::detail::RequestHandler*> paddle::operators::distributed::RequestHandler*>
HandlerMap; HandlerMap;
class BRPCServiceImpl : public SendRecvService { class BRPCServiceImpl : public SendRecvService {
...@@ -27,17 +27,17 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -27,17 +27,17 @@ class BRPCServiceImpl : public SendRecvService {
: request_send_h_(nullptr), : request_send_h_(nullptr),
request_get_h_(nullptr), request_get_h_(nullptr),
request_prefetch_h_(nullptr) { request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend); auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_send_h_ = it->second; request_send_h_ = it->second;
} }
it = rpc_call_map.find(paddle::operators::detail::kRequestSend); it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_get_h_ = it->second; request_get_h_ = it->second;
} }
it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch); it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second; request_prefetch_h_ = it->second;
} }
...@@ -88,15 +88,15 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -88,15 +88,15 @@ class BRPCServiceImpl : public SendRecvService {
} }
private: private:
paddle::operators::detail::RequestHandler* request_send_h_; paddle::operators::distributed::RequestHandler* request_send_h_;
paddle::operators::detail::RequestHandler* request_get_h_; paddle::operators::distributed::RequestHandler* request_get_h_;
paddle::operators::detail::RequestHandler* request_prefetch_h_; paddle::operators::distributed::RequestHandler* request_prefetch_h_;
}; };
} // namespace sendrecv } // namespace sendrecv
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void AsyncBRPCServer::StartServer() { void AsyncBRPCServer::StartServer() {
// Instance of your service. // Instance of your service.
...@@ -139,6 +139,6 @@ void AsyncBRPCServer::WaitServerReady() { ...@@ -139,6 +139,6 @@ void AsyncBRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer WaitSeverReady"; VLOG(3) << "AsyncGRPCServer WaitSeverReady";
} }
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -19,12 +19,12 @@ limitations under the License. */ ...@@ -19,12 +19,12 @@ limitations under the License. */
#include <string> #include <string>
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class AsyncBRPCServer final : public RPCServer { class AsyncBRPCServer final : public RPCServer {
public: public:
...@@ -48,6 +48,6 @@ class AsyncBRPCServer final : public RPCServer { ...@@ -48,6 +48,6 @@ class AsyncBRPCServer final : public RPCServer {
int ready_; int ready_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -17,11 +17,11 @@ limitations under the License. */ ...@@ -17,11 +17,11 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC // file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data. // requests without too much copying of the tensor data.
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
GrpcByteBufferSource::GrpcByteBufferSource() {} GrpcByteBufferSource::GrpcByteBufferSource() {}
...@@ -83,6 +83,6 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const { ...@@ -83,6 +83,6 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
return byte_count_; return byte_count_;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -106,7 +106,7 @@ class GrpcBufferReader final ...@@ -106,7 +106,7 @@ class GrpcBufferReader final
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
// Source provides a way for a particular RPC implementation to provide // Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom. // received data to ParseFrom.
class Source { class Source {
...@@ -183,6 +183,6 @@ class GrpcByteSource : public Source { ...@@ -183,6 +183,6 @@ class GrpcByteSource : public Source {
char space_[sizeof(Reader)]; char space_[sizeof(Reader)];
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#include <sys/time.h> #include <sys/time.h>
#include <limits> #include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void GRPCClient::InitImpl() { InitEventLoop(); } void GRPCClient::InitImpl() { InitEventLoop(); }
...@@ -276,6 +276,6 @@ std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) { ...@@ -276,6 +276,6 @@ std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
return ch; return ch;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -38,13 +38,13 @@ limitations under the License. */ ...@@ -38,13 +38,13 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
struct VarHandle { struct VarHandle {
std::string ep; std::string ep;
...@@ -226,6 +226,6 @@ class GRPCClient : public RPCClient { ...@@ -226,6 +226,6 @@ class GRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(GRPCClient); DISABLE_COPY_AND_ASSIGN(GRPCClient);
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -21,8 +21,8 @@ limitations under the License. */ ...@@ -21,8 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -50,7 +50,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -50,7 +50,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < 564; ++i) rows->push_back(i); for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0)); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
...@@ -81,10 +81,10 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -81,10 +81,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// deserialize zero-copy // deserialize zero-copy
// framework::Variable var2; // framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx); operators::distributed::VariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar(); framework::Variable* var2 = resp.GetVar();
...@@ -128,7 +128,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -128,7 +128,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
math::set_constant(ctx, tensor, 31.9); math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0)); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
...@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy // deserialize zero-copy
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx); operators::distributed::VariableResponse resp(&scope, &ctx);
if (from_type == 0) { if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
} else { } else {
......
...@@ -15,13 +15,13 @@ limitations under the License. */ ...@@ -15,13 +15,13 @@ limitations under the License. */
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum CallStatus { PROCESS = 0, FINISH }; enum CallStatus { PROCESS = 0, FINISH };
// reference: // reference:
...@@ -74,7 +74,7 @@ class RequestSend final : public RequestBase { ...@@ -74,7 +74,7 @@ class RequestSend final : public RequestBase {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), request_handler->dev_ctx(),
!request_handler->sync_mode())); !request_handler->sync_mode()));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
...@@ -106,7 +106,7 @@ class RequestGet final : public RequestBase { ...@@ -106,7 +106,7 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_, method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
...@@ -150,7 +150,8 @@ class RequestPrefetch final : public RequestBase { ...@@ -150,7 +150,8 @@ class RequestPrefetch final : public RequestBase {
local_scope_(nullptr) { local_scope_(nullptr) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true)); request_handler->dev_ctx(), true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id =
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
...@@ -354,6 +355,6 @@ void AsyncGRPCServer::HandleRequest( ...@@ -354,6 +355,6 @@ void AsyncGRPCServer::HandleRequest(
} }
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -29,17 +29,17 @@ limitations under the License. */ ...@@ -29,17 +29,17 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h" #include "paddle/fluid/operators/distributed/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RequestBase; class RequestBase;
...@@ -84,6 +84,6 @@ class AsyncGRPCServer final : public RPCServer { ...@@ -84,6 +84,6 @@ class AsyncGRPCServer final : public RPCServer {
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_; std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h> #include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h> #include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -42,16 +42,17 @@ class ServerContext; ...@@ -42,16 +42,17 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse. // Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse. // Wire-format is identical to RecvVariableResponse.
template <> template <>
class SerializationTraits<paddle::operators::detail::VariableResponse> { class SerializationTraits<paddle::operators::distributed::VariableResponse> {
public: public:
static Status Serialize( static Status Serialize(
const paddle::operators::detail::VariableResponse& msg, const paddle::operators::distributed::VariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) { grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status(); return Status();
} }
static Status Deserialize(grpc_byte_buffer* buffer, static Status Deserialize(
paddle::operators::detail::VariableResponse* msg, grpc_byte_buffer* buffer,
paddle::operators::distributed::VariableResponse* msg,
int max_message_size = INT_MAX) { int max_message_size = INT_MAX) {
if (buffer == nullptr) { if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload"); return Status(StatusCode::INTERNAL, "No payload");
...@@ -59,7 +60,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> { ...@@ -59,7 +60,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> {
Status result = g_core_codegen_interface->ok(); Status result = g_core_codegen_interface->ok();
if (result.ok()) { if (result.ok()) {
paddle::operators::detail::GrpcByteSource source(buffer); paddle::operators::distributed::GrpcByteSource source(buffer);
int ret = msg->Parse(&source); int ret = msg->Parse(&source);
if (ret != 0) { if (ret != 0) {
result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
...@@ -73,7 +74,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> { ...@@ -73,7 +74,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> {
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum class GrpcMethod { enum class GrpcMethod {
kSendVariable, kSendVariable,
...@@ -118,6 +119,6 @@ class GrpcService final { ...@@ -118,6 +119,6 @@ class GrpcService final {
}; };
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
char* EncodeVarint32(char* dst, uint32_t v) { char* EncodeVarint32(char* dst, uint32_t v) {
// Operate on characters as unsigneds // Operate on characters as unsigneds
...@@ -144,6 +144,6 @@ class ProtoEncodeHelper { ...@@ -144,6 +144,6 @@ class ProtoEncodeHelper {
char* limit_; // Just for CHECKs char* limit_; // Just for CHECKs
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
...@@ -124,6 +124,6 @@ class RequestHandler { ...@@ -124,6 +124,6 @@ class RequestHandler {
RPCServer* rpc_server_; RPCServer* rpc_server_;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
...@@ -119,6 +119,6 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -119,6 +119,6 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true; return true;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -28,11 +28,11 @@ ...@@ -28,11 +28,11 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RequestSendHandler final : public RequestHandler { class RequestSendHandler final : public RequestHandler {
public: public:
...@@ -66,6 +66,6 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -66,6 +66,6 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
std::once_flag RPCClient::init_flag_; std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr); std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RPCClient { class RPCClient {
public: public:
...@@ -84,6 +84,6 @@ class RPCClient { ...@@ -84,6 +84,6 @@ class RPCClient {
static std::once_flag init_flag_; static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_; static std::unique_ptr<RPCClient> rpc_client_;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void RPCServer::ShutDown() { void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown "; LOG(INFO) << "RPCServer ShutDown ";
...@@ -112,6 +112,6 @@ void RPCServer::WaitCond(const std::string& rpc_name) { ...@@ -112,6 +112,6 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RPCServer { class RPCServer {
public: public:
...@@ -86,6 +86,6 @@ class RPCServer { ...@@ -86,6 +86,6 @@ class RPCServer {
friend class RequestHandler; friend class RequestHandler;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -22,18 +22,18 @@ limitations under the License. */ ...@@ -22,18 +22,18 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace detail = paddle::operators::detail; namespace distributed = paddle::operators::distributed;
USE_OP(lookup_table); USE_OP(lookup_table);
std::unique_ptr<detail::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
...@@ -113,19 +113,21 @@ void StartServer() { ...@@ -113,19 +113,21 @@ void StartServer() {
g_req_handler->SetScope(&scope); g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get()); g_rpc_service->RegisterRPC(distributed::kRequestPrefetch,
g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, g_rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join(); server_thread.join();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true)); g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <nccl.h> #include <nccl.h>
...@@ -23,14 +23,14 @@ limitations under the License. */ ...@@ -23,14 +23,14 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h" #include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
using VarMsg = sendrecv::VariableMessage; using VarMsg = sendrecv::VariableMessage;
...@@ -222,11 +222,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -222,11 +222,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var) { framework::Variable** var) {
operators::detail::VariableResponse resp(scope, &ctx); operators::distributed::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar(); *var = resp.GetVar();
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -25,12 +25,12 @@ limitations under the License. */ ...@@ -25,12 +25,12 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
...@@ -61,6 +61,6 @@ inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { ...@@ -61,6 +61,6 @@ inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
} }
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -22,12 +22,12 @@ ...@@ -22,12 +22,12 @@
#endif #endif
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum WireType { enum WireType {
WIRETYPE_VARINT = 0, WIRETYPE_VARINT = 0,
...@@ -158,13 +158,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -158,13 +158,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType( length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type()))); paddle::operators::distributed::ToTypeIndex(
meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type())); paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
...@@ -480,6 +480,6 @@ int VariableResponse::Parse(Source* source) { ...@@ -480,6 +480,6 @@ int VariableResponse::Parse(Source* source) {
return 0; return 0;
} }
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -22,17 +22,17 @@ ...@@ -22,17 +22,17 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class VariableResponse { class VariableResponse {
public: public:
...@@ -99,6 +99,6 @@ class VariableResponse { ...@@ -99,6 +99,6 @@ class VariableResponse {
sendrecv::VariableMessage meta_; sendrecv::VariableMessage meta_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -42,8 +42,8 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -42,8 +42,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait(); rpc_client->Wait();
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
namespace paddle { namespace paddle {
...@@ -60,7 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -60,7 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
...@@ -80,11 +81,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -80,11 +81,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
detail::RequestSendHandler rpc_h(true); distributed::RequestSendHandler rpc_h(true);
std::unique_ptr<detail::RPCServer> rpc_service( std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1)); new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(detail::kRequestSend, &rpc_h); rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get()); rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
...@@ -95,11 +96,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -95,11 +96,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_h.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(detail::kRequestSend); rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(detail::kRequestSend); rpc_service->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service->ShutDown(); rpc_service->ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
......
...@@ -21,14 +21,14 @@ limitations under the License. */ ...@@ -21,14 +21,14 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void RunServer(std::shared_ptr<detail::RPCServer> service) { void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer(); service->StartServer();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
...@@ -119,12 +119,12 @@ void ListenAndServOp::RunSyncLoop( ...@@ -119,12 +119,12 @@ void ListenAndServOp::RunSyncLoop(
while (true) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(detail::kRequestSend); rpc_service_->SetCond(distributed::kRequestSend);
rpc_service_->WaitBarrier(detail::kRequestSend); rpc_service_->WaitBarrier(distributed::kRequestSend);
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!"; LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
break; break;
} }
...@@ -152,11 +152,11 @@ void ListenAndServOp::RunSyncLoop( ...@@ -152,11 +152,11 @@ void ListenAndServOp::RunSyncLoop(
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
// reset received sparse vars to avoid reuse it in the next mini-batch // reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get()) dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder(); ->ResetSparseVarRecorder();
} // while(true) } // while(true)
} }
...@@ -213,13 +213,13 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -213,13 +213,13 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} }
static void FillRequestCtx( static void FillRequestCtx(
detail::RequestHandler *h, framework::Scope *scope, distributed::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx, framework::Executor *executor, platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx, *prefetch_ctx,
detail::RPCServer *rpc_server) { distributed::RPCServer *rpc_server) {
h->SetScope(scope); h->SetScope(scope);
h->SetDevCtx(dev_ctx); h->SetDevCtx(dev_ctx);
h->SetExecutor(executor); h->SetExecutor(executor);
...@@ -247,14 +247,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -247,14 +247,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode)); request_send_handler_.reset(new distributed::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestSend,
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch, rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
auto optimize_blocks = auto optimize_blocks =
......
...@@ -24,8 +24,8 @@ limitations under the License. */ ...@@ -24,8 +24,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,7 +33,7 @@ namespace operators { ...@@ -33,7 +33,7 @@ namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
...@@ -62,10 +62,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -62,10 +62,11 @@ class ListenAndServOp : public framework::OperatorBase {
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
protected: protected:
mutable std::shared_ptr<detail::RPCServer> rpc_service_; mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_; mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
}; };
......
...@@ -93,10 +93,10 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -93,10 +93,10 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation // computation
for (size_t k = 0; k < input_rows; ++k) { for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols; const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0; int col_idx = 0;
for (int j = 0; j < num; ++j) { for (size_t j = 0; j < num; ++j) {
int col_len = output_cols[j]; int col_len = output_cols[j];
auto* out_tensor = outputs->at(j); auto* out_tensor = outputs->at(j);
if (out_tensor != nullptr) { if (out_tensor != nullptr) {
......
...@@ -295,7 +295,7 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -295,7 +295,7 @@ class ParallelDoGradOp : public framework::OperatorBase {
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}}, "sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}},
framework::AttributeMap{{"use_mkldnn", {false}}}); framework::AttributeMap{});
VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]); VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]);
sum_op->Run(*sub_scopes[0], places[0]); sum_op->Run(*sub_scopes[0], places[0]);
WaitOnPlace(places[0]); WaitOnPlace(places[0]);
......
...@@ -41,8 +41,8 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -41,8 +41,8 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
...@@ -429,8 +429,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -429,8 +429,7 @@ class RecurrentGradOp : public RecurrentBase {
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, place); sum_op->Run(cur_scope, place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
......
...@@ -43,8 +43,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -43,8 +43,8 @@ class RecvOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
......
...@@ -44,8 +44,8 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -44,8 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
......
...@@ -45,8 +45,8 @@ class SendOp : public framework::OperatorBase { ...@@ -45,8 +45,8 @@ class SendOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*Licensed under the Apache License, Version 2.0(the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::CPUDeviceContext;
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using mkldnn::sum;
using mkldnn::reorder;
using platform::to_void_cast;
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X");
const int N = in_vars.size();
auto out_var = ctx.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
LoDTensor* output = ctx.Output<LoDTensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = framework::vectorize2int(output->dims());
auto src_tz = dst_tz;
memory::format output_format{memory::format::format_undef};
std::vector<float> scales;
std::vector<memory::primitive_desc> srcs_mpd;
std::vector<mkldnn::memory> srcs_mem;
PADDLE_ENFORCE(in_vars[0]->IsType<LoDTensor>(),
"Input[0] must be LoDTensors");
auto& input0 = in_vars[0]->Get<LoDTensor>();
PADDLE_ENFORCE(input0.layout() == DataLayout::kMKLDNN &&
input0.format() != memory::format::format_undef,
"Wrong layout/format for inputs[0]");
memory::format input_format = input0.format();
if (src_tz.size() == 1 && (input_format == memory::format::nchw ||
input_format == memory::format::nhwc)) {
input_format = memory::format::x;
}
if (src_tz.size() == 2 && (input_format == memory::format::nchw ||
input_format == memory::format::nhwc)) {
input_format = memory::format::nc;
}
for (int i = in_place ? 1 : 0; i < N; i++) {
PADDLE_ENFORCE(in_vars[i]->IsType<LoDTensor>(),
"all inputs must be all LoDTensors");
auto& input = in_vars[i]->Get<LoDTensor>();
PADDLE_ENFORCE(input.layout() == DataLayout::kMKLDNN &&
input.format() != memory::format::format_undef,
"Wrong layout/format for inputs");
if (input.numel() == 0) {
continue;
}
const T* input_data = input.data<T>();
auto src_md =
memory::desc(src_tz, memory::data_type::f32, input_format);
auto src_mpd = memory::primitive_desc(src_md, mkldnn_engine);
auto src_mem = memory(src_mpd, to_void_cast(input_data));
srcs_mpd.push_back(src_mpd);
srcs_mem.push_back(src_mem);
scales.push_back(1.0);
}
auto dst_md =
memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
std::shared_ptr<memory> dst_mem;
if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
} else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
}
std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
inputs.push_back(srcs_mem[i]);
}
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd);
primitive reorder_prim;
std::shared_ptr<memory> target_mem;
if (in_place) {
output_format = input_format;
target_mem.reset(new memory(
{{{src_tz}, memory::data_type::f32, output_format}, mkldnn_engine},
output_data));
reorder_prim = reorder(*dst_mem, *target_mem);
}
std::vector<primitive> pipeline;
pipeline.push_back(sum_prim);
if (in_place) pipeline.push_back(reorder_prim);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
} else if (out_var->IsType<framework::SelectedRows>()) {
// TODO(@mozga-intel) Add MKLDNN SelectedRows support
std::unique_ptr<framework::SelectedRows> in0;
if (in_place) {
// If is in_place, we store the input[0] to in0
auto& in_sel0 = in_vars[0]->Get<SelectedRows>();
auto& rows = in_sel0.rows();
in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
in0->mutable_value()->ShareDataWith(in_sel0.value());
}
auto get_selected_row = [&](size_t i) -> const SelectedRows& {
if (i == 0 && in0) {
return *in0.get();
} else {
return in_vars[i]->Get<SelectedRows>();
}
};
auto* out = ctx.Output<SelectedRows>("Out");
out->mutable_rows()->clear();
auto* out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
for (int i = 0; i < N; i++) {
auto& sel_row = get_selected_row(i);
first_dim += sel_row.rows().size();
}
auto in_dim =
framework::vectorize(get_selected_row(N - 1).value().dims());
in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim));
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
return;
}
out_value->mutable_data<T>(ctx.GetPlace());
math::SelectedRowsAddTo<CPUDeviceContext, T> functor;
int64_t offset = 0;
for (int i = 0; i < N; i++) {
auto& sel_row = get_selected_row(i);
if (sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(ctx.template device_context<CPUDeviceContext>(), sel_row,
offset, out);
offset += sel_row.value().numel();
}
} else if (out_var->IsType<framework::LoDTensorArray>()) {
// TODO(@mozga-intel) Add MKLDNN LoDTensorArray support
auto& out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
"Only support all inputs are TensorArray");
auto& in_array = in_vars[i]->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].numel() != 0) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (out_array[i].numel() == 0) {
framework::TensorCopy(in_array[i], in_array[i].place(),
ctx.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
auto in = EigenVector<T>::Flatten(in_array[i]);
auto result = EigenVector<T>::Flatten(out_array[i]);
result.device(*ctx.template device_context<MKLDNNDeviceContext>()
.eigen_device()) = result + in;
}
}
}
}
} else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
out_var->Type().name());
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::SumMKLDNNOpKernel<float>);
...@@ -18,10 +18,6 @@ limitations under the License. */ ...@@ -18,10 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
...@@ -67,18 +63,6 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -67,18 +63,6 @@ class SumOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X"); auto x_vars = ctx.MultiInputVar("X");
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
if (x_vars[0]->IsType<framework::LoDTensor>()) { if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1; int dtype = -1;
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
...@@ -96,27 +80,26 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -96,27 +80,26 @@ class SumOp : public framework::OperatorWithKernel {
"Sum operator should have at least one tensor"); "Sum operator should have at least one tensor");
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(), static_cast<framework::proto::VarType::Type>(dtype),
layout, library); ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) { } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
for (auto& var : x_vars) { for (auto& var : x_vars) {
auto& value = var->Get<framework::SelectedRows>().value(); auto& value = var->Get<framework::SelectedRows>().value();
if (value.IsInitialized()) { if (value.IsInitialized()) {
return framework::OpKernelType(framework::ToDataType(value.type()), return framework::OpKernelType(framework::ToDataType(value.type()),
ctx.device_context(), layout, library); ctx.device_context());
} }
} }
// if input sparse vars are not initialized, use an default kernel type. // if input sparse vars are not initialized, use an default kernel type.
return framework::OpKernelType(framework::proto::VarType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context(), layout, library); ctx.device_context());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) { } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
auto& array = x_var->Get<framework::LoDTensorArray>(); auto& array = x_var->Get<framework::LoDTensorArray>();
for (auto& each : array) { for (auto& each : array) {
if (each.numel() != 0) { if (each.numel() != 0) {
return framework::OpKernelType(framework::ToDataType(each.type()), return framework::OpKernelType(framework::ToDataType(each.type()),
ctx.device_context(), layout, ctx.device_context());
library);
} }
} }
} }
...@@ -133,9 +116,6 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -133,9 +116,6 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.") AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X"); AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Sum operator. Sum operator.
...@@ -152,6 +132,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -152,6 +132,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarType::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
<< block->FindRecursiveOrCreateVar(name).GetType(); << block->FindRecursiveOrCreateVar(name).GetType();
...@@ -225,7 +206,6 @@ namespace ops = paddle::operators; ...@@ -225,7 +206,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker, REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
ops::SumOpVarTypeInference); ops::SumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>, sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
ops::SumKernel<paddle::platform::CPUDeviceContext, double>, ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -37,11 +37,11 @@ USE_NO_KERNEL_OP(listen_and_serv); ...@@ -37,11 +37,11 @@ USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail; namespace distributed = paddle::operators::distributed;
namespace string = paddle::string; namespace string = paddle::string;
std::unique_ptr<detail::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
void StartServer() { void StartServer() {
f::Scope scope; f::Scope scope;
...@@ -57,14 +57,14 @@ void StartServer() { ...@@ -57,14 +57,14 @@ void StartServer() {
g_req_handler->SetProgram(&empty_program); g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor); g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get()); g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, g_rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend); g_rpc_service->SetCond(distributed::kRequestSend);
g_rpc_service->WaitBarrier(detail::kRequestSend); g_rpc_service->WaitBarrier(distributed::kRequestSend);
LOG(INFO) << "got nccl id and stop server..."; LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown(); g_rpc_service->ShutDown();
...@@ -72,7 +72,7 @@ void StartServer() { ...@@ -72,7 +72,7 @@ void StartServer() {
} }
TEST(SendNcclId, RPCServer) { TEST(SendNcclId, RPCServer) {
g_req_handler.reset(new detail::RequestSendHandler(true)); g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
...@@ -91,7 +91,8 @@ TEST(SendNcclId, RPCServer) { ...@@ -91,7 +91,8 @@ TEST(SendNcclId, RPCServer) {
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
LOG(INFO) << "connect to server" << ep; LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
......
...@@ -203,11 +203,11 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -203,11 +203,11 @@ class WhileGradOp : public framework::OperatorBase {
->set_lod(inside_tensor.lod()); ->set_lod(inside_tensor.lod());
} }
} }
auto new_inside_name = cur_scope.Rename(inside_grad_name); auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place); sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} }
......
...@@ -99,11 +99,5 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) { ...@@ -99,11 +99,5 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) {
memory.get_primitive_desc().desc().data.format); memory.get_primitive_desc().desc().data.format);
} }
inline mkldnn::memory::format GetMKLDNNFormat(
const mkldnn::sum::primitive_desc& memory) {
return static_cast<mkldnn::memory::format>(
memory.dst_primitive_desc().desc().data.format);
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -132,9 +132,9 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -132,9 +132,9 @@ def _addup_repetitive_outputs_(op_descs):
for idx, op_desc in enumerate(op_descs): for idx, op_desc in enumerate(op_descs):
for var_name in op_desc.input_arg_names(): for var_name in op_desc.input_arg_names():
if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_( pending_sum_ops.append(
"sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, (_create_op_desc_("sum", {"X": renamed_vars[var_name]},
{"use_mkldnn": False}), idx)) {"Out": [var_name]}, {}), idx))
renamed_vars[var_name] = [var_name] renamed_vars[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
if var_name == core.empty_var_name( if var_name == core.empty_var_name(
...@@ -161,9 +161,8 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -161,9 +161,8 @@ def _addup_repetitive_outputs_(op_descs):
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in renamed_vars.iteritems(): for var_name, inputs in renamed_vars.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append( pending_sum_ops.append((_create_op_desc_(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, "sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs)))
{"use_mkldnn": False}), len(op_descs)))
# sum_op descs are sorted according to their insert position # sum_op descs are sorted according to their insert position
for p in reversed(pending_sum_ops): for p in reversed(pending_sum_ops):
op_descs.insert(p[1], p[0]) op_descs.insert(p[1], p[0])
......
...@@ -41,7 +41,12 @@ def _clone_var_(block, var): ...@@ -41,7 +41,12 @@ def _clone_var_(block, var):
class Evaluator(object): class Evaluator(object):
""" """
Base Class for all evaluators Warning: better to use the fluid.metrics.* things, more
flexible support via pure Python and Operator, and decoupled
with executor. Short doc are intended to urge new user
start from Metrics.
Base Class for all evaluators.
Args: Args:
name(str): The name of evaluator. such as, "accuracy". Used for generate name(str): The name of evaluator. such as, "accuracy". Used for generate
...@@ -69,6 +74,10 @@ class Evaluator(object): ...@@ -69,6 +74,10 @@ class Evaluator(object):
def reset(self, executor, reset_program=None): def reset(self, executor, reset_program=None):
""" """
reset metric states at the begin of each pass/user specified batch reset metric states at the begin of each pass/user specified batch
Args:
executor(Executor|ParallelExecutor): a executor for executing the reset_program
reset_program(Program): a single Program for reset process
""" """
if reset_program is None: if reset_program is None:
reset_program = Program() reset_program = Program()
...@@ -85,15 +94,16 @@ class Evaluator(object): ...@@ -85,15 +94,16 @@ class Evaluator(object):
def eval(self, executor, eval_program=None): def eval(self, executor, eval_program=None):
""" """
Evaluate the statistics merged by multiple mini-batches. Evaluate the statistics merged by multiple mini-batches.
Args:
executor(Executor|ParallelExecutor): a executor for executing the eval_program
eval_program(Program): a single Program for eval process
""" """
raise NotImplementedError() raise NotImplementedError()
def create_state(self, suffix, dtype, shape): def _create_state(self, suffix, dtype, shape):
""" """
Create state variable. Create state variable.
NOTE: It is not a public API.
Args: Args:
suffix(str): the state suffix. suffix(str): the state suffix.
dtype(str|core.VarDesc.VarType): the state data type dtype(str|core.VarDesc.VarType): the state data type
...@@ -113,9 +123,35 @@ class Evaluator(object): ...@@ -113,9 +123,35 @@ class Evaluator(object):
class ChunkEvaluator(Evaluator): class ChunkEvaluator(Evaluator):
""" """
Warning: This would be deprecated in the future. Please use fluid.metrics.ChunkEvaluator
instead.
Accumulate counter numbers output by chunk_eval from mini-batches and Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter compute the precision recall and F1-score using the accumulated counter
numbers. numbers.
For some basics of chunking, please refer to
'Chunking with Support Vector Machines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>'.
Args:
input (Variable): prediction output of the network.
label (Variable): label of the test data set.
chunk_scheme (str): can be IOB/IOE/IOBES and IO. See the chunk_eval op for details.
num_chunk_types (int): the number of chunk type.
excluded_chunk_types (list): A list including chunk type ids, indicating chunk types that are not counted.
Returns:
tuple: tuple containing: precision, recall, f1_score
Examples:
.. code-block:: python
exe = fluid.executor(place)
evaluator = fluid.Evaluator.ChunkEvaluator(input, label)
for epoch in PASS_NUM:
evaluator.reset(exe)
for data in batches:
loss = exe.run(fetch_list=[cost])
distance, instance_error = distance_evaluator.eval(exe)
""" """
def __init__( def __init__(
...@@ -130,11 +166,11 @@ class ChunkEvaluator(Evaluator): ...@@ -130,11 +166,11 @@ class ChunkEvaluator(Evaluator):
if main_program.current_block().idx != 0: if main_program.current_block().idx != 0:
raise ValueError("You can only invoke Evaluator in root block") raise ValueError("You can only invoke Evaluator in root block")
self.num_infer_chunks = self.create_state( self.num_infer_chunks = self._create_state(
dtype='int64', shape=[1], suffix='num_infer_chunks') dtype='int64', shape=[1], suffix='num_infer_chunks')
self.num_label_chunks = self.create_state( self.num_label_chunks = self._create_state(
dtype='int64', shape=[1], suffix='num_label_chunks') dtype='int64', shape=[1], suffix='num_label_chunks')
self.num_correct_chunks = self.create_state( self.num_correct_chunks = self._create_state(
dtype='int64', shape=[1], suffix='num_correct_chunks') dtype='int64', shape=[1], suffix='num_correct_chunks')
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval( precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval(
input=input, input=input,
...@@ -178,6 +214,8 @@ class ChunkEvaluator(Evaluator): ...@@ -178,6 +214,8 @@ class ChunkEvaluator(Evaluator):
class EditDistance(Evaluator): class EditDistance(Evaluator):
""" """
Warning: This would be deprecated in the future. Please use fluid.metrics.EditDistance
instead.
Accumulate edit distance sum and sequence number from mini-batches and Accumulate edit distance sum and sequence number from mini-batches and
compute the average edit_distance and instance error of all batches. compute the average edit_distance and instance error of all batches.
...@@ -188,7 +226,8 @@ class EditDistance(Evaluator): ...@@ -188,7 +226,8 @@ class EditDistance(Evaluator):
ignored_tokens(list of int): Tokens that should be removed before ignored_tokens(list of int): Tokens that should be removed before
calculating edit distance. calculating edit distance.
Example: Examples:
.. code-block:: python
exe = fluid.executor(place) exe = fluid.executor(place)
distance_evaluator = fluid.Evaluator.EditDistance(input, label) distance_evaluator = fluid.Evaluator.EditDistance(input, label)
...@@ -210,11 +249,11 @@ class EditDistance(Evaluator): ...@@ -210,11 +249,11 @@ class EditDistance(Evaluator):
if main_program.current_block().idx != 0: if main_program.current_block().idx != 0:
raise ValueError("You can only invoke Evaluator in root block") raise ValueError("You can only invoke Evaluator in root block")
self.total_distance = self.create_state( self.total_distance = self._create_state(
dtype='float32', shape=[1], suffix='total_distance') dtype='float32', shape=[1], suffix='total_distance')
self.seq_num = self.create_state( self.seq_num = self._create_state(
dtype='int64', shape=[1], suffix='seq_num') dtype='int64', shape=[1], suffix='seq_num')
self.instance_error = self.create_state( self.instance_error = self._create_state(
dtype='int64', shape=[1], suffix='instance_error') dtype='int64', shape=[1], suffix='instance_error')
distances, seq_num = layers.edit_distance( distances, seq_num = layers.edit_distance(
input=input, label=label, ignored_tokens=ignored_tokens) input=input, label=label, ignored_tokens=ignored_tokens)
...@@ -256,9 +295,10 @@ class EditDistance(Evaluator): ...@@ -256,9 +295,10 @@ class EditDistance(Evaluator):
class DetectionMAP(Evaluator): class DetectionMAP(Evaluator):
""" """
Warning: This would be deprecated in the future. Please use fluid.metrics.DetectionMAP
instead.
Calculate the detection mean average precision (mAP). Calculate the detection mean average precision (mAP).
TODO (Dang Qingqing): update the following doc.
The general steps are as follows: The general steps are as follows:
1. calculate the true positive and false positive according to the input 1. calculate the true positive and false positive according to the input
of detection and labels. of detection and labels.
...@@ -293,7 +333,8 @@ class DetectionMAP(Evaluator): ...@@ -293,7 +333,8 @@ class DetectionMAP(Evaluator):
- 11point: the 11-point interpolated average precision. - 11point: the 11-point interpolated average precision.
- integral: the natural integral of the precision-recall curve. - integral: the natural integral of the precision-recall curve.
Example: Examples:
.. code-block:: python
exe = fluid.executor(place) exe = fluid.executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input, map_evaluator = fluid.Evaluator.DetectionMAP(input,
...@@ -340,9 +381,10 @@ class DetectionMAP(Evaluator): ...@@ -340,9 +381,10 @@ class DetectionMAP(Evaluator):
evaluate_difficult=evaluate_difficult, evaluate_difficult=evaluate_difficult,
ap_version=ap_version) ap_version=ap_version)
self.create_state(dtype='int32', shape=None, suffix='accum_pos_count') self._create_state(dtype='int32', shape=None, suffix='accum_pos_count')
self.create_state(dtype='float32', shape=None, suffix='accum_true_pos') self._create_state(dtype='float32', shape=None, suffix='accum_true_pos')
self.create_state(dtype='float32', shape=None, suffix='accum_false_pos') self._create_state(
dtype='float32', shape=None, suffix='accum_false_pos')
self.has_state = None self.has_state = None
var = self.helper.create_variable( var = self.helper.create_variable(
......
...@@ -18,7 +18,7 @@ from framework import Program, default_main_program, Variable ...@@ -18,7 +18,7 @@ from framework import Program, default_main_program, Variable
from . import core from . import core
__all__ = [ __all__ = [
'Executor', 'global_scope', 'scope_guard', 'switch_scope', 'fetch_var' 'Executor', 'global_scope', 'scope_guard', '_switch_scope', 'fetch_var'
] ]
g_scope = core.Scope() g_scope = core.Scope()
...@@ -35,7 +35,7 @@ def global_scope(): ...@@ -35,7 +35,7 @@ def global_scope():
return g_scope return g_scope
def switch_scope(scope): def _switch_scope(scope):
global g_scope global g_scope
ex = g_scope ex = g_scope
g_scope = scope g_scope = scope
...@@ -57,12 +57,27 @@ def scope_guard(scope): ...@@ -57,12 +57,27 @@ def scope_guard(scope):
Args: Args:
scope: The new global/default scope. scope: The new global/default scope.
""" """
ex = switch_scope(scope) ex = _switch_scope(scope)
yield yield
switch_scope(ex) _switch_scope(ex)
def as_numpy(tensor): def as_numpy(tensor):
"""
Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
>>> import paddle.fluid as fluid
>>> outs = executor.run(...)
>>> np_outs = map(lambda x: as_numpy(x), outs)
>>> ...
Args:
tensor(Variable): a instance of Tensor
Returns:
numpy.ndarray
"""
if isinstance(tensor, list): if isinstance(tensor, list):
return [as_numpy(t) for t in tensor] return [as_numpy(t) for t in tensor]
assert isinstance(tensor, core.LoDTensor) assert isinstance(tensor, core.LoDTensor)
...@@ -186,7 +201,7 @@ def fetch_var(name, scope=None, return_numpy=True): ...@@ -186,7 +201,7 @@ def fetch_var(name, scope=None, return_numpy=True):
return tensor return tensor
def get_program_cache_key(feed, fetch_list): def _get_program_cache_key(feed, fetch_list):
feed_var_names = feed.keys() feed_var_names = feed.keys()
def to_name_str(var): def to_name_str(var):
...@@ -205,6 +220,25 @@ def get_program_cache_key(feed, fetch_list): ...@@ -205,6 +220,25 @@ def get_program_cache_key(feed, fetch_list):
class Executor(object): class Executor(object):
"""
An Executor in Python, only support the single-GPU running. For multi-cards, please refer to
ParallelExecutor.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list.
It store the global variables into the global scope, and create a local scope for the temporary
variables. The local scope contents will be discarded after every minibatch forward/backward finished.
But the global scope variables will be persistent through different runs.
All of ops in program will be running in sequence.
Args:
place(core.CPUPlace|core.CUDAPlace(n)): indicate the executor run on which device
Note: For debugging complicated network in parallel-GPUs, you can test it on the executor.
They has the exactly same arguments, and expected the same results.
"""
def __init__(self, place): def __init__(self, place):
self.place = place self.place = place
p = core.Place() p = core.Place()
...@@ -213,6 +247,23 @@ class Executor(object): ...@@ -213,6 +247,23 @@ class Executor(object):
self.program_caches = dict() self.program_caches = dict()
def as_lodtensor(self, data): def as_lodtensor(self, data):
"""
Convert numpy.ndarray to Tensor, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
>>> import paddle.fluid as fluid
>>> exe = fluid.executor(fluid.CPUPlace())
>>> data = np.array(size=(100, 200, 300))
>>> np_outs = map(lambda x: exe.as_lodtensor(x), data)
>>> ...
Args:
data(numpy.ndarray): a instance of array
Returns:
LoDTensor
"""
if isinstance(data, list): if isinstance(data, list):
raise RuntimeError("Some of your feed data hold LoD information. \ raise RuntimeError("Some of your feed data hold LoD information. \
They can not be completely cast from a list of Python \ They can not be completely cast from a list of Python \
...@@ -304,23 +355,47 @@ class Executor(object): ...@@ -304,23 +355,47 @@ class Executor(object):
scope=None, scope=None,
return_numpy=True, return_numpy=True,
use_program_cache=False): use_program_cache=False):
""" Run program by this Executor. Feed data by feed map, fetch result by fetch_list. """
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all the variables(or names) that user want to get after program run.
Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list operators in the program but not only the operators dependent by the fetch_list
:param program: the program that need to run, if not provied, then default_main_program will be used. Args:
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData} program(Program): the program that need to run, if not provied, then default_main_program will be used.
:param fetch_list: a list of variable or variable names that user want to get, run will return them according feed(dict): feed variable map, e.g. {"image": ImageData, "label": LableData}
to this list. fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
:param feed_var_name: the name for the input variable of feed Operator. feed_var_name(str): the name for the input variable of feed Operator.
:param fetch_var_name: the name for the output variable of feed Operator. fetch_var_name(str): the name for the output variable of fetch Operator.
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope scope(Scope): the scope used to run this program, you can switch it to different scope. default is global_scope
:param return_numpy: if convert the fetched tensor to numpy return_numpy(bool): if convert the fetched tensor to numpy
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step. use_program_cache(bool): set use_program_cache to true if program not changed compare to the last step.
:return: result according to fetch_list.
Returns:
list(numpy.array): fetch result according to fetch_list.
Examples:
>>> data = layers.data(name='X', shape=[1], dtype='float32')
>>> hidden = layers.fc(input=data, size=10)
>>> layers.assign(hidden, out)
>>> loss = layers.mean(out)
>>> adam = fluid.optimizer.Adam()
>>> adam.minimize(loss)
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> exe.run(default_startup_program())
>>> x = numpy.random.random(size=(10, 1)).astype('float32')
>>> outs = exe.run(
>>> feed={'X': x},
>>> fetch_list=[loss.name])
""" """
if feed is None: if feed is None:
feed = {} feed = {}
...@@ -341,7 +416,7 @@ class Executor(object): ...@@ -341,7 +416,7 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
cache_key = get_program_cache_key(feed, fetch_list) cache_key = _get_program_cache_key(feed, fetch_list)
if use_program_cache: if use_program_cache:
cached_program = self._get_program_cache(cache_key) cached_program = self._get_program_cache(cache_key)
if cached_program is None: if cached_program is None:
......
...@@ -28,8 +28,8 @@ import math_op_patch ...@@ -28,8 +28,8 @@ import math_op_patch
from math_op_patch import * from math_op_patch import *
import detection import detection
from detection import * from detection import *
import metric import metric_op
from metric import * from metric_op import *
from learning_rate_scheduler import * from learning_rate_scheduler import *
__all__ = [] __all__ = []
...@@ -41,5 +41,5 @@ __all__ += control_flow.__all__ ...@@ -41,5 +41,5 @@ __all__ += control_flow.__all__
__all__ += ops.__all__ __all__ += ops.__all__
__all__ += device.__all__ __all__ += device.__all__
__all__ += detection.__all__ __all__ += detection.__all__
__all__ += metric.__all__ __all__ += metric_op.__all__
__all__ += learning_rate_scheduler.__all__ __all__ += learning_rate_scheduler.__all__
...@@ -126,7 +126,7 @@ def auc(input, label, curve='ROC', num_thresholds=200): ...@@ -126,7 +126,7 @@ def auc(input, label, curve='ROC', num_thresholds=200):
topk_out, topk_indices = nn.topk(input, k=k) topk_out, topk_indices = nn.topk(input, k=k)
auc_out = helper.create_tmp_variable(dtype="float32") auc_out = helper.create_tmp_variable(dtype="float32")
helper.append_op( helper.append_op(
type="accuracy", type="auc",
inputs={ inputs={
"Out": [topk_out], "Out": [topk_out],
"Indices": [topk_indices], "Indices": [topk_indices],
......
...@@ -198,10 +198,7 @@ def fc(input, ...@@ -198,10 +198,7 @@ def fc(input,
else: else:
pre_bias = helper.create_tmp_variable(dtype) pre_bias = helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type="sum", type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias})
inputs={"X": mul_results},
outputs={"Out": pre_bias},
attrs={"use_mkldnn": use_mkldnn})
# add bias # add bias
pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims) pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims)
# add activation # add activation
......
...@@ -230,11 +230,7 @@ def sums(input, out=None): ...@@ -230,11 +230,7 @@ def sums(input, out=None):
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
if out is None: if out is None:
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op( helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out})
type='sum',
inputs={'X': input},
outputs={'Out': out},
attrs={'use_mkldnn': False})
return out return out
......
...@@ -23,6 +23,8 @@ import warnings ...@@ -23,6 +23,8 @@ import warnings
__all__ = [ __all__ = [
'MetricBase', 'MetricBase',
'CompositeMetric', 'CompositeMetric',
'Precision',
'Recall',
'Accuracy', 'Accuracy',
'ChunkEvaluator', 'ChunkEvaluator',
'EditDistance', 'EditDistance',
...@@ -46,33 +48,34 @@ def _is_number_or_matrix_(var): ...@@ -46,33 +48,34 @@ def _is_number_or_matrix_(var):
class MetricBase(object): class MetricBase(object):
""" """
Base Class for all evaluators Base Class for all Metrics.
MetricBase define a group of interfaces for the
model evaluation methods. Metrics accumulate metric states between
consecutive minibatches, at every minibatch, use update
interface to add current minibatch value to global states.
Use eval to compute accumative metric value from last reset()
or from scratch on.
If you need to custom a new metric, please inherit from MetricBase and
custom implementation.
Args: Args:
name(str): The name of evaluator. such as, "accuracy". Used for generate name(str): The name of metric instance. such as, "accuracy".
temporary variable name. It needed if you want to distinct different metrics in a model.
Interface:
Note(*) : the states is the attributes who not has _ prefix.
get_config(): print current states and configuration
reset(): clear the states. If the Metrics states type is not (int, float, np.ndarray),
Please override this method.
update(): update states at every minibatch
eval(): get metric evaluation in numpy type.
""" """
def __init__(self, name, **kwargs): def __init__(self, name):
self._name = str(name) if name != None else self.__class__.__name__ self._name = str(name) if name != None else self.__class__.__name__
self._kwargs = kwargs if kwargs != None else dict()
self.reset()
def __str__(self): def __str__(self):
return self._name return self._name
def reset(self): def reset(self):
""" """
states is the attributes who not has _ prefix. reset clear the states of metrics. By default, the states
reset the states of metrics. are the members who do not has _ prefix, reset set them to inital states.
If you violate the implicit name rule, please also custom the reset
interface.
""" """
states = { states = {
attr: value attr: value
...@@ -90,61 +93,231 @@ class MetricBase(object): ...@@ -90,61 +93,231 @@ class MetricBase(object):
setattr(self, attr, None) setattr(self, attr, None)
def get_config(self): def get_config(self):
"""
Get the metric and current states.
The states are the members who do not has "_" prefix.
Args:
None
Returns:
dict: a dict of metric and states
"""
states = { states = {
attr: value attr: value
for attr, value in self.__dict__.iteritems() for attr, value in self.__dict__.iteritems()
if not attr.startswith("_") if not attr.startswith("_")
} }
config = copy.deepcopy(self._kwargs) config = {}
config.update({"name": self._name, "states": copy.deepcopy(states)}) config.update({"name": self._name, "states": copy.deepcopy(states)})
return config return config
def update(self): def update(self, preds, labels):
raise NotImplementedError() """
Updates the metric states at every minibatch.
One user can compute the minibatch metric via pure Python, or
via a c++ operator.
Args:
preds(numpy.array): the predictions of current minibatch
labels(numpy.array): the labels of current minibatch, if the label is one-hot
or soft-label, should custom the corresponding update rule.
"""
raise NotImplementedError(
"Should not use it directly, please extend it.")
def eval(self): def eval(self):
raise NotImplementedError() """
Evalute the current metrics based the accumulated states.
Returns:
float|list(float)|numpy.array: the metrics via Python.
"""
raise NotImplementedError(
"Should not use it directly, please extend it.")
class CompositeMetric(MetricBase): class CompositeMetric(MetricBase):
""" """
Compute multiple metrics in each minibatch. Composite multiple metrics in one instance.
for example, merge F1, accuracy, recall into one Metric. for example, merge F1, accuracy, recall into one Metric.
Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
comp = fluid.metrics.CompositeMetric()
acc = fluid.metrics.Precision()
recall = fluid.metrics.Recall()
comp.add_metric(acc)
comp.add_metric(recall)
for pass in range(PASSES):
comp.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
comp.update(preds=preds, labels=labels)
numpy_acc, numpy_recall = comp.eval()
""" """
def __init__(self, name=None, **kwargs): def __init__(self, name=None):
super(CompositeMetric, self).__init__(name, kwargs) super(CompositeMetric, self).__init__(name)
self._metrics = [] self._metrics = []
def add_metric(self, metric): def add_metric(self, metric):
"""
add one metric instance to CompositeMetric.
Args:
metric: a instance of MetricBase.
"""
if not isinstance(metric, MetricBase): if not isinstance(metric, MetricBase):
raise ValueError("SubMetric should be inherit from MetricBase.") raise ValueError("SubMetric should be inherit from MetricBase.")
self._metrics.append(metric) self._metrics.append(metric)
def update(self, preds, labels):
"""
Update every metrics in sequence.
Args:
preds(numpy.array): the predictions of current minibatch
labels(numpy.array): the labels of current minibatch, if the label is one-hot
or soft-label, should custom the corresponding update rule.
"""
for m in self._metrics:
ans.append(m.update(preds, labels))
def eval(self): def eval(self):
"""
Evaluate every metrics in sequence.
Returns:
list(float|numpy.array): a list of metrics value in Python.
"""
ans = [] ans = []
for m in self._metrics: for m in self._metrics:
ans.append(m.eval()) ans.append(m.eval())
return ans return ans
class Precision(MetricBase):
"""
Precision (also called positive predictive value) is the fraction of
relevant instances among the retrieved instances.
https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers
Note Precision is different with Accuracy in binary classifiers.
accuracy = true positive / total instances
precision = true positive / all positive instance
Examples:
.. code-block:: python
metric = fluid.metrics.Precision()
for pass in range(PASSES):
metric.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels)
numpy_precision = metric.eval()
"""
def __init__(self, name=None):
super(Precision, self).__init__(name)
self.tp = 0 # true positive
self.fp = 0 # false positive
def update(self, preds, labels):
if not _is_numpy_(preds):
raise ValueError("The 'preds' must be a numpy ndarray.")
if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.")
sample_num = labels[0]
for i in range(sample_num):
pred = preds[i].astype("int32")
label = labels[i]
if label == 1:
if pred == label:
self.tp += 1
else:
self.fp += 1
def eval(self):
ap = self.tp + self.fp
return float(self.tp) / ap if ap != 0 else .0
class Recall(MetricBase):
"""
Recall (also known as sensitivity) is the fraction of
relevant instances that have been retrieved over the
total amount of relevant instances
https://en.wikipedia.org/wiki/Precision_and_recall
Examples:
.. code-block:: python
metric = fluid.metrics.Recall()
for pass in range(PASSES):
metric.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels)
numpy_recall = metric.eval()
"""
def __init__(self, name=None):
super(Recall, self).__init__(name)
self.tp = 0 # true positive
self.fn = 0 # false negtive
def update(self, preds, labels):
if not _is_numpy_(preds):
raise ValueError("The 'preds' must be a numpy ndarray.")
if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.")
sample_num = labels[0]
for i in range(sample_num):
pred = preds[i].astype("int32")
label = labels[i]
if label == 1:
if pred == label:
self.tp += 1
else:
if pred != label:
self.fn += 1
def eval(self):
recall = self.tp + self.fn
return float(self.tp) / recall if recall != 0 else .0
class Accuracy(MetricBase): class Accuracy(MetricBase):
""" """
Accumulate the accuracy from minibatches and compute the average accuracy Accumulate the accuracy from minibatches and compute the average accuracy
for every pass. for every pass.
https://en.wikipedia.org/wiki/Accuracy_and_precision
Args: Args:
name: the metrics name name: the metrics name
Example: Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
minibatch_accuracy = fluid.layers.accuracy(pred, label) minibatch_accuracy = fluid.layers.accuracy(pred, label)
accuracy_evaluator = fluid.metrics.Accuracy() accuracy_evaluator = fluid.metrics.Accuracy()
for epoch in PASS_NUM: for pass in range(PASSES):
accuracy_evaluator.reset() accuracy_evaluator.reset()
for data in batches: for data in train_reader():
batch_size = data[0]
loss = exe.run(fetch_list=[cost, minibatch_accuracy]) loss = exe.run(fetch_list=[cost, minibatch_accuracy])
accuracy_evaluator.update(value=minibatch_accuracy, weight=batches) accuracy_evaluator.update(value=minibatch_accuracy, weight=batch_size)
accuracy = accuracy_evaluator.eval() numpy_acc = accuracy_evaluator.eval()
""" """
def __init__(self, name=None): def __init__(self, name=None):
...@@ -153,6 +326,13 @@ class Accuracy(MetricBase): ...@@ -153,6 +326,13 @@ class Accuracy(MetricBase):
self.weight = .0 self.weight = .0
def update(self, value, weight): def update(self, value, weight):
"""
Update minibatch states.
Args:
value(float|numpy.array): accuracy of one minibatch.
weight(int|float): batch size.
"""
if not _is_number_or_matrix_(value): if not _is_number_or_matrix_(value):
raise ValueError( raise ValueError(
"The 'value' must be a number(int, float) or a numpy ndarray.") "The 'value' must be a number(int, float) or a numpy ndarray.")
...@@ -163,9 +343,8 @@ class Accuracy(MetricBase): ...@@ -163,9 +343,8 @@ class Accuracy(MetricBase):
def eval(self): def eval(self):
if self.weight == 0: if self.weight == 0:
raise ValueError( raise ValueError("There is no data in Accuracy Metrics. \
"There is no data in Accuracy Metrics. Please check layers.accuracy output has added to Accuracy." Please check layers.accuracy output has added to Accuracy.")
)
return self.value / self.weight return self.value / self.weight
...@@ -174,6 +353,25 @@ class ChunkEvaluator(MetricBase): ...@@ -174,6 +353,25 @@ class ChunkEvaluator(MetricBase):
Accumulate counter numbers output by chunk_eval from mini-batches and Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter compute the precision recall and F1-score using the accumulated counter
numbers. numbers.
For some basics of chunking, please refer to
'Chunking with Support Vector Machines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>'.
ChunkEvalEvaluator computes the precision, recall, and F1-score of chunk detection,
and supports IOB, IOE, IOBES and IO (also known as plain) tagging schemes.
Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval(
input=pred,
label=label)
metric = fluid.metrics.ChunkEvaluator()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(num_infer_chunks, num_label_chunks, num_correct_chunks)
numpy_precision, numpy_recall, numpy_f1 = metric.eval()
""" """
def __init__(self, name=None): def __init__(self, name=None):
...@@ -183,9 +381,17 @@ class ChunkEvaluator(MetricBase): ...@@ -183,9 +381,17 @@ class ChunkEvaluator(MetricBase):
self.num_correct_chunks = 0 self.num_correct_chunks = 0
def update(self, num_infer_chunks, num_label_chunks, num_correct_chunks): def update(self, num_infer_chunks, num_label_chunks, num_correct_chunks):
"""
Update the states based on the layers.chunk_eval() ouputs.
Args:
num_infer_chunks(int|numpy.array): The number of chunks in Inference on the given minibatch.
num_label_chunks(int|numpy.array): The number of chunks in Label on the given mini-batch.
num_correct_chunks(int|float|numpy.array): The number of chunks both in Inference and Label on the
given mini-batch.
"""
if not _is_number_or_matrix_(num_infer_chunks): if not _is_number_or_matrix_(num_infer_chunks):
raise ValueError( raise ValueError(
"The 'num_infer_chunks' must be a number(int, float) or a numpy ndarray." "The 'num_infer_chunks' must be a number(int) or a numpy ndarray."
) )
if not _is_number_or_matrix_(num_label_chunks): if not _is_number_or_matrix_(num_label_chunks):
raise ValueError( raise ValueError(
...@@ -212,20 +418,27 @@ class ChunkEvaluator(MetricBase): ...@@ -212,20 +418,27 @@ class ChunkEvaluator(MetricBase):
class EditDistance(MetricBase): class EditDistance(MetricBase):
""" """
Edit distance is a way of quantifying how dissimilar two strings
(e.g., words) are to one another by counting the minimum number
of operations required to transform one string into the other.
Refer to https://en.wikipedia.org/wiki/Edit_distance
Accumulate edit distance sum and sequence number from mini-batches and Accumulate edit distance sum and sequence number from mini-batches and
compute the average edit_distance and instance error of all batches. compute the average edit_distance and instance error of all batches.
Args: Args:
name: the metrics name name: the metrics name
Example: Examples:
edit_distance_metrics = fluid.layers.edit_distance(input, label) .. code-block:: python
distances, seq_num = fluid.layers.edit_distance(input, label)
distance_evaluator = fluid.metrics.EditDistance() distance_evaluator = fluid.metrics.EditDistance()
for epoch in PASS_NUM: for epoch in PASS_NUM:
distance_evaluator.reset() distance_evaluator.reset()
for data in batches: for data in batches:
loss = exe.run(fetch_list=[cost] + list(edit_distance_metrics)) loss = exe.run(fetch_list=[cost] + list(edit_distance_metrics))
distance_evaluator.update(*edit_distance_metrics) distance_evaluator.update(distances, seq_num)
distance, instance_error = distance_evaluator.eval() distance, instance_error = distance_evaluator.eval()
In the above example: In the above example:
...@@ -264,16 +477,38 @@ class EditDistance(MetricBase): ...@@ -264,16 +477,38 @@ class EditDistance(MetricBase):
class DetectionMAP(MetricBase): class DetectionMAP(MetricBase):
""" """
Calculate the detection mean average precision (mAP). Calculate the detection mean average precision (mAP).
mAP is the metric to measure the accuracy of object detectors
like Faster R-CNN, SSD, etc.
It is the average of the maximum precisions at different recall values.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
TODO (Dang Qingqing): update the following doc.
The general steps are as follows: The general steps are as follows:
1. calculate the true positive and false positive according to the input 1. calculate the true positive and false positive according to the input
of detection and labels. of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'. 2. calculate mAP value, support two versions: '11 point' and 'integral'.
Please get more information from the following articles: Examples:
https://sanchom.wordpress.com/tag/average-precision/ .. code-block:: python
https://arxiv.org/abs/1512.02325
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
batch_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
metric = fluid.metrics.DetectionMAP()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, batch_map])
batch_size = data[0]
metric.update(value=batch_map, weight=batch_size)
numpy_map = metric.eval()
""" """
def __init__(self, name=None): def __init__(self, name=None):
...@@ -302,8 +537,9 @@ class DetectionMAP(MetricBase): ...@@ -302,8 +537,9 @@ class DetectionMAP(MetricBase):
class Auc(MetricBase): class Auc(MetricBase):
""" """
Auc Metrics which adapts to binary classification. Auc metric adapts to the binary classification.
Need to note that auc metrics compute the value via Python natively. Refer to https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
Need to note that auc metric compute the value via Python natively.
If you concern the speed, please use the fluid.layers.auc instead. If you concern the speed, please use the fluid.layers.auc instead.
The `auc` function creates four local variables, `true_positives`, The `auc` function creates four local variables, `true_positives`,
...@@ -322,6 +558,16 @@ class Auc(MetricBase): ...@@ -322,6 +558,16 @@ class Auc(MetricBase):
curve. curve.
"NOTE: only implement the ROC curve type via Python now." "NOTE: only implement the ROC curve type via Python now."
Examples:
.. code-block:: python
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
metric = fluid.metrics.Auc()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds, labels)
numpy_auc = metric.eval()
""" """
def __init__(self, name, curve='ROC', num_thresholds=200): def __init__(self, name, curve='ROC', num_thresholds=200):
...@@ -334,10 +580,10 @@ class Auc(MetricBase): ...@@ -334,10 +580,10 @@ class Auc(MetricBase):
self.tn_list = np.zeros((num_thresholds, )) self.tn_list = np.zeros((num_thresholds, ))
self.fp_list = np.zeros((num_thresholds, )) self.fp_list = np.zeros((num_thresholds, ))
def update(self, labels, predictions, axis=1): def update(self, preds, labels):
if not _is_numpy_(labels): if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.") raise ValueError("The 'labels' must be a numpy ndarray.")
if not _is_numpy_(predictions): if not _is_numpy_(preds):
raise ValueError("The 'predictions' must be a numpy ndarray.") raise ValueError("The 'predictions' must be a numpy ndarray.")
kepsilon = 1e-7 # to account for floating point imprecisions kepsilon = 1e-7 # to account for floating point imprecisions
......
...@@ -20,15 +20,12 @@ from op_test import OpTest ...@@ -20,15 +20,12 @@ from op_test import OpTest
class TestSumOp(OpTest): class TestSumOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sum" self.op_type = "sum"
self.use_mkldnn = False
self.init_kernel_type()
x0 = np.random.random((3, 4)).astype('float32') x0 = np.random.random((3, 4)).astype('float32')
x1 = np.random.random((3, 4)).astype('float32') x1 = np.random.random((3, 4)).astype('float32')
x2 = np.random.random((3, 4)).astype('float32') x2 = np.random.random((3, 4)).astype('float32')
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]} self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
y = x0 + x1 + x2 y = x0 + x1 + x2
self.outputs = {'Out': y} self.outputs = {'Out': y}
self.attrs = {'use_mkldnn': self.use_mkldnn}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -36,9 +33,6 @@ class TestSumOp(OpTest): ...@@ -36,9 +33,6 @@ class TestSumOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['x0'], 'Out') self.check_grad(['x0'], 'Out')
def init_kernel_type(self):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -880,8 +880,7 @@ class DistributeTranspiler(object): ...@@ -880,8 +880,7 @@ class DistributeTranspiler(object):
table_opt_block.append_op( table_opt_block.append_op(
type="sum", type="sum",
inputs={"X": pserver_side_table_grad_list}, inputs={"X": pserver_side_table_grad_list},
outputs={"Out": [grad_var]}, outputs={"Out": [grad_var]})
attrs={"use_mkldnn": False})
else: else:
# in async_mode, for table gradient, it also need to be splited to each parameter server # in async_mode, for table gradient, it also need to be splited to each parameter server
origin_grad_name = grad_var.name origin_grad_name = grad_var.name
...@@ -1113,8 +1112,7 @@ class DistributeTranspiler(object): ...@@ -1113,8 +1112,7 @@ class DistributeTranspiler(object):
optimize_block.append_op( optimize_block.append_op(
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
outputs={"Out": merged_var}, outputs={"Out": merged_var})
attrs={"use_mkldnn": False})
# TODO(panyx0718): What if it's SELECTED_ROWS. # TODO(panyx0718): What if it's SELECTED_ROWS.
if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS: if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
optimize_block.append_op( optimize_block.append_op(
......
...@@ -12,15 +12,42 @@ ...@@ -12,15 +12,42 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import sys
import re
from test_sum_op import TestSumOp
def escape(input):
o = input.replace("\n", "")
o = o.replace("\r", "")
return o
class TestMKLDNN(TestSumOp):
def init_kernel_type(self):
self.use_mkldnn = True
def main():
usage = """Usage:
1. Download the Paddle_PR_CI_*.log from TeamCity
2. run: python check_ctest_hung.py Paddle_PR_CI_*.log
3. If there is hung ctest, the result likes:
Diff: set(['test_parallel_executor_crf'])
"""
if len(sys.argv) < 2:
print(usage)
exit(0)
if __name__ == '__main__': logfile = sys.argv[1]
unittest.main() started = set()
passed = set()
with open(logfile, "r") as fn:
for l in fn.readlines():
if l.find("Test ") != -1 and \
l.find("Passed") != -1:
m = re.search("Test\s+#[0-9]*\:\s([a-z0-9_]+)", escape(l))
passed.add(m.group(1))
if l.find("Start ") != -1:
start_parts = escape(l).split(" ")
m = re.search("Start\s+[0-9]+\:\s([a-z0-9_]+)", escape(l))
started.add(m.group(1))
print "Diff: ", started - passed
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册