diff --git a/doc/api/v2/fluid/nets.rst b/doc/api/v2/fluid/nets.rst
index f6b1cb4ba10659fb336899f08376c265c67290f1..500019bc507f859c4c91de5d322a82eb1e78e2de 100644
--- a/doc/api/v2/fluid/nets.rst
+++ b/doc/api/v2/fluid/nets.rst
@@ -26,8 +26,8 @@ glu
:noindex:
-dot_product_attention
----------------------
-.. autofunction:: paddle.v2.fluid.nets.dot_product_attention
+scaled_dot_product_attention
+----------------------------
+.. autofunction:: paddle.v2.fluid.nets.scaled_dot_product_attention
:noindex:
diff --git a/doc/design/dist_refactor/distributed_architecture.md b/doc/design/dist_refactor/distributed_architecture.md
index 3a741f95866fb6c301ca9097af7916281f2278cf..9368c5780dc922953f38bf0f86d9f797a4a8a6fe 100644
--- a/doc/design/dist_refactor/distributed_architecture.md
+++ b/doc/design/dist_refactor/distributed_architecture.md
@@ -152,12 +152,12 @@ for data in train_reader():
`JobDesc` object describe the distributed job resource specification to run on
Cluster environment.
-
+
`RemoteExecutor.run` sends the `ProgramDesc` and
[TrainingJob](https://github.com/PaddlePaddle/cloud/blob/develop/doc/autoscale/README.md#training-job-resource)
to a server in the cluster which executes `RemoteExecutor.listen`. This server is responsible
-to start the final Kubernetes Jobs to run the different role of `ProgramDesc`.
+to start the final Kubernetes Jobs to run the different role of `ProgramDesc` from `ConfigMap`.
### Placement Algorithm
diff --git a/doc/design/dist_refactor/src/remote_executor.graffle b/doc/design/dist_refactor/src/remote_executor.graffle
index ce2c18fee5687732053c48af9c8c290a994a8090..41b2067311694b56d211a4f32d1b76884eeffd2d 100644
Binary files a/doc/design/dist_refactor/src/remote_executor.graffle and b/doc/design/dist_refactor/src/remote_executor.graffle differ
diff --git a/doc/design/dist_refactor/src/remote_executor.png b/doc/design/dist_refactor/src/remote_executor.png
index 6be4b1841b99efdb59557975485d0387f422308c..744e2fb2e0f1bbe058e991ba7b2a09000965ee79 100644
Binary files a/doc/design/dist_refactor/src/remote_executor.png and b/doc/design/dist_refactor/src/remote_executor.png differ
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index b83007ac3bc6ed8713ca65fddabccfd292a2732f..8d9260811a8c9274dcaade9b090bab727d1952ca 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -74,7 +74,8 @@ cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
-cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table)
+cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
+framework_proto backward glog lod_rank_table profiler)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc
index bd58c0a7f8161f6b45f2b500f3685e4028d97e96..c28ffefdd0872238299cdbb0653ee17cdad61699 100644
--- a/paddle/framework/executor.cc
+++ b/paddle/framework/executor.cc
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
+#include "paddle/platform/profiler.h"
DECLARE_bool(do_memory_benchmark);
DEFINE_bool(check_nan_inf, false,
@@ -117,6 +118,10 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(4) << op->DebugStringEx(local_scope);
+
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ platform::RecordEvent record_event(op->Type(), pool.Get(place_));
+
op->Run(*local_scope, place_);
VLOG(3) << op->DebugStringEx(local_scope);
if (FLAGS_do_memory_benchmark) {
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index ba83667ebc9a89c37f77a7f71e6df90b54723cc0..aab02f16849582db4b41087046b810463a855e1a 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -991,8 +991,10 @@ TEST(Layer, SequenceLastInstanceLayer) {
"seqlastins",
"non-seq",
-1); // hasSubseq seqlastins to non-seq
- testDegradeLayer(
- true, "seqlastins", "seq", -1); // hasSubseq seqlastins to seq
+ testDegradeLayer(true,
+ "seqlastins",
+ "seq",
+ -1); // hasSubseq seqlastins to seq
}
TEST(Layer, AverageLayer) {
@@ -1001,8 +1003,10 @@ TEST(Layer, AverageLayer) {
"average",
"non-seq",
5); // seq average to a shorten seq, stride window = 5
- testDegradeLayer(
- true, "average", "non-seq", -1); // hasSubseq average to non-seq
+ testDegradeLayer(true,
+ "average",
+ "non-seq",
+ -1); // hasSubseq average to non-seq
testDegradeLayer(true, "average", "seq", -1); // hasSubseq average to seq
}
@@ -1287,8 +1291,9 @@ TEST(Layer, PoolLayer) {
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
- testPoolLayer2(
- "cudnn-avg-incl-pad-pool", /* trans= */ false, /* useGpu= */ true);
+ testPoolLayer2("cudnn-avg-incl-pad-pool",
+ /* trans= */ false,
+ /* useGpu= */ true);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
#endif
}
@@ -2431,18 +2436,21 @@ TEST(Layer, test3DDeConvLayer) {
}
TEST(Layer, ScaleShiftLayer) {
- const size_t batchSize = 16;
- const size_t size = 32;
- TestConfig config;
- config.layerConfig.set_type("scale_shift");
- config.layerConfig.set_size(size);
- config.biasSize = 1;
- config.inputDefs.push_back(
- {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1});
- config.layerConfig.add_inputs();
- for (auto useGpu : {false, true}) {
- testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false);
- }
+ // FIXME: Disable ScaleShiftLayer because it is not stable.
+ // https://github.com/PaddlePaddle/Paddle/issues/7781
+ return;
+ // const size_t batchSize = 16;
+ // const size_t size = 32;
+ // TestConfig config;
+ // config.layerConfig.set_type("scale_shift");
+ // config.layerConfig.set_size(size);
+ // config.biasSize = 1;
+ // config.inputDefs.push_back(
+ // {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1});
+ // config.layerConfig.add_inputs();
+ // for (auto useGpu : {false, true}) {
+ // testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false);
+ // }
}
TEST(Layer, ScaleSubRegionLayer) {
diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc
index 1e41587c418fb0ce4e452d5c6735c54e2d42f798..d699dabf2fb982f267c4869180efaf0e600eb46c 100644
--- a/paddle/operators/detail/grpc_client.cc
+++ b/paddle/operators/detail/grpc_client.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
+#include "paddle/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
@@ -22,25 +23,32 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
- sendrecv::VariableMessage req;
- auto* var = scope.FindVar(var_name);
- SerializeToMessage(var_name, var, ctx, &req);
-
- // varhandle
- VarHandle var_h;
- var_h.ep = ep;
- var_h.scope = &scope;
- var_h.name = var_name;
- var_h.ctx = &ctx;
-
- // stub context
- auto ch = GetChannel(ep);
- SendProcessor* s = new SendProcessor(ch);
- s->Prepare(var_h, time_out);
- s->response_call_back_ = NULL;
-
- auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
- rpc->Finish(&s->reply_, &s->status_, (void*)s);
+ const platform::DeviceContext* p_ctx = &ctx;
+ const std::string ep_val = ep;
+ const std::string var_name_val = var_name;
+ const framework::Scope* p_scope = &scope;
+ const auto ch = GetChannel(ep_val);
+
+ framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
+ auto* var = p_scope->FindVar(var_name_val);
+ sendrecv::VariableMessage req;
+ SerializeToMessage(var_name_val, var, *p_ctx, &req);
+
+ // varhandle
+ VarHandle var_h;
+ var_h.ep = ep_val;
+ var_h.scope = p_scope;
+ var_h.name = var_name_val;
+ var_h.ctx = p_ctx;
+
+ // stub context
+ SendProcessor* s = new SendProcessor(ch);
+ s->Prepare(var_h, time_out);
+ s->response_call_back_ = NULL;
+
+ auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
+ rpc->Finish(&s->reply_, &s->status_, (void*)s);
+ });
req_count_++;
@@ -50,8 +58,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& ret_msg) {
auto* outvar = var_h.scope->FindVar(var_h.name);
-
- std::istringstream iss(ret_msg.serialized());
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
}
@@ -60,24 +66,31 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
- sendrecv::VariableMessage req;
- req.set_varname(var_name);
-
- // varhandle
- VarHandle var_h;
- var_h.ep = ep;
- var_h.scope = &scope;
- var_h.name = var_name;
- var_h.ctx = &ctx;
-
- // stub context
- auto ch = GetChannel(ep);
- GetProcessor* s = new GetProcessor(ch);
- s->Prepare(var_h, time_out);
- s->response_call_back_ = ProcGetResponse;
-
- auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
- rpc->Finish(&s->reply_, &s->status_, (void*)s);
+ const platform::DeviceContext* p_ctx = &ctx;
+ const std::string ep_val = ep;
+ const std::string var_name_val = var_name;
+ const framework::Scope* p_scope = &scope;
+ const auto ch = GetChannel(ep_val);
+
+ framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {
+ sendrecv::VariableMessage req;
+ req.set_varname(var_name_val);
+
+ // varhandle
+ VarHandle var_h;
+ var_h.ep = ep_val;
+ var_h.scope = p_scope;
+ var_h.name = var_name_val;
+ var_h.ctx = p_ctx;
+
+ // stub context
+ GetProcessor* s = new GetProcessor(ch);
+ s->Prepare(var_h, time_out);
+ s->response_call_back_ = ProcGetResponse;
+
+ auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
+ rpc->Finish(&s->reply_, &s->status_, (void*)s);
+ });
req_count_++;
@@ -85,19 +98,31 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
}
bool RPCClient::Wait() {
- bool ok = true;
+ if (req_count_ <= 0) {
+ return true;
+ }
- while (true) {
- if (req_count_ <= 0) {
- break;
- }
+ std::vector a(req_count_);
+ std::vector> waits(req_count_);
- if (!Proceed()) {
+ for (int i = 0; i < req_count_; i++) {
+ waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); });
+ }
+
+ for (int i = 0; i < req_count_; i++) {
+ waits[i].wait();
+ }
+
+ int last_req_count = req_count_;
+ req_count_ = 0;
+
+ for (int i = 0; i < last_req_count; i++) {
+ if (!a[i]) {
return false;
}
}
- return ok;
+ return true;
}
bool RPCClient::Proceed() {
@@ -124,7 +149,6 @@ bool RPCClient::Proceed() {
c->Process();
delete c;
- req_count_--;
return true;
}
diff --git a/paddle/operators/im2sequence_op.h b/paddle/operators/im2sequence_op.h
index aeb810015134babc132909b3e820fa8391233b1c..f33aec71a92a65ec0e4114530d70e36c9dc1be04 100644
--- a/paddle/operators/im2sequence_op.h
+++ b/paddle/operators/im2sequence_op.h
@@ -79,7 +79,7 @@ class Im2SequenceKernel : public framework::OpKernel {
framework::LoD lod(1);
lod[0].reserve(batch_size + 1);
for (int i = 0, offset = 0; i < batch_size + 1; ++i) {
- lod[0][i] = offset;
+ lod[0].push_back(offset);
offset += output_height * output_width;
}
out->set_lod(lod);
diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc
index 58e8fd6124d8c076337ae9bb2f5103e7a3cb7ff0..b9743a5df1092917d13a50aa20ea7e7c52b8d151 100644
--- a/paddle/operators/reshape_op.cc
+++ b/paddle/operators/reshape_op.cc
@@ -90,14 +90,10 @@ Reshape Operator.
Reshape Input(X) into the shape specified by Attr(shape).
An example:
-Given a 2-D tensor X with 2 rows and 2 columns
-
- [[1, 2], [3, 4]]
+Given a 2-D tensor X with 2 rows and 2 columns : [[1, 2], [3, 4]]
and target shape = [1, 4], the reshape operator will transform
-the tensor X into a 2-D tensor:
-
- [[1, 2, 3, 4]]
+the tensor X into a 2-D tensor: [[1, 2, 3, 4]]
One dimension in the target shape can be set -1, representing that its
size is unknown. In this case, the real dimension will be infered from
diff --git a/paddle/platform/profiler.cc b/paddle/platform/profiler.cc
index 7e2e2d968ef877f6aa8b87ab8f044e89574dffa9..2a8afc940393baaaa939471f50f2d5c63edd6a84 100644
--- a/paddle/platform/profiler.cc
+++ b/paddle/platform/profiler.cc
@@ -47,16 +47,16 @@ inline uint64_t GetTimeInNsec() {
}
Event::Event(EventKind kind, std::string name, uint32_t thread_id,
- DeviceContext* dev_ctx)
+ const DeviceContext* dev_ctx)
: kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false) {
#ifdef PADDLE_WITH_CUDA
- auto* cuda_dev_ctx = static_cast(dev_ctx);
- if (cuda_dev_ctx) {
+ has_cuda_ = dev_ctx ? platform::is_gpu_place(dev_ctx->GetPlace()) : false;
+ if (has_cuda_) {
+ auto* cuda_dev_ctx = static_cast(dev_ctx);
PADDLE_ENFORCE(cudaGetDevice(&device_));
PADDLE_ENFORCE(cudaEventCreate(&event_));
auto stream = cuda_dev_ctx->stream();
PADDLE_ENFORCE(cudaEventRecord(event_, stream));
- has_cuda_ = true;
}
#endif
cpu_ns_ = GetTimeInNsec();
@@ -114,19 +114,20 @@ inline EventList& GetEventList() {
return *g_event_list;
}
-void Mark(const std::string& name, DeviceContext* dev_ctx) {
+void Mark(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kMark, name, g_thread_id, dev_ctx);
}
-void PushEvent(const std::string& name, DeviceContext* dev_ctx) {
+void PushEvent(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPushRange, name, g_thread_id, dev_ctx);
}
-void PopEvent(const std::string& name, DeviceContext* dev_ctx) {
+void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx);
}
-RecordEvent::RecordEvent(const std::string& name, DeviceContext* dev_ctx) {
+RecordEvent::RecordEvent(const std::string& name,
+ const DeviceContext* dev_ctx) {
if (g_state == ProfilerState::kDisabled) return;
dev_ctx_ = dev_ctx;
name_ = name;
@@ -155,6 +156,7 @@ void EnableProfiler(ProfilerState state) {
DeviceContext* dev_ctx = new CUDADeviceContext(CUDAPlace(d));
Mark("_cuda_startup_", dev_ctx);
dev_ctx->Wait();
+ delete dev_ctx;
});
}
}
@@ -163,14 +165,17 @@ void EnableProfiler(ProfilerState state) {
Mark("_start_profiler_", nullptr);
}
-std::vector> DisableProfiler() {
- PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
- "Can't disable profiling, since it's not starting.");
- // Mark the profiling stop.
- Mark("_stop_profiler_", nullptr);
- g_state = ProfilerState::kDisabled;
- std::vector> result;
+void ResetProfiler() {
std::lock_guard guard(g_all_event_lists_mutex);
+ for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end();
+ ++it) {
+ (*it)->Clear();
+ }
+}
+
+std::vector> GetAllEvents() {
+ std::lock_guard guard(g_all_event_lists_mutex);
+ std::vector> result;
for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end();
++it) {
result.emplace_back((*it)->Reduce());
@@ -178,6 +183,18 @@ std::vector> DisableProfiler() {
return result;
}
+void DisableProfiler(EventSortingKey sorted_key) {
+ PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
+ "Can't disable profiling, since it's not starting.");
+ // Mark the profiling stop.
+ Mark("_stop_profiler_", nullptr);
+ g_state = ProfilerState::kDisabled;
+
+ std::vector> all_events = GetAllEvents();
+ ParseEvents(all_events, sorted_key);
+ ResetProfiler();
+}
+
void ParseEvents(std::vector>& events,
EventSortingKey sorted_by) {
if (g_profiler_place == "") return;
@@ -291,12 +308,12 @@ void ParseEvents(std::vector>& events,
}
// Print report
- PrintProfilingReport(events_table, sorted_domain, max_name_width + 4, 12);
+ PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12);
}
-void PrintProfilingReport(std::vector>& events_table,
- std::string& sorted_domain, const size_t name_width,
- const size_t data_width) {
+void PrintProfiler(std::vector>& events_table,
+ std::string& sorted_domain, const size_t name_width,
+ const size_t data_width) {
// Output header information
std::cout << "\n------------------------->"
<< " Profiling Report "
diff --git a/paddle/platform/profiler.h b/paddle/platform/profiler.h
index 6df48ef8806e865f473b4317ac0283863c3c6f64..8de1e6ad296d1e15c1659ccf431f1d5013eb608c 100644
--- a/paddle/platform/profiler.h
+++ b/paddle/platform/profiler.h
@@ -29,7 +29,7 @@ class Event {
// The DeviceContext is used to get the cuda stream.
// If CPU profiling mode, can pass nullptr.
Event(EventKind kind, std::string name, uint32_t thread_id,
- DeviceContext* dev_ctx);
+ const DeviceContext* dev_ctx);
std::string kind() const;
std::string name() const { return name_; }
@@ -84,6 +84,8 @@ struct EventList {
return result;
}
+ void Clear() { event_blocks.clear(); }
+
std::forward_list> event_blocks;
};
@@ -93,29 +95,26 @@ enum ProfilerState {
kCUDA, // GPU profiling state
};
-void Mark(const std::string& name, DeviceContext* dev_ctx);
+void Mark(const std::string& name, const DeviceContext* dev_ctx);
-void PushEvent(const std::string& name, DeviceContext* dev_ctx);
+void PushEvent(const std::string& name, const DeviceContext* dev_ctx);
-void PopEvent(const std::string& name, DeviceContext* dev_ctx);
+void PopEvent(const std::string& name, const DeviceContext* dev_ctx);
struct RecordEvent {
- explicit RecordEvent(const std::string& name, DeviceContext* dev_ctx);
+ explicit RecordEvent(const std::string& name, const DeviceContext* dev_ctx);
~RecordEvent();
// The device context is used by Event to get the current cuda stream.
- DeviceContext* dev_ctx_;
+ const DeviceContext* dev_ctx_;
// Event name
std::string name_;
};
-// Enable the profiling function.
-void EnableProfiler(ProfilerState state);
-
// Return the event list of all threads. Asummed the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
-std::vector> DisableProfiler();
+std::vector> GetAllEvents();
// The information of each event given in the profiling report
struct EventItem {
@@ -130,13 +129,22 @@ struct EventItem {
// Candidate keys to sort the profiling report
enum EventSortingKey { kDefault, kCalls, kTotal, kMin, kMax, kAve };
+// Enable the profiling function.
+void EnableProfiler(ProfilerState state);
+
+// Clear the g_all_event_lists, which is total event lists of all threads.
+void ResetProfiler();
+
+void DisableProfiler(EventSortingKey sorted_key);
+
// Parse the event list and output the profiling report
void ParseEvents(std::vector>&,
EventSortingKey sorted_by = EventSortingKey::kDefault);
// Print results
-void PrintProfilingReport(std::vector>& events_table,
- std::string& sorted_domain, const size_t name_width,
- const size_t data_width);
+void PrintProfiler(std::vector>& events_table,
+ std::string& sorted_domain, const size_t name_width,
+ const size_t data_width);
+
} // namespace platform
} // namespace paddle
diff --git a/paddle/platform/profiler_test.cc b/paddle/platform/profiler_test.cc
index 13dea713c71e147ed5dd8d090e92d86c96256c09..81f10c91342f76910cc780b0ebd0c0df04e9d7bf 100644
--- a/paddle/platform/profiler_test.cc
+++ b/paddle/platform/profiler_test.cc
@@ -103,18 +103,14 @@ TEST(RecordEvent, RecordEvent) {
// Bad Usage:
PushEvent("event_without_pop", dev_ctx);
PopEvent("event_without_push", dev_ctx);
- std::vector> events = paddle::platform::DisableProfiler();
- // Will remove parsing-related code from test later
- ParseEvents(events, EventSortingKey::kTotal);
+ std::vector> events = paddle::platform::GetAllEvents();
int cuda_startup_count = 0;
int start_profiler_count = 0;
- int stop_profiler_count = 0;
for (size_t i = 0; i < events.size(); ++i) {
for (size_t j = 0; j < events[i].size(); ++j) {
if (events[i][j].name() == "_cuda_startup_") ++cuda_startup_count;
if (events[i][j].name() == "_start_profiler_") ++start_profiler_count;
- if (events[i][j].name() == "_stop_profiler_") ++stop_profiler_count;
if (events[i][j].name() == "push") {
EXPECT_EQ(events[i][j + 1].name(), "pop");
#ifdef PADDLE_WITH_CUDA
@@ -127,5 +123,7 @@ TEST(RecordEvent, RecordEvent) {
}
EXPECT_EQ(cuda_startup_count % 5, 0);
EXPECT_EQ(start_profiler_count, 1);
- EXPECT_EQ(stop_profiler_count, 1);
+
+ // Will remove parsing-related code from test later
+ DisableProfiler(EventSortingKey::kTotal);
}
diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt
index 7b374307071d2da91a677361b404448f1a3816b0..e78673e0baa03496faab13d069b3bd456660bad6 100644
--- a/paddle/pybind/CMakeLists.txt
+++ b/paddle/pybind/CMakeLists.txt
@@ -1,7 +1,7 @@
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
- DEPS pybind python backward proto_desc paddle_memory executor prune init
+ DEPS pybind python backward proto_desc paddle_memory executor prune init profiler
${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID)
target_link_libraries(paddle_pybind rt)
diff --git a/paddle/pybind/protobuf.h b/paddle/pybind/protobuf.h
index 089183accc08c3c486a7ae78ccfe060853ec54f5..9e747e9ea60fd95c74937daa283bc7a9eb9368c0 100644
--- a/paddle/pybind/protobuf.h
+++ b/paddle/pybind/protobuf.h
@@ -17,6 +17,7 @@ limitations under the License. */
#include
#include
#include
+#include "paddle/platform/variant.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index c5d70bc9f91bc92b28a546cc79b08a9fda150050..b4fd2a8989632e1aad99ee777ec26ba1146fa1e7 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
+#include "paddle/platform/profiler.h"
#include "paddle/pybind/const_value.h"
#include "paddle/pybind/exception.h"
#include "paddle/pybind/pybind.h"
@@ -52,7 +53,7 @@ static size_t UniqueIntegerGenerator(const std::string &prefix) {
return generators[prefix].fetch_add(1);
}
-bool IsCompileGPU() {
+bool IsCompiledWithCUDA() {
#ifndef PADDLE_WITH_CUDA
return false;
#else
@@ -430,7 +431,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_glog", framework::InitGLOG);
m.def("init_devices", &framework::InitDevices);
- m.def("is_compile_gpu", IsCompileGPU);
+ m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable);
@@ -476,6 +477,24 @@ All parameter, weight, gradient are variables in Paddle.
m.def("nvprof_stop", platform::CudaProfilerStop);
#endif
+ py::enum_(m, "ProfilerState", py::arithmetic())
+ .value("kDisabled", platform::ProfilerState::kDisabled)
+ .value("kCPU", platform::ProfilerState::kCPU)
+ .value("kCUDA", platform::ProfilerState::kCUDA)
+ .export_values();
+
+ py::enum_(m, "EventSortingKey", py::arithmetic())
+ .value("kDefault", platform::EventSortingKey::kDefault)
+ .value("kCalls", platform::EventSortingKey::kCalls)
+ .value("kTotal", platform::EventSortingKey::kTotal)
+ .value("kMin", platform::EventSortingKey::kMin)
+ .value("kMax", platform::EventSortingKey::kMax)
+ .value("kAve", platform::EventSortingKey::kAve)
+ .export_values();
+
+ m.def("enable_profiler", platform::EnableProfiler);
+ m.def("disable_profiler", platform::DisableProfiler);
+ m.def("reset_profiler", platform::ResetProfiler);
return m.ptr();
}
} // namespace pybind
diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py
index 1f041c74597637a7b74e9690a60b6cd8fdd21cf8..787416aed1acf81138df06110317614dfe77fb48 100644
--- a/python/paddle/v2/fluid/__init__.py
+++ b/python/paddle/v2/fluid/__init__.py
@@ -89,7 +89,7 @@ def __bootstrap__():
read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'do_memory_benchmark'
]
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync']
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py
index ae81d68bafd22db5d9f7ab0f9cc0dcdb204493e1..29243c90e872ca4a7d1ce6f84f6297b865655da1 100644
--- a/python/paddle/v2/fluid/backward.py
+++ b/python/paddle/v2/fluid/backward.py
@@ -178,7 +178,7 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
if _all_in_set_(
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
op_desc.input_arg_names()), no_grad_set):
- no_grad_set.union(out_arg_names)
+ no_grad_set.update(out_arg_names)
return True
return False
diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py
index 376d6013a38923014fa35e964e58d7f56bf80546..5b02d2495d1ebe9e82e7f847e5bd07548901c7fc 100644
--- a/python/paddle/v2/fluid/io.py
+++ b/python/paddle/v2/fluid/io.py
@@ -15,6 +15,7 @@
import os
import cPickle as pickle
+from paddle.v2.fluid.evaluator import Evaluator
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core
@@ -187,8 +188,14 @@ def get_inference_program(target_vars, main_program=None):
main_program = default_main_program()
if not isinstance(target_vars, list):
target_vars = [target_vars]
-
- pruned_program = main_program.prune(targets=target_vars)
+ vars = []
+ for var in target_vars:
+ if isinstance(var, Evaluator):
+ vars.append(var.states)
+ vars.append(var.metrics)
+ else:
+ vars.append(var)
+ pruned_program = main_program.prune(targets=vars)
inference_program = pruned_program.inference_optimize()
return inference_program
diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py
index 8c481444e9bf895d29ed4e4952e825c2eaafc915..7d9ae53d94b6c82890150346f138e48a0dfbf15c 100644
--- a/python/paddle/v2/fluid/layer_helper.py
+++ b/python/paddle/v2/fluid/layer_helper.py
@@ -111,6 +111,7 @@ class LayerHelper(object):
is_bias=False,
default_initializer=None):
# Deepcopy the attr so that parameters can be shared in program
+ attr = copy.deepcopy(attr)
assert isinstance(attr, ParamAttr)
suffix = 'b' if is_bias else 'w'
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index f834eac755fd1c5261acd36610c696988b969f09..477ae7cea972fea265dbcb538295ce36a7b6fe55 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -111,16 +111,17 @@ def fc(input,
into a 2-dimensional matrix. The parameter
`num_flatten_dims` determines how the input tensor
is flattened: the first `num_flatten_dims`
- dimensions will be flatten to form the first
- dimension of the final matrix (height of the
- matrix), and the rest `rank(X) - num_flatten_dims`
- dimensions are flattened to form the second
- dimension of the final matrix (width of the matrix).
- For example, suppose `X` is a 6-dimensional tensor
- with a shape [2, 3, 4, 5, 6], and
- `num_flatten_dims` = 3. Then, the flattened matrix
- will have a shape [2 x 3 x 4, 5 x 6] = [24, 30].
- By default, `num_flatten_dims` is set to 1.
+ (inclusive, index starts from 1) dimensions will
+ be flatten to form the first dimension of the
+ final matrix (height of the matrix), and the rest
+ `rank(X) - num_flatten_dims` dimensions are
+ flattened to form the second dimension of the
+ final matrix (width of the matrix). For example,
+ suppose `X` is a 6-dimensional tensor with a shape
+ [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. Then,
+ the flattened matrix will have a shape
+ [2 x 3 x 4, 5 x 6] = [24, 30]. By default,
+ `num_flatten_dims` is set to 1.
param_attr(ParamAttr|list): The parameter attribute for learnable
parameters/weights of the fully connected
layer.
@@ -161,6 +162,7 @@ def fc(input,
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
+
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
@@ -531,8 +533,10 @@ def gru_unit(input,
size (integer): The input dimension value.
weight (ParamAttr): The weight parameters for gru unit. Default: None
bias (ParamAttr): The bias parameters for gru unit. Default: None
- activation (string): The activation type for cell (actNode). Default: 'tanh'
- gate_activation (string): The activation type for gates (actGate). Default: 'sigmoid'
+ activation (string): The activation type for cell (actNode).
+ Default: 'tanh'
+ gate_activation (string): The activation type for gates (actGate).
+ Default: 'sigmoid'
Returns:
tuple: The hidden value, reset-hidden value and gate values.
@@ -671,8 +675,9 @@ def cross_entropy(input, label, **kwargs):
"""
**Cross Entropy Layer**
- This layer computes the cross entropy between `input` and `label`. It supports
- both standard cross-entropy and soft-label cross-entropy loss computation.
+ This layer computes the cross entropy between `input` and `label`. It
+ supports both standard cross-entropy and soft-label cross-entropy loss
+ computation.
1) One-hot cross-entropy:
`soft_label = False`, `Label[i, 0]` indicates the class index for sample i:
@@ -699,23 +704,28 @@ def cross_entropy(input, label, **kwargs):
Args:
input (Variable|list): a 2-D tensor with shape [N x D], where N is the
- batch size and D is the number of classes. This input is a probability
- computed by the previous operator, which is almost always the result
- of a softmax operator.
+ batch size and D is the number of classes. This
+ input is a probability computed by the previous
+ operator, which is almost always the result of
+ a softmax operator.
label (Variable|list): the ground truth which is a 2-D tensor. When
- `soft_label` is set to `False`, `label` is a tensor with shape
- [N x 1]. When `soft_label` is set to `True`, `label` is a
- tensor with shape [N x D].
- soft_label (bool, via `**kwargs`): a flag indicating whether to interpretate
- the given labels as soft labels, default `False`.
+ `soft_label` is set to `False`, `label` is a
+ tensor with shape [N x 1]. When
+ `soft_label` is set to `True`, `label` is a
+ tensor with shape [N x D].
+ soft_label (bool, via `**kwargs`): a flag indicating whether to
+ interpretate the given labels as soft
+ labels, default `False`.
Returns:
A 2-D tensor with shape [N x 1], the cross entropy loss.
Raises:
- `ValueError`: 1) the 1st dimension of `input` and `label` are not equal; 2) when \
- `soft_label == True`, and the 2nd dimension of `input` and `label` are not \
- equal; 3) when `soft_label == False`, and the 2nd dimension of `label` is not 1.
+ `ValueError`: 1) the 1st dimension of `input` and `label` are not equal.
+ 2) when `soft_label == True`, and the 2nd dimension of
+ `input` and `label` are not equal.
+ 3) when `soft_label == False`, and the 2nd dimension of
+ `label` is not 1.
Examples:
.. code-block:: python
@@ -738,7 +748,9 @@ def square_error_cost(input, label, **kwargs):
"""
**Square error cost layer**
- This layer accepts input predictions and target label and returns the squared error cost.
+ This layer accepts input predictions and target label and returns the
+ squared error cost.
+
For predictions, :math:`X`, and target labels, :math:`Y`, the equation is:
.. math::
@@ -756,8 +768,8 @@ def square_error_cost(input, label, **kwargs):
label(Variable): Label tensor, has target labels.
Returns:
- Variable: The tensor variable storing the element-wise squared error difference \
- of input and label.
+ Variable: The tensor variable storing the element-wise squared error
+ difference of input and label.
Examples:
.. code-block:: python
@@ -853,7 +865,8 @@ def chunk_eval(input,
"chunk_scheme": chunk_scheme,
"excluded_chunk_types": excluded_chunk_types or []
})
- return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks
+ return (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
+ num_correct_chunks)
def sequence_conv(input,
@@ -911,13 +924,14 @@ def conv2d(input,
**Convlution2D Layer**
The convolution2D layer calculates the output based on the input, filter
- and strides, paddings, dilations, groups parameters. Input(Input) and Output(Output)
- are in NCHW format. Where N is batch size, C is the number of channels, H is the height
- of the feature, and W is the width of the feature.
+ and strides, paddings, dilations, groups parameters. Input(Input) and
+ Output(Output) are in NCHW format. Where N is batch size, C is the number of
+ channels, H is the height of the feature, and W is the width of the feature.
The details of convolution layer, please refer UFLDL's `convolution,
`_ .
- If bias attribution and activation type are provided, bias is added to the output of the convolution,
- and the corresponding activation function is applied to the final result.
+ If bias attribution and activation type are provided, bias is added to the
+ output of the convolution, and the corresponding activation function is
+ applied to the final result.
For each input :math:`X`, the equation is:
@@ -932,7 +946,8 @@ def conv2d(input,
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
- * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
+ * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be
+ different.
Example:
@@ -977,17 +992,20 @@ def conv2d(input,
act(str): Activation type. Default: None
Returns:
- Variable: The tensor variable storing the convolution and \
+ Variable: The tensor variable storing the convolution and
non-linearity activation result.
Raises:
- ValueError: If the shapes of input, filter_size, stride, padding and groups mismatch.
+ ValueError: If the shapes of input, filter_size, stride, padding and
+ groups mismatch.
Examples:
.. code-block:: python
- data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
- conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu")
+ data = fluid.layers.data(
+ name='data', shape=[3, 32, 32], dtype='float32')
+ conv2d = fluid.layers.conv2d(
+ input=data, num_filters=2, filter_size=3, act="relu")
"""
if stride is None:
stride = [1, 1]
@@ -1350,7 +1368,8 @@ def conv2d_transpose(input,
H is the height of the feature, and W is the width of the feature.
Parameters(dilations, strides, paddings) are two elements. These two elements
represent height and width, respectively. The details of convolution transpose
- layer, please refer to the following explanation and references `therein `_.
+ layer, please refer to the following explanation and references
+ `therein `_.
For each input :math:`X`, the equation is:
@@ -1363,7 +1382,8 @@ def conv2d_transpose(input,
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`W`: Filter value, a tensor with MCHW format.
* :math:`\\ast` : Convolution transpose operation.
- * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
+ * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be
+ different.
Example:
@@ -1404,7 +1424,8 @@ def conv2d_transpose(input,
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
- param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer. Default: None
+ param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer.
+ Default: None
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
name(str|None): A name for this layer(optional). If set None, the layer
@@ -1414,13 +1435,16 @@ def conv2d_transpose(input,
Variable: The tensor variable storing the convolution transpose result.
Raises:
- ValueError: If the shapes of input, filter_size, stride, padding and groups mismatch.
+ ValueError: If the shapes of input, filter_size, stride, padding and
+ groups mismatch.
Examples:
.. code-block:: python
- data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
- conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3)
+ data = fluid.layers.data(
+ name='data', shape=[3, 32, 32], dtype='float32')
+ conv2d_transpose = fluid.layers.conv2d_transpose(
+ input=data, num_filters=2, filter_size=3)
"""
helper = LayerHelper("conv2d_transpose", **locals())
if not isinstance(input, Variable):
@@ -1644,10 +1668,10 @@ def lstm_unit(x_t,
tuple: The hidden value and cell value of lstm unit.
Raises:
- ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\
- not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \
- and **cell_t_prev** not be the same or the 2nd dimensions of \
- **hidden_t_prev** and **cell_t_prev** not be the same.
+ ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**
+ not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev**
+ and **cell_t_prev** not be the same or the 2nd dimensions of
+ **hidden_t_prev** and **cell_t_prev** not be the same.
Examples:
@@ -1979,7 +2003,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
data = fluid.layers.data(name="data",
shape=(3, 17, 13),
dtype="float32")
- fc = fluid.layers.l2_normalize(x=data, axis=1)
+ normed = fluid.layers.l2_normalize(x=data, axis=1)
"""
if len(x.shape) == 1: axis = 0
@@ -2031,9 +2055,10 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
"""
- Applies matrix multiplication to two tensors. Currently, the input
- tensors' rank can be any, but when the rank of anyone inputs is
- bigger than 3, this two inputs' rank should be equal.
+ Applies matrix multiplication to two tensors.
+
+ Currently, the input tensors' rank can be any, but when the rank of any
+ inputs is bigger than 3, this two inputs' rank should be equal.
The actual behavior depends on the shapes of :math:`x`, :math:`y` and the
flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically:
@@ -2074,25 +2099,56 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
# Examples to clarify shapes of the inputs and output
# x: [B, ..., M, K], y: [B, ..., K, N]
fluid.layers.matmul(x, y) # out: [B, ..., M, N]
+
# x: [B, M, K], y: [B, K, N]
fluid.layers.matmul(x, y) # out: [B, M, N]
+
# x: [B, M, K], y: [K, N]
fluid.layers.matmul(x, y) # out: [B, M, N]
- # x: [B, M, K], y: [K]
- fluid.layers.matmul(x, y) # out: [B, M]
+
# x: [M, K], y: [K, N]
fluid.layers.matmul(x, y) # out: [M, N]
+
+ # x: [B, M, K], y: [K]
+ fluid.layers.matmul(x, y) # out: [B, M]
+
# x: [K], y: [K]
fluid.layers.matmul(x, y) # out: [1]
- # x: [M], y: [N]
+ # x: [M], y: [N]
fluid.layers.matmul(x, y, True, True) # out: [M, N]
"""
+
+ def __check_input(x, y):
+ if len(y.shape) > len(x.shape):
+ raise ValueError(
+ "Invalid inputs for matmul. "
+ "x's rank should be always greater than or equal to y'rank.")
+
+ x_shape = list(x.shape)
+ y_shape = list(y.shape)
+ if len(x_shape) == 1:
+ x_shape = [1] + x_shape
+ if len(y_shape) == 1:
+ y_shape = y_shape + [1]
+
+ # check the inner 2 dimensions
+ if transpose_x:
+ x_shape[-2], x_shape[-1] = x_shape[-1], x_shape[-2]
+ if transpose_y:
+ y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2]
+ if x_shape[-1] != y_shape[-2]:
+ raise ValueError("Invalid inputs for matmul.")
+
+ if len(y_shape) > 2:
+ for i, dim_x in enumerate(x_shape[:-2]):
+ if dim_x != y_shape[i]:
+ raise ValueError("Invalid inputs for matmul.")
+
+ __check_input(x, y)
+
helper = LayerHelper('matmul', **locals())
- assert max(len(x.shape), len(y.shape)) <= 3 or len(x.shape) == len(
- y.
- shape), 'Inputs\' rank should be equal or their rank should be less 4.'
- out = helper.create_tmp_variable(dtype=helper.input_dtype())
+ out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='matmul',
inputs={'X': x,
@@ -2109,13 +2165,26 @@ def edit_distance(input,
ignored_tokens=None,
name=None):
"""
- EditDistance operator computes the edit distances between a batch of hypothesis strings and their references. Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
+ EditDistance operator computes the edit distances between a batch of
+ hypothesis strings and their references. Edit distance, also called
+ Levenshtein distance, measures how dissimilar two strings are by counting
+ the minimum number of operations to transform one string into anthor.
+ Here the operations include insertion, deletion, and substitution.
+
+ For example, given hypothesis string A = "kitten" and reference
+ B = "sitting", the edit distance is 3 for A will be transformed into B
+ at least after two substitutions and one insertion:
- "kitten" -> "sitten" -> "sittin" -> "sitting"
+ "kitten" -> "sitten" -> "sittin" -> "sitting"
- Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total number denoted by `batch_size`, and the separation is specified by the LoD information. And the `batch_size` reference strings are arranged in order in the same way in the LoDTensor Input(Refs).
+ Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with
+ the total number denoted by `batch_size`, and the separation is specified
+ by the LoD information. And the `batch_size` reference strings are arranged
+ in order in the same way in the LoDTensor Input(Refs).
- Output(Out) contains the `batch_size` results and each stands for the edit stance for a pair of strings respectively. If Attr(normalized) is true, the edit distance will be divided by the length of reference string.
+ Output(Out) contains the `batch_size` results and each stands for the edit
+ distance for a pair of strings respectively. If Attr(normalized) is true,
+ the edit distance will be divided by the length of reference string.
Args:
@@ -2123,9 +2192,11 @@ def edit_distance(input,
label(Variable): The indices for reference strings.
- normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
+ normalized(bool): Indicated whether to normalize the edit distance by
+ the length of reference string.
- ignored_tokens(list of int): Tokens that should be removed before calculating edit distance.
+ ignored_tokens(list of int): Tokens that should be removed before
+ calculating edit distance.
Returns:
Variable: sequence-to-sequence edit distance in shape [batch_size, 1].
@@ -2176,8 +2247,10 @@ def edit_distance(input,
def ctc_greedy_decoder(input, blank, name=None):
"""
This op is used to decode sequences by greedy policy by below steps:
- 1. Get the indexes of max value for each row in input. a.k.a. numpy.argmax(input, axis=0).
- 2. For each sequence in result of step1, merge repeated tokens between two blanks and delete all blanks.
+ 1. Get the indexes of max value for each row in input. a.k.a.
+ numpy.argmax(input, axis=0).
+ 2. For each sequence in result of step1, merge repeated tokens between two
+ blanks and delete all blanks.
A simple example as below:
@@ -2207,9 +2280,16 @@ def ctc_greedy_decoder(input, blank, name=None):
Args:
- input(Variable): (LoDTensor), the probabilities of variable-length sequences, which is a 2-D Tensor with LoD information. It's shape is [Lp, num_classes + 1], where Lp is the sum of all input sequences' length and num_classes is the true number of classes. (not including the blank label).
+ input(Variable): (LoDTensor), the probabilities of
+ variable-length sequences, which is a 2-D Tensor with
+ LoD information. It's shape is [Lp, num_classes + 1],
+ where Lp is the sum of all input sequences' length and
+ num_classes is the true number of classes. (not
+ including the blank label).
- blank(int): the blank label index of Connectionist Temporal Classification (CTC) loss, which is in thehalf-opened interval [0, num_classes + 1).
+ blank(int): the blank label index of Connectionist Temporal
+ Classification (CTC) loss, which is in thehalf-opened
+ interval [0, num_classes + 1).
Returns:
Variable: CTC greedy decode result.
@@ -2277,8 +2357,10 @@ def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
Examples:
.. code-block:: python
- y = layers.data(name='y', shape=[11, 8], dtype='float32', lod_level=1)
- y_predict = layers.data(name='y_predict', shape=[11, 1], dtype='float32')
+ y = layers.data(
+ name='y', shape=[11, 8], dtype='float32', lod_level=1)
+ y_predict = layers.data(
+ name='y_predict', shape=[11, 1], dtype='float32')
cost = layers.warpctc(input=y_predict, label=y)
"""
@@ -2432,6 +2514,12 @@ def transpose(x, perm, name=None):
raise ValueError(
"Input(perm) is the permutation of dimensions of Input(input). "
"It's length shoud be equal to Input(input)'s rank.")
+ for idx, dim in enumerate(perm):
+ if dim >= len(x.shape):
+ raise ValueError(
+ "Each element in perm should be less than x's rank. "
+ "%d-th element in perm is %d which accesses x's rank %d." %
+ (idx, perm[idx], len(x.shape)))
helper = LayerHelper('transpose', **locals())
out = helper.create_tmp_variable(x.dtype)
@@ -2540,7 +2628,8 @@ def im2sequence(input, filter_size=1, stride=1, padding=0, name=None):
.. code-block:: python
- output = fluid.layers.im2sequence(input=layer, stride=[1, 1], filter_size=[2, 2])
+ output = fluid.layers.im2sequence(
+ input=layer, stride=[1, 1], filter_size=[2, 2])
"""
diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py
index 6146e3711d3c62d22591b2855d73b5791e4b47d0..cb63d43709e23ae04c4d23457bbb79e6f7f0ce3c 100644
--- a/python/paddle/v2/fluid/nets.py
+++ b/python/paddle/v2/fluid/nets.py
@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import layers
__all__ = [
"simple_img_conv_pool",
"sequence_conv_pool",
"glu",
- "dot_product_attention",
+ "scaled_dot_product_attention",
]
@@ -160,7 +159,11 @@ def glu(input, dim=-1):
return out
-def dot_product_attention(querys, keys, values):
+def scaled_dot_product_attention(queries,
+ keys,
+ values,
+ num_heads=1,
+ dropout_rate=0.):
"""
The dot-product attention.
@@ -174,39 +177,162 @@ def dot_product_attention(querys, keys, values):
.. math::
- Attention(Q, K, V)= softmax(QK^\mathrm{T})V
+ Attention(Q, K, V)= softmax(QK^\mathrm{T})V
Refer to `Attention Is All You Need
`_.
- Note that batch data containing sequences with different lengths is not
- supported by this because of the (batch) matrix multipication.
-
Args:
- query (Variable): The input variable which is a Tensor or LoDTensor.
- key (Variable): The input variable which is a Tensor or LoDTensor.
- value (Variable): The input variable which is a Tensor or LoDTensor.
+
+ queries (Variable): The input variable which should be a 3-D Tensor.
+ keys (Variable): The input variable which should be a 3-D Tensor.
+ values (Variable): The input variable which should be a 3-D Tensor.
+ num_heads (int): Head number to compute the scaled dot product
+ attention. Default value is 1.
+ dropout_rate (float): The dropout rate to drop the attention weight.
+ Default value is 0.
Returns:
- tuple: The Tensor variables representing the output and attention scores.
+
+ Variable: A 3-D Tensor computed by multi-head scaled dot product
+ attention.
+
+ Raises:
+
+ ValueError: If input queries, keys, values are not 3-D Tensors.
+
+ NOTE:
+ 1. When num_heads > 1, three linear projections are learned respectively
+ to map input queries, keys and values into queries', keys' and values'.
+ queries', keys' and values' have the same shapes with queries, keys
+ and values.
+
+ 1. When num_heads == 1, scaled_dot_product_attention has no learnable
+ parameters.
Examples:
.. code-block:: python
- # Suppose q, k, v are tensor variables with the following shape:
+ # Suppose q, k, v are Tensors with the following shape:
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
- out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
- out.shape # [3, 5, 10]
- attn_scores.shape # [3, 5, 6]
+
+ contexts = fluid.nets.scaled_dot_product_attention(q, k, v)
+ contexts.shape # [3, 5, 10]
"""
- assert keys.shape[-2] == values.shape[
- -2], 'The shapes of keys and values mismatch.'
- assert querys.shape[-1] == keys.shape[
- -1], 'The shapes of querys and keys mismatch.'
- product = layers.matmul(x=querys, y=keys, transpose_y=True)
- attn_scores = layers.reshape(
+ if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
+ raise ValueError(
+ "Inputs quries, keys and values should all be 3-D tensors.")
+
+ if queries.shape[-1] != keys.shape[-1]:
+ raise ValueError(
+ "The hidden size of queries and keys should be the same.")
+ if keys.shape[-2] != values.shape[-2]:
+ raise ValueError(
+ "The max sequence length in query batch and in key batch "
+ "should be the same.")
+ if keys.shape[-1] % num_heads != 0:
+ raise ValueError("The hidden size of keys (%d) must be divisible "
+ "by the number of attention heads (%d)." %
+ (keys.shape[-1], num_heads))
+ if values.shape[-1] % num_heads != 0:
+ raise ValueError("The hidden size of values (%d) must be divisible "
+ "by the number of attention heads (%d)." %
+ (values.shape[-1], num_heads))
+
+ def __compute_qkv(queries, keys, values, num_heads):
+ """
+ Add linear projection to queries, keys, and values.
+
+ Args:
+ queries(Tensor): a 3-D input Tensor.
+ keys(Tensor): a 3-D input Tensor.
+ values(Tensor): a 3-D input Tensor.
+ num_heads(int): The number of heads. Linearly project the inputs
+ ONLY when num_heads > 1.
+
+ Returns:
+ Tensor: linearly projected output Tensors: queries', keys' and
+ values'. They have the same shapes with queries, keys and
+ values.
+ """
+
+ if num_heads == 1:
+ return queries, keys, values
+
+ q = layers.fc(input=queries, size=queries.shape[-1], num_flatten_dims=2)
+ k = layers.fc(input=keys, size=keys.shape[-1], num_flatten_dims=2)
+ v = layers.fc(input=values, size=values.shape[-1], num_flatten_dims=2)
+ return q, k, v
+
+ def __split_heads(x, num_heads):
+ """
+ Reshape the last dimension of inpunt tensor x so that it becomes two
+ dimensions.
+
+ Args:
+ x(Tensor): a 3-D input Tensor.
+ num_heads(int): The number of heads.
+
+ Returns:
+ Tensor: a Tensor with shape [..., n, m/num_heads], where m is size
+ of the last dimension of x.
+ """
+ if num_heads == 1:
+ return x
+
+ hidden_size = x.shape[-1]
+ # reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
+ # into a 4-D output:
+ # [batch_size, max_sequence_length, num_heads, hidden_size_per_head].
+ reshaped = layers.reshape(
+ x=x,
+ shape=list(x.shape[:-1]) + [num_heads, hidden_size // num_heads])
+
+ # permuate the dimensions into:
+ # [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
+ return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
+
+ def __combine_heads(x):
+ """
+ Reshape the last two dimensions of inpunt tensor x so that it becomes
+ one dimension.
+
+ Args:
+ x(Tensor): a 4-D input Tensor with shape
+ [bs, num_heads, max_sequence_length, hidden_dim].
+
+ Returns:
+ Tensor: a Tensor with shape
+ [bs, max_sequence_length, num_heads * hidden_dim].
+ """
+
+ if len(x.shape) == 3: return x
+ if len(x.shape) != 4:
+ raise ValueError("Input(x) should be a 4-D Tensor.")
+
+ trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
+ return layers.reshape(
+ x=trans_x,
+ shape=map(int, [
+ trans_x.shape[0], trans_x.shape[1],
+ trans_x.shape[2] * trans_x.shape[3]
+ ]))
+
+ q, k, v = __compute_qkv(queries, keys, values, num_heads)
+
+ q = __split_heads(q, num_heads)
+ k = __split_heads(k, num_heads)
+ v = __split_heads(v, num_heads)
+
+ key_dim_per_head = keys.shape[-1] // num_heads
+ scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
+ product = layers.matmul(x=k, y=scaled_q, transpose_y=True)
+
+ weights = layers.reshape(
x=layers.reshape(
- x=product, shape=[-1, product.shape[-1]], act='softmax'),
+ x=product, shape=[-1, product.shape[-1]], act="softmax"),
shape=product.shape)
- out = layers.matmul(attn_scores, values)
- return out, attn_scores
+ if dropout_rate:
+ weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False)
+ ctx_multiheads = layers.matmul(weights, v)
+ return __combine_heads(ctx_multiheads)
diff --git a/python/paddle/v2/fluid/profiler.py b/python/paddle/v2/fluid/profiler.py
index 29e0d54a3ac9622e5505c8e5de38616d9c636e67..51c1c8aa705513825b46fb936c6c99090c50fb7d 100644
--- a/python/paddle/v2/fluid/profiler.py
+++ b/python/paddle/v2/fluid/profiler.py
@@ -63,3 +63,58 @@ def cuda_profiler(output_file, output_mode=None, config=None):
# Disables profiler collection.
core.nvprof_stop()
os.remove(config_file)
+
+
+def reset_profiler():
+ """The profiler clear interface.
+ reset_profiler will clear the previous time record.
+ """
+ core.reset_profiler()
+
+
+@contextmanager
+def profiler(state, sorted_key=None):
+ """The profiler interface.
+ Different from cuda_profiler, this profiler can be used to profile both CPU
+ and GPU program. By defalut, it records the CPU and GPU operator kernels,
+ if you want to profile other program, you can refer the profiling tutorial
+ to add more records.
+
+ Args:
+ state (string) : The profiling state, which should be 'CPU' or 'GPU',
+ telling the profiler to use CPU timer or GPU timer for profiling.
+ Although users may have already specified the execution place
+ (CPUPlace/CUDAPlace) in the begining, for flexibility the profiler
+ would not inherit this place.
+ sorted_key (string) : If None, the profiling results will be printed
+ in the order of first end time of events. Otherwise, the profiling
+ results will be sorted by the this flag. This flag should be one
+ of 'calls', 'total', 'max', 'min' or 'ave'.
+ The `calls` means sorting by the number of calls.
+ The `total` means sorting by the total execution time.
+ The `max` means sorting by the maximum execution time.
+ The `min` means sorting by the minimum execution time.
+ The `ave` means sorting by the average execution time.
+ """
+
+ if state not in ['CPU', 'GPU']:
+ raise ValueError("The state must be 'CPU' or 'GPU'.")
+ prof_state = core.ProfilerState.kCUDA if state == "GPU" else core.ProfilerState.kCPU
+ core.enable_profiler(prof_state)
+ yield
+
+ if sorted_key not in ['calls', 'total', 'max', 'min', 'ave']:
+ raise ValueError("The state must be in 'calls', 'total', "
+ "'max', 'min', 'ave'")
+ sorted_key = 'default' if sorted_key is None else sorted_key
+ key_map = {
+ 'default': core.EventSortingKey.kDefault,
+ 'calls': core.EventSortingKey.kCalls,
+ 'total': core.EventSortingKey.kTotal,
+ 'max': core.EventSortingKey.kMax,
+ 'min': core.EventSortingKey.kMin,
+ 'ave': core.EventSortingKey.kAve,
+ }
+ # TODO(qingqing) : redirect C++ ostream to Python stream.
+ # with core.ostream_redirect(stdout=True, stderr=True):
+ core.disable_profiler(key_map[sorted_key])
diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py
index 56f54de86f680653fbd97a7ce1d3f547d1657587..3f6d7070c2987d0557c60db84a2c679cd2cfe36b 100644
--- a/python/paddle/v2/fluid/tests/op_test.py
+++ b/python/paddle/v2/fluid/tests/op_test.py
@@ -334,7 +334,7 @@ class OpTest(unittest.TestCase):
def check_output(self, atol=1e-5):
places = [core.CPUPlace()]
- if core.is_compile_gpu() and core.op_support_gpu(self.op_type):
+ if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0))
for place in places:
self.check_output_with_place(place, atol)
@@ -367,7 +367,7 @@ class OpTest(unittest.TestCase):
max_relative_error=0.005,
user_defined_grads=None):
places = [core.CPUPlace()]
- if core.is_compile_gpu() and core.op_support_gpu(self.op_type):
+ if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0))
for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names,
diff --git a/python/paddle/v2/fluid/tests/test_adagrad_op.py b/python/paddle/v2/fluid/tests/test_adagrad_op.py
index 86b0567ce123b00bace639fb8fe76cf3894abd6d..3556bcf8ba0d7f16b1d9bf50e46aebde83de2e25 100644
--- a/python/paddle/v2/fluid/tests/test_adagrad_op.py
+++ b/python/paddle/v2/fluid/tests/test_adagrad_op.py
@@ -180,7 +180,7 @@ class TestSparseAdagradOp(unittest.TestCase):
def test_sparse_adagrad(self):
places = [core.CPUPlace()]
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
diff --git a/python/paddle/v2/fluid/tests/test_adam_op.py b/python/paddle/v2/fluid/tests/test_adam_op.py
index 10580adca714beeb7571312b8fdc4235ecaaccfe..df1fa8983c1984a9bb9f204aded148c17d3d609d 100644
--- a/python/paddle/v2/fluid/tests/test_adam_op.py
+++ b/python/paddle/v2/fluid/tests/test_adam_op.py
@@ -305,7 +305,7 @@ class TestSparseAdamOp(unittest.TestCase):
def test_sparse_sgd(self):
places = [core.CPUPlace()]
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
diff --git a/python/paddle/v2/fluid/tests/test_batch_norm_op.py b/python/paddle/v2/fluid/tests/test_batch_norm_op.py
index 371bd426781b457582e74c33c80c46b5d56946fa..cf13166f255c782bdcec622d58d073a0943c8e1e 100644
--- a/python/paddle/v2/fluid/tests/test_batch_norm_op.py
+++ b/python/paddle/v2/fluid/tests/test_batch_norm_op.py
@@ -352,7 +352,7 @@ class TestBatchNormOp(OpTest):
print "op test backward passed: ", str(place), data_layout
places = [core.CPUPlace()]
- if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
+ if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
places.append(core.CUDAPlace(0))
for place in places:
diff --git a/python/paddle/v2/fluid/tests/test_gaussian_random_op.py b/python/paddle/v2/fluid/tests/test_gaussian_random_op.py
index 82842534d4ac7ad8b0a8e0d877c6a638fb53cadc..79beb8b1fcef610bc2f3e8d18da4345baa9b99c3 100644
--- a/python/paddle/v2/fluid/tests/test_gaussian_random_op.py
+++ b/python/paddle/v2/fluid/tests/test_gaussian_random_op.py
@@ -33,7 +33,7 @@ class TestGaussianRandomOp(unittest.TestCase):
self.gaussian_random_test(place=fluid.CPUPlace())
def test_gpu(self):
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
self.gaussian_random_test(place=fluid.CUDAPlace(0))
def gaussian_random_test(self, place):
diff --git a/python/paddle/v2/fluid/tests/test_iou_similarity_op.py b/python/paddle/v2/fluid/tests/test_iou_similarity_op.py
old mode 100755
new mode 100644
diff --git a/python/paddle/v2/fluid/tests/test_multihead_attention.py b/python/paddle/v2/fluid/tests/test_multihead_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..54ec3e3d6e53f35d6a518ef659853e1a13c1711f
--- /dev/null
+++ b/python/paddle/v2/fluid/tests/test_multihead_attention.py
@@ -0,0 +1,98 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import paddle.v2.fluid as fluid
+import paddle.v2.fluid.core as core
+import numpy as np
+
+
+class TestMultiheadAttention(unittest.TestCase):
+ def gen_random_input(self):
+ """Generate random input data.
+ """
+ # batch_size, max_sequence_length, hidden dimension
+ self.input_shape = (3, 13, 16)
+ self.queries = np.random.random(size=self.input_shape).astype("float32")
+ self.keys = np.random.random(size=self.input_shape).astype("float32")
+
+ def set_program(self):
+ """Build the test program.
+ """
+ queries = fluid.layers.data(
+ name="queries",
+ shape=self.input_shape,
+ dtype="float32",
+ append_batch_size=False)
+ queries.stop_gradient = False
+ keys = fluid.layers.data(
+ name="keys",
+ shape=self.input_shape,
+ dtype="float32",
+ append_batch_size=False)
+ keys.stop_gradient = False
+
+ contexts = fluid.nets.scaled_dot_product_attention(
+ queries=queries,
+ keys=keys,
+ values=keys,
+ num_heads=8,
+ dropout_rate=0.)
+ out = fluid.layers.reduce_sum(contexts, dim=None)
+ fluid.backward.append_backward(loss=out)
+
+ self.fetch_list = [contexts]
+
+ def run_program(self):
+ """Run the test program.
+ """
+ places = [core.CPUPlace()]
+ if core.is_compile_gpu():
+ places.append(core.CUDAPlace(0))
+
+ for place in places:
+ self.set_inputs(place)
+ exe = fluid.Executor(place)
+
+ exe.run(fluid.default_startup_program())
+ output = exe.run(fluid.default_main_program(),
+ feed=self.inputs,
+ fetch_list=self.fetch_list,
+ return_numpy=True)
+ self.op_output = output
+
+ def set_inputs(self, place):
+ """Set the randomly generated data to the test program.
+ """
+ self.inputs = {}
+ queries = fluid.Tensor()
+ queries.set(self.queries, place)
+
+ keys = fluid.Tensor()
+ keys.set(self.keys, place)
+
+ self.inputs["keys"] = keys
+ self.inputs["queries"] = queries
+
+ def test_multihead_attention(self):
+ self.gen_random_input()
+
+ self.set_program()
+ self.run_program()
+
+ #fixme(caoying) add more meaningfull unittest.
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/paddle/v2/fluid/tests/test_normalization_wrapper.py b/python/paddle/v2/fluid/tests/test_normalization_wrapper.py
index 57f14f6b9cc9c7cf9ae93274cf3d7763350e6e10..6b71f2a923f0cf0744d6b2190aa35830dcf15f24 100644
--- a/python/paddle/v2/fluid/tests/test_normalization_wrapper.py
+++ b/python/paddle/v2/fluid/tests/test_normalization_wrapper.py
@@ -46,7 +46,7 @@ class TestNormalization(unittest.TestCase):
"""Run the test program.
"""
places = [core.CPUPlace()]
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
diff --git a/python/paddle/v2/fluid/tests/test_op_support_gpu.py b/python/paddle/v2/fluid/tests/test_op_support_gpu.py
index 34939818126b1d747fb76861bbd691894fb3759b..7de02a8fda22a3db82a2e0b5e6fa9c9f2718fa12 100644
--- a/python/paddle/v2/fluid/tests/test_op_support_gpu.py
+++ b/python/paddle/v2/fluid/tests/test_op_support_gpu.py
@@ -18,7 +18,8 @@ import paddle.v2.fluid.core as core
class TestOpSupportGPU(unittest.TestCase):
def test_case(self):
- self.assertEqual(core.is_compile_gpu(), core.op_support_gpu("sum"))
+ self.assertEqual(core.is_compiled_with_cuda(),
+ core.op_support_gpu("sum"))
if __name__ == '__main__':
diff --git a/python/paddle/v2/fluid/tests/test_parallel_op.py b/python/paddle/v2/fluid/tests/test_parallel_op.py
index dfde492c7cd930615c030bb0c8e5a2cf36ff59a8..367cc8b1aaf0aff24c685031f33d35becb9eb7ef 100644
--- a/python/paddle/v2/fluid/tests/test_parallel_op.py
+++ b/python/paddle/v2/fluid/tests/test_parallel_op.py
@@ -53,7 +53,7 @@ class BaseParallelForTest(unittest.TestCase):
fetch=fetch,
place=cpu,
use_parallel=True)
- if fluid.core.is_compile_gpu():
+ if fluid.core.is_compiled_with_cuda():
gpu = fluid.CUDAPlace(0)
result_gpu = self._run_test_impl_(
callback=callback,
@@ -159,7 +159,7 @@ class ParallelOpTest(BaseParallelForTest):
def test_simple_fc(self):
self.run_test(
- callback=ParallelOpTest.__network__,
+ callback=self.__network__,
feed={
'img': numpy.random.random(size=(51, 784)).astype('float32')
},
@@ -167,10 +167,35 @@ class ParallelOpTest(BaseParallelForTest):
def test_fc_with_tiny_data(self):
self.run_test(
- callback=ParallelOpTest.__network__,
+ callback=self.__network__,
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')},
fetch=['fc1.w@GRAD'])
+class ParallelOpTestMultipleInput(BaseParallelForTest):
+ @staticmethod
+ def __network__():
+ x = fluid.layers.data(
+ shape=[784], dtype='float32', name='img1', stop_gradient=False)
+ y = fluid.layers.data(
+ shape=[784], dtype='float32', name='img2', stop_gradient=False)
+ yield [x, y]
+ x = x + y
+ hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
+ hidden2 = fluid.layers.fc(input=hidden1, size=200, param_attr='fc2.w')
+ hidden3 = fluid.layers.fc(input=hidden2, size=200, param_attr='fc3.w')
+ loss = fluid.layers.mean(x=hidden3)
+ yield loss
+
+ def test_simple_fc(self):
+ self.run_test(
+ callback=self.__network__,
+ feed={
+ 'img1': numpy.random.random(size=(51, 784)).astype('float32'),
+ 'img2': numpy.random.random(size=(51, 784)).astype('float32')
+ },
+ fetch=['fc1.w@GRAD', 'fc2.w@GRAD', 'fc3.w@GRAD'])
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/paddle/v2/fluid/tests/test_profiler.py b/python/paddle/v2/fluid/tests/test_profiler.py
index abf8881b6786416f56f93e498761a4791b35d7c3..09b2d08401878448b4b3f3c6c03193e255e9ffeb 100644
--- a/python/paddle/v2/fluid/tests/test_profiler.py
+++ b/python/paddle/v2/fluid/tests/test_profiler.py
@@ -13,16 +13,17 @@
# limitations under the License.
import unittest
+import os
import numpy as np
import paddle.v2.fluid as fluid
import paddle.v2.fluid.profiler as profiler
import paddle.v2.fluid.layers as layers
-import os
+import paddle.v2.fluid.core as core
class TestProfiler(unittest.TestCase):
def test_nvprof(self):
- if not fluid.core.is_compile_gpu():
+ if not fluid.core.is_compiled_with_cuda():
return
epoc = 8
dshape = [4, 3, 28, 28]
@@ -40,6 +41,50 @@ class TestProfiler(unittest.TestCase):
exe.run(fluid.default_main_program(), feed={'data': input})
os.remove(output_file)
+ def net_profiler(self, state):
+ if state == 'GPU' and not core.is_compiled_with_cuda():
+ return
+ startup_program = fluid.Program()
+ main_program = fluid.Program()
+
+ with fluid.program_guard(main_program, startup_program):
+ image = fluid.layers.data(name='x', shape=[784], dtype='float32')
+ hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
+ hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
+ predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
+ label = fluid.layers.data(name='y', shape=[1], dtype='int64')
+ cost = fluid.layers.cross_entropy(input=predict, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
+
+ optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
+ opts = optimizer.minimize(avg_cost, startup_program=startup_program)
+
+ place = fluid.CPUPlace() if state == 'CPU' else fluid.CUDAPlace(0)
+ exe = fluid.Executor(place)
+ exe.run(startup_program)
+
+ accuracy.reset(exe)
+ with profiler.profiler(state, 'total') as prof:
+ for iter in range(10):
+ if iter == 2:
+ profiler.reset_profiler()
+ x = np.random.random((32, 784)).astype("float32")
+ y = np.random.randint(0, 10, (32, 1)).astype("int64")
+
+ outs = exe.run(main_program,
+ feed={'x': x,
+ 'y': y},
+ fetch_list=[avg_cost] + accuracy.metrics)
+ acc = np.array(outs[1])
+ pass_acc = accuracy.eval(exe)
+
+ def test_cpu_profiler(self):
+ self.net_profiler('CPU')
+
+ def test_cuda_profiler(self):
+ self.net_profiler('GPU')
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py b/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py
index 74cd6de9e6fde70c001bb2189c4976cdd8e34633..0a223bac0ce8fd626881cef983c7cd960f2c5ba8 100644
--- a/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py
+++ b/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py
@@ -45,7 +45,7 @@ class TestReorderLoDTensor(unittest.TestCase):
outputs = []
input_grads = []
places = [core.CPUPlace()]
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.set_inputs(place)
diff --git a/python/paddle/v2/fluid/tests/test_sgd_op.py b/python/paddle/v2/fluid/tests/test_sgd_op.py
index f87927968b0fdb00ec207ff1d52be9e0d81af139..ba2ca1683f9f6d72bbd1550df89c7424d223a1d9 100644
--- a/python/paddle/v2/fluid/tests/test_sgd_op.py
+++ b/python/paddle/v2/fluid/tests/test_sgd_op.py
@@ -91,7 +91,7 @@ class TestSparseSGDOp(unittest.TestCase):
def test_sparse_sgd(self):
places = [core.CPUPlace()]
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
diff --git a/python/paddle/v2/fluid/tests/test_split_selected_rows_op.py b/python/paddle/v2/fluid/tests/test_split_selected_rows_op.py
index 37c6587c4151a89563f93cab35d63b2419ef88ab..343aa20066146ae08462a92f1efaa20c4d4b5ed8 100644
--- a/python/paddle/v2/fluid/tests/test_split_selected_rows_op.py
+++ b/python/paddle/v2/fluid/tests/test_split_selected_rows_op.py
@@ -21,7 +21,7 @@ from paddle.v2.fluid.op import Operator
class TestSpliteSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
diff --git a/python/paddle/v2/fluid/tests/test_uniform_random_op.py b/python/paddle/v2/fluid/tests/test_uniform_random_op.py
index b2a39f975eb461292dc2e7be332a26931684bf90..94cf416fad8f02cdea8017ae1350fa264ce644b1 100644
--- a/python/paddle/v2/fluid/tests/test_uniform_random_op.py
+++ b/python/paddle/v2/fluid/tests/test_uniform_random_op.py
@@ -36,7 +36,7 @@ class TestUniformRandomOp(unittest.TestCase):
self.uniform_random_test(place=core.CPUPlace())
def test_gpu(self):
- if core.is_compile_gpu():
+ if core.is_compiled_with_cuda():
self.uniform_random_test(place=core.CUDAPlace(0))
def uniform_random_test(self, place):