diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 25cb15d2cc8c27e5fa1477e60e4428d5823495dd..6e73aaef15e07d3f75bb463b9fcaa8a8fde5c834 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -90,9 +90,6 @@ InterpreterCore::~InterpreterCore() { // cancle gc's thread gc_.reset(nullptr); - exception_notifier_->UnregisterEvent(); - completion_notifier_->UnregisterEvent(); - async_work_queue_.reset(nullptr); } diff --git a/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc index ac45e7b5fdfe9feb284a0a5e156e6aacbc43f48b..163050ae5a65a7758fdcbe00f5e880acf4262f9d 100644 --- a/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc +++ b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc @@ -19,37 +19,79 @@ namespace paddle { namespace framework { +constexpr EventsWaiter::EventId kEmptyEventId = 0; + EventsWaiter::EventsWaiter() - : trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {} + : trigger_event_(kEmptyEventId), + counter_(0), + eof_(true), + waiting_(false), + cv_(1) {} std::shared_ptr EventsWaiter::RegisterEvent( const std::string& name, EventChecker checker) { - auto counter = counter_.fetch_add(1); - auto id = std::hash()(name + std::to_string(counter)); + EventId id = kEmptyEventId; + EventInfo* evt = nullptr; + do { + auto counter = counter_.fetch_add(1); + id = std::hash()(name + std::to_string(counter)); + if (id == kEmptyEventId) { + continue; + } + std::lock_guard guard(events_lock_); + if (events_.count(id) > 0) { + continue; + } + evt = &(events_[id]); + } while (evt == nullptr); + evt->id = id; + evt->name = name; + evt->type = TriggerType::LevelTriggered; + evt->checker = std::move(checker); + eof_.store(false, std::memory_order_relaxed); VLOG(10) << "Register event id:" << id << " name:" << name; auto notifier = std::shared_ptr(new EventNotifier(id, this)); - EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)}; - std::lock_guard guard(events_lock_); - events_[id] = std::move(evt); return notifier; } std::shared_ptr EventsWaiter::RegisterEvent( const std::string& name) { - auto counter = counter_.fetch_add(1); - auto id = std::hash()(name + std::to_string(counter)); + EventId id = kEmptyEventId; + EventInfo* evt = nullptr; + do { + auto counter = counter_.fetch_add(1); + id = std::hash()(name + std::to_string(counter)); + if (id == kEmptyEventId) { + continue; + } + std::lock_guard guard(events_lock_); + if (events_.count(id) > 0) { + continue; + } + evt = &(events_[id]); + } while (evt == nullptr); + evt->id = id; + evt->name = name; + evt->type = TriggerType::EdgeTriggered; + evt->checker = []() { return false; }; + eof_.store(false, std::memory_order_relaxed); VLOG(10) << "Register event id:" << id << " name:" << name; auto notifier = std::shared_ptr(new EventNotifier(id, this)); - EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }}; - std::lock_guard guard(events_lock_); - events_[id] = std::move(evt); return notifier; } void EventsWaiter::UnregisterEvent(const EventId& id) { VLOG(10) << "Unregister event id:" << id; - std::lock_guard guard(events_lock_); - events_.erase(id); + { + std::lock_guard guard(events_lock_); + deleted_events_.insert(id); + if (deleted_events_.size() == events_.size()) { + eof_.store(true, std::memory_order_relaxed); + } + } + if (eof_.load(std::memory_order_relaxed)) { + cv_.Notify(true); + } } std::string EventsWaiter::WaitEvent() { @@ -61,42 +103,60 @@ std::string EventsWaiter::WaitEvent() { PADDLE_THROW( platform::errors::ResourceExhausted("Another thread is waiting.")); } + auto w = cv_.GetWaiter(0); - cv_.Prewait(); - std::string* triggered = trigger_event_; - if (triggered == nullptr) { + EventId triggered = trigger_event_; + while (triggered == kEmptyEventId && !eof_) { + cv_.Prewait(); + + // double check + triggered = trigger_event_; // checkers - { - std::lock_guard guard(events_lock_); - for (auto& kv : events_) { - auto& evt = kv.second; - if (TriggerType::LevelTriggered == evt.type && evt.checker()) { - triggered = new std::string(evt.name); - break; + if (triggered == kEmptyEventId) { + { + std::lock_guard guard(events_lock_); + for (auto& kv : events_) { + auto& evt = kv.second; + if (TriggerType::LevelTriggered == evt.type && evt.checker()) { + triggered = evt.id; + break; + } } } - } - if (triggered != nullptr) { - std::string* prev = nullptr; - if (!trigger_event_.compare_exchange_strong(prev, triggered, - std::memory_order_seq_cst, - std::memory_order_relaxed)) { - delete triggered; - triggered = prev; + if (triggered != kEmptyEventId) { + EventId prev = kEmptyEventId; + if (!trigger_event_.compare_exchange_strong( + prev, triggered, std::memory_order_seq_cst, + std::memory_order_relaxed)) { + triggered = prev; + } } } + + if (triggered != kEmptyEventId || eof_) { + cv_.CancelWait(); + } else { + cv_.CommitWait(w); + triggered = trigger_event_; + } } - if (triggered) { - cv_.CancelWait(); - } else { - cv_.CommitWait(w); - triggered = trigger_event_; + + trigger_event_.store(kEmptyEventId, std::memory_order_relaxed); + waiting_.store(false, std::memory_order_relaxed); + std::string evt_name = + triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered); + VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name; + // lazy deletion + { + std::lock_guard guard(events_lock_); + if (deleted_events_.size() > 0) { + for (auto evt : deleted_events_) { + events_.erase(evt); + } + deleted_events_.clear(); + } } - trigger_event_.store(nullptr, std::memory_order_relaxed); - waiting_.store(false); - auto trigger_event = *triggered; - delete triggered; - return trigger_event; + return evt_name; } int EventsWaiter::Clear() { @@ -106,32 +166,33 @@ int EventsWaiter::Clear() { std::memory_order_relaxed)) { return -1; } - trigger_event_.store(nullptr, std::memory_order_relaxed); + trigger_event_.store(kEmptyEventId, std::memory_order_relaxed); waiting_.store(false); return 0; } void EventsWaiter::TriggerEvent(const EventId& id) { VLOG(10) << "Try to trigger event id:" << id; - std::string* trigger_event = new std::string; - { - std::lock_guard guard(events_lock_); - auto iter = events_.find(id); - if (iter == events_.end()) { - delete trigger_event; - return; - } - *trigger_event = iter->second.name; + EventId prev = kEmptyEventId; + if (!trigger_event_.compare_exchange_strong( + prev, id, std::memory_order_seq_cst, std::memory_order_relaxed)) { + VLOG(10) << "Event id:" << prev << " is pending"; + return; } - std::string* prev = nullptr; - if (!trigger_event_.compare_exchange_strong(prev, trigger_event, + VLOG(10) << "Triggered event id:" << id; + cv_.Notify(true); +} + +void EventsWaiter::CancelEvent(const EventId& id) { + VLOG(10) << "Try to cancel event id:" << id; + EventId prev = id; + if (!trigger_event_.compare_exchange_strong(prev, kEmptyEventId, std::memory_order_seq_cst, std::memory_order_relaxed)) { - delete trigger_event; + VLOG(10) << "Event id:" << prev << " is pending"; return; } - VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event; - cv_.Notify(true); + VLOG(10) << "Cancelled event id:" << id; } std::string EventsWaiter::GetEventName(const EventId& id) { diff --git a/paddle/fluid/framework/new_executor/workqueue/events_waiter.h b/paddle/fluid/framework/new_executor/workqueue/events_waiter.h index 5ffed15155d592941c77a846b9df563b81d70c66..9d85f4a27242c9f9c8ed7ffa80879d626527dd35 100644 --- a/paddle/fluid/framework/new_executor/workqueue/events_waiter.h +++ b/paddle/fluid/framework/new_executor/workqueue/events_waiter.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "paddle/fluid/framework/new_executor/workqueue/event_count.h" #include "paddle/fluid/memory/allocation/spin_lock.h" @@ -37,13 +38,12 @@ class EventsWaiter { // Make sure EventsWaiter has a longer lifetime than EventNotifier. class EventNotifier { public: - void NotifyEvent() { waiter_.TriggerEvent(id_); } + ~EventNotifier() { waiter_.UnregisterEvent(id_); } - void UnregisterEvent() { waiter_.UnregisterEvent(id_); } + void NotifyEvent() { waiter_.TriggerEvent(id_); } - EventId GetEventId() { return id_; } + void CancelEvent() { waiter_.CancelEvent(id_); } - // return "Unregistered" if the corresponding event was unregistered. std::string GetEventName() { return waiter_.GetEventName(id_); } private: @@ -97,12 +97,16 @@ class EventsWaiter { void TriggerEvent(const EventId& id); + void CancelEvent(const EventId& id); + std::string GetEventName(const EventId& id); std::unordered_map events_; + std::unordered_set deleted_events_; paddle::memory::SpinLock events_lock_; - std::atomic trigger_event_; + std::atomic trigger_event_; std::atomic counter_; + std::atomic eof_; std::atomic waiting_; EventCount cv_; }; diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc index 596ffb9bfc0c4f624aeaf5874bdf18563d96d14c..881878ebb12a721e7b194036b9d36a89c5404365 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc @@ -31,11 +31,8 @@ class WorkQueueImpl : public WorkQueue { public: explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) { if (options_.track_task && options.events_waiter != nullptr) { + empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); - TaskTracker* tracker = reinterpret_cast(storage); - empty_notifier_ = options.events_waiter->RegisterEvent( - kQueueEmptyEvent, - [tracker]() { return tracker->PendingTaskNum() == 0; }); tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); } if (options_.detached == false && options.events_waiter != nullptr) { @@ -47,9 +44,6 @@ class WorkQueueImpl : public WorkQueue { } virtual ~WorkQueueImpl() { - if (empty_notifier_) { - empty_notifier_->UnregisterEvent(); - } delete queue_; if (tracker_ != nullptr) { tracker_->~TaskTracker(); @@ -57,7 +51,6 @@ class WorkQueueImpl : public WorkQueue { } if (destruct_notifier_) { destruct_notifier_->NotifyEvent(); - destruct_notifier_->UnregisterEvent(); } } @@ -124,14 +117,12 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( const auto& options = queues_options_[idx]; if (options.track_task && tracker_ == nullptr && options.events_waiter != nullptr) { + empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); - TaskTracker* tracker = reinterpret_cast(storage); - empty_notifier_ = options.events_waiter->RegisterEvent( - kQueueEmptyEvent, - [tracker]() { return tracker->PendingTaskNum() == 0; }); tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); } - if (options.detached == false && options.events_waiter != nullptr) { + if (options.detached == false && options.events_waiter != nullptr && + !destruct_notifier_) { destruct_notifier_ = options.events_waiter->RegisterEvent(kQueueDestructEvent); } @@ -141,9 +132,6 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( } WorkQueueGroupImpl::~WorkQueueGroupImpl() { - if (empty_notifier_) { - empty_notifier_->UnregisterEvent(); - } for (auto queue : queues_) { queue->~NonblockingThreadPool(); } @@ -154,7 +142,6 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() { free(queues_storage_); if (destruct_notifier_) { destruct_notifier_->NotifyEvent(); - destruct_notifier_->UnregisterEvent(); } } diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc index 97f0282a15837e74e874202cd1891ff62de8d951..d8e09fb6baefe4cfc0e40cf0a1985f98853b9da5 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" #include +#include #include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" @@ -26,11 +27,12 @@ TEST(WorkQueueUtils, TestEventsWaiter) { EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); EXPECT_EQ(notifier->GetEventName(), "test_register_lt"); EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); - notifier->UnregisterEvent(); - EXPECT_EQ(notifier->GetEventName(), "Unregistered"); + notifier.reset(); notifier = events_waiter.RegisterEvent("test_register_et"); notifier->NotifyEvent(); EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et"); + notifier->NotifyEvent(); + notifier->CancelEvent(); } TEST(WorkQueue, TestSingleThreadedWorkQueue) { @@ -106,8 +108,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); // Cancel work_queue->Cancel(); + // Wait kQueueDestructEvent + std::thread waiter_thread([&events_waiter]() { + EXPECT_EQ(events_waiter.WaitEvent(), + paddle::framework::kQueueDestructEvent); + }); work_queue.reset(); - EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); + waiter_thread.join(); } TEST(WorkQueue, TestWorkQueueGroup) { @@ -154,10 +161,15 @@ TEST(WorkQueue, TestWorkQueueGroup) { // WaitQueueGroupEmpty events_waiter.WaitEvent(); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); + EXPECT_EQ(handle.get(), random_num); // Cancel queue_group->Cancel(); - events_waiter.WaitEvent(); + // Wait kQueueDestructEvent + std::thread waiter_thread([&events_waiter]() { + EXPECT_EQ(events_waiter.WaitEvent(), + paddle::framework::kQueueDestructEvent); + EXPECT_EQ(events_waiter.WaitEvent(), "NoEventNotifier"); + }); queue_group.reset(); - EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); - EXPECT_EQ(handle.get(), random_num); + waiter_thread.join(); } diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h index eee64df285dcb0aed23a8d4a4c622639cfe0772a..b6e6ede8c334fa58b6bacec9876a287a5bd0b3e0 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h @@ -81,7 +81,13 @@ class TaskTracker { ~TaskTracker() = default; - void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); } + void AddCounter() { + if (0 == num_tasks_.fetch_add(1, std::memory_order_relaxed)) { + if (notifier_ != nullptr) { + notifier_->CancelEvent(); + } + } + } void SubCounter() { if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) { diff --git a/paddle/fluid/operators/assign_pos_op.cc b/paddle/fluid/operators/assign_pos_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..69c0283c7bffb33cd49c1d0374f647828364dc67 --- /dev/null +++ b/paddle/fluid/operators/assign_pos_op.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/assign_pos_op.h" + +namespace paddle { +namespace operators { + +class AssignPosOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("cum_count"), "Input", "cum_count", + "AssignPos"); + OP_INOUT_CHECK(ctx->HasInput("eff_num_len"), "Input", "eff_num_len", + "AssignPos"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AssignPos"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AssignPos"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto cum_count_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "cum_count"); + auto X_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + PADDLE_ENFORCE_EQ(cum_count_dtype, X_dtype, + platform::errors::InvalidArgument( + "The dtype of the cum_count and X should be same")); + PADDLE_ENFORCE_EQ(cum_count_dtype, framework::proto::VarType::INT64, + platform::errors::InvalidArgument( + "The dtype of the cum_count_dtype, eff_num_len and " + "X should be same as int64")); + return framework::OpKernelType(cum_count_dtype, ctx.device_context()); + } +}; + +class AssignPosOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "numbers to scatter."); + AddInput("cum_count", "The cumulative sum count of numbers."); + AddInput("eff_num_len", + "The effective numbers of numbers should be scattered."); + AddOutput("Out", "Assemble numbers in the order of counters."); + + AddComment(R"DOC( +assign_pos_op Operator. + +Assign pos decides which tokens should be fetched belong to +specially counter orderingly. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ops::AssignPosOp, + ops::AssignPosOpMaker); + +REGISTER_OP_CPU_KERNEL(assign_pos, ops::AssignPosOpCPUKernel, + ops::AssignPosOpCPUKernel); diff --git a/paddle/fluid/operators/assign_pos_op.cu b/paddle/fluid/operators/assign_pos_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5fa159b94f9834e43db1cb0a419eefd2f60181b0 --- /dev/null +++ b/paddle/fluid/operators/assign_pos_op.cu @@ -0,0 +1,94 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/assign_pos_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/float16.h" + +DECLARE_bool(avoid_op_randomness); + +namespace paddle { +namespace operators { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void AssignPos(T* cum_count, const T* numbers, T* out, + int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + int number_idx = numbers[i]; + if (number_idx > -1) { + int p = platform::CudaAtomicAdd(cum_count + number_idx, -1); + out[p - 1] = i; + } + } +} + +template +class AssignPosCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // assign pos decides which tokens should be fetched belong to specially + // counter orderingly. + auto cum_count = context.Input( + "cum_count"); // (counter number) int32 | int64 + auto numbers = + context.Input("X"); // (batch_size * seq_len, topk) int32 + auto eff_num_len = + context.Input("eff_num_len"); // (sum(cum_count)) + auto out = context.Output("Out"); // (cum_count) value ranges + // from 0 to batch_size * + // seq_len * topk + auto place = context.GetPlace(); + auto numel = numbers->numel(); + T* cum_data = const_cast(cum_count->data()); + auto cum_size = cum_count->numel(); + + framework::Tensor cpu_eff_num_len; + int64_t cpu_eff_num_len_data = 0; + if (platform::is_cpu_place(eff_num_len->place())) { + cpu_eff_num_len_data = eff_num_len->data()[0]; + } else { + framework::TensorCopySync(*eff_num_len, platform::CPUPlace(), + &cpu_eff_num_len); + cpu_eff_num_len_data = cpu_eff_num_len.data()[0]; + } + const auto& dev_ctx = + context.template device_context(); + framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data}); + auto out_data = out->mutable_data(out_dims, place); + + const T* num_data = numbers->data(); + + int blocks = NumBlocks(numel); + int threads = kNumCUDAThreads; + + AssignPos<<>>(cum_data, num_data, + out_data, numel); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel); diff --git a/paddle/fluid/operators/assign_pos_op.h b/paddle/fluid/operators/assign_pos_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1a017415778dd058378536284d1a264944c60927 --- /dev/null +++ b/paddle/fluid/operators/assign_pos_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +class AssignPosOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support assign pos op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/number_count_op.cc b/paddle/fluid/operators/number_count_op.cc index 8f7a3b82acf19fa79cbf5c632977e6ae533ae12b..3b7406c997aba2885564f82cae4a21fcc59dcbdc 100644 --- a/paddle/fluid/operators/number_count_op.cc +++ b/paddle/fluid/operators/number_count_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,8 +22,7 @@ class NumberCountOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx", - "NumberCount"); + OP_INOUT_CHECK(ctx->HasInput("numbers"), "Input", "numbers", "NumberCount"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count", "NumberCount"); } @@ -31,25 +30,24 @@ class NumberCountOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - // the dtype of the gate_idx should be same as int64 - auto gate_idx_dtype = - OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx"); + // the dtype of the numbers should be same as int64 + auto number_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "numbers"); - PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64, + PADDLE_ENFORCE_EQ(number_dtype, framework::proto::VarType::INT64, platform::errors::InvalidArgument( - "The dtype of the gate_idx_dtype should be int64")); - return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace()); + "The dtype of the number_dtype should be int64")); + return framework::OpKernelType(number_dtype, ctx.GetPlace()); } }; class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("gate_idx", "(Tensor) The input gate index tensor."); - AddOutput("Out", "(Tensor) The output expert count tensor."); - AddAttr("upper_range", "(int), The number of experts."); + AddInput("numbers", "(Tensor) The input gate index tensor."); + AddOutput("Out", "(Tensor) The output number count tensor."); + AddAttr("upper_range", "(int), The number of different numbers."); - AddComment(R"DOC(number_count Operator.count gate indices.)DOC"); + AddComment(R"DOC(number_count Operator.count numbers.)DOC"); } }; diff --git a/paddle/fluid/operators/number_count_op.cu b/paddle/fluid/operators/number_count_op.cu index 97e4b4f2845ae132c28d3bb71dcc8e73f02e193a..0106c70d8eb53888801e942fc6c7c9ca57644062 100644 --- a/paddle/fluid/operators/number_count_op.cu +++ b/paddle/fluid/operators/number_count_op.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ __global__ void initialize_zero_kernel(T* data, const int length) { } template -__global__ void NumberCount(const T* gate_idx, T* number_count, +__global__ void NumberCount(const T* numbers, T* number_count, int64_t batch_size, int upper_range) { int res_tmp[PERTHREAD_EXPERTS] = {0}; int expert_min = blockIdx.x * PERTHREAD_EXPERTS; @@ -47,7 +47,7 @@ __global__ void NumberCount(const T* gate_idx, T* number_count, expert_max = upper_range; } for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { - T idx = gate_idx[i]; + T idx = numbers[i]; if (idx == -1) { continue; } @@ -76,18 +76,18 @@ template class NumberCountOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto gate_idx = context.Input("gate_idx"); + auto numbers = context.Input("numbers"); auto upper_range = context.Attr("upper_range"); auto number_count = context.Output("Out"); - int64_t batch_size = gate_idx->numel(); + int64_t batch_size = numbers->numel(); auto place = context.GetPlace(); const auto& dev_ctx = context.template device_context(); framework::DDim out_dims = phi::make_ddim({upper_range}); auto out_data = number_count->mutable_data(out_dims, place); - const T* gate_data = gate_idx->data(); + const T* gate_data = numbers->data(); initialize_zero_kernel< T><<>>( diff --git a/paddle/fluid/operators/number_count_op.h b/paddle/fluid/operators/number_count_op.h index 95e64946fb8a2156fdb4cbae880ccf2c143447ed..ded7ea6eec54f7ce08ae610274febdbb4f82d292 100644 --- a/paddle/fluid/operators/number_count_op.h +++ b/paddle/fluid/operators/number_count_op.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index b3500c5724bd2aefc353ba01d92e5deb0af14010..06d8153019f42551a33f5ecc3bfda30c04a07078 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -251,6 +251,7 @@ set MSVC_STATIC_CRT=ON set ON_INFER=ON set WITH_TENSORRT=ON set WITH_INFERENCE_API_TEST=ON +set vcvars64_dir="D:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat" call :cmake || goto cmake_error call :build || goto build_error @@ -314,11 +315,9 @@ echo ======================================== rem set vs language to english to block showIncludes, this need vs has installed English language package. set VSLANG=1033 rem Configure the environment for 64-bit builds. 'DISTUTILS_USE_SDK' indicates that the user has selected the compiler. -echo %task_name%|findstr wincheck_inference >nul && ( - call "D:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat" -) || ( - call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat" -) +if not defined vcvars64_dir set vcvars64_dir="C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat" +call %vcvars64_dir% + set DISTUTILS_USE_SDK=1 rem Windows 10 Kit bin dir set PATH=C:\Program Files (x86)\Windows Kits\10\bin\10.0.17763.0\x64;%PATH% @@ -381,18 +380,8 @@ echo echo ${md5_content}^>md5.txt >> cache.sh %cache_dir%\tools\busybox64.exe bash cache.sh set /p md5=< md5.txt -echo %task_name%|findstr build >nul && ( - set THIRD_PARTY_HOME=%cache_dir:\=/%/third_party - set THIRD_PARTY_PATH=!THIRD_PARTY_HOME!/%md5% - echo %task_name% is a whl-build task, will only reuse local third_party cache. - goto :cmake_impl -) || ( - echo %task_name% is a PR-CI-Windows task, will try to reuse bos and local third_party cache both. -) - if "%WITH_GPU%"=="ON" ( - for /F %%# in ('dir /b /d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\"') do set cuda_version=%%# - set cuda_version=!cuda_version:~-4! + set cuda_version=%CUDA_TOOLKIT_ROOT_DIR:~-4% set sub_dir=cuda!cuda_version:.=! ) else ( set sub_dir=cpu @@ -400,6 +389,13 @@ if "%WITH_GPU%"=="ON" ( set THIRD_PARTY_HOME=%cache_dir:\=/%/third_party/%sub_dir% set THIRD_PARTY_PATH=%THIRD_PARTY_HOME%/%md5% +echo %task_name%|findstr build >nul && ( + echo %task_name% is a whl-build task, will only reuse local third_party cache. + goto :cmake_impl +) || ( + echo %task_name% is a PR-CI-Windows task, will try to reuse bos and local third_party cache both. +) + if not exist %THIRD_PARTY_PATH% ( echo There is no usable third_party cache in %THIRD_PARTY_PATH%, will download from bos. pip install wget diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index f4165d97685f1e6966a3cfd20162155c5399392f..2f437a7d6633b7e552abe32c7b48a33e6c1e3698 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -1024,6 +1024,9 @@ function generate_api_spec() { ${PADDLE_ROOT}/paddle/fluid/op_use_default_grad_maker_${spec_kind}.spec deactivate + + cd ${PADDLE_ROOT}/build + rm -rf ${PADDLE_ROOT}/build/.check_api_workspace } function check_approvals_of_unittest() { diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index f163da4fb999b3b6708ddd846e7e19c2e0c291d1..83a0ff8099cba31c76167d9fe33c28a9ef5ca9f8 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -37,9 +37,45 @@ from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer from paddle import _C_ops from paddle.fluid import core from paddle.fluid.dygraph import to_variable +from paddle.distributed.fleet.utils.recompute import RecomputeFunction +from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar __all__ = [] +_grad_scalar = None + + +class _RecomputeModelWrapper(paddle.nn.Layer): + def __init__(self, model, segments=2, preserve_rng_state=True): + super(_RecomputeModelWrapper, self).__init__() + assert isinstance(model, paddle.nn.Sequential), ( + "The model passed to RecomputeModelWrapper must be of type " + "paddle.nn.Sequential.") + self._model = model + self._segments = segments + self._preserve_rng_state = preserve_rng_state + self._layers = list(model.children()) + self._segment_size = len(self._layers) // segments + + def _run_func(self, begin, end): + def do_run(input): + for i in range(begin, end): + input = self._layers[i](input) + return input + + return do_run + + def _checkpoint(self, func, *args, **kwargs): + return RecomputeFunction.apply(func, self._preserve_rng_state, *args) + + def forward(self, input): + end = 0 + for begin in range(0, self._segment_size * (self._segments - 1), + self._segment_size): + end = begin + self._segment_size + input = self._checkpoint(self._run_func(begin, end), input) + return self._run_func(end, len(self._layers))(input) + def apply_ir_passes(main_program, startup_program, config): build_strategy = config._user_defined_strategy.build_strategy._copy() @@ -952,6 +988,41 @@ class Fleet(object): if self.worker_num() <= 1: return model + amp_enable = False + recompute_enable = False + strategy = self._user_defined_strategy + if strategy.amp == True: + amp_enable = True + amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1" + if amp_level.upper() == "O2": + model = paddle.amp.decorate( + models=model, + optimizers=None, + level="O2", + master_weight=None, + save_dtype=None) + init_loss_scaling = strategy.amp_configs['init_loss_scaling'] + incr_ratio = strategy.amp_configs['incr_ratio'] + decr_ratio = strategy.amp_configs['decr_ratio'] + incr_every_n_steps = strategy.amp_configs['incr_every_n_steps'] + decr_every_n_nan_or_inf = strategy.amp_configs[ + 'decr_every_n_nan_or_inf'] + use_dynamic_loss_scaling = strategy.amp_configs[ + 'use_dynamic_loss_scaling'] + + global _grad_scalar + _grad_scalar = paddle.amp.GradScaler( + init_loss_scaling=init_loss_scaling, + incr_ratio=incr_ratio, + decr_ratio=decr_ratio, + incr_every_n_steps=incr_every_n_steps, + decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) + + if strategy.recompute == True: + recompute_enable = True + model = _RecomputeModelWrapper(model) + if self._user_defined_strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, @@ -964,7 +1035,7 @@ class Fleet(object): return distributed_model if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: - distributed_model = ShardingParallel( + model = ShardingParallel( model, self._hcg, strategy=self._user_defined_strategy) elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: @@ -975,22 +1046,23 @@ class Fleet(object): assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size( ) broadcast_sharding_parameters(model, self._hcg) - distributed_model = paddle.DataParallel( + model = paddle.DataParallel( model, comm_buffer_size=self._user_defined_strategy. fuse_grad_size_in_MB, last_comm_buffer_size=self._user_defined_strategy. last_comm_group_size_MB, find_unused_parameters=self._user_defined_strategy. - find_unused_parameters) + find_unused_parameters, + static_graph=True if recompute_enable else False) elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: - distributed_model = TensorParallel( + model = TensorParallel( model, self._hcg, strategy=self._user_defined_strategy) elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: - distributed_model = PipelineParallel( + model = PipelineParallel( model, self._hcg, strategy=self._user_defined_strategy) - return distributed_model + return model @dygraph_only def state_dict(self): diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index e6febff505e5248e6fc908c62293db2461b3eb32..83a5e18714dff8c75768e4bd4f46d898983b70f0 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -40,9 +40,9 @@ def launch(): - ``--rank``: The rank of the node, can be auto assigned by master. Default ``--rank=-1``. - - ``--log_level``: The log levl to set for logging.setLevel. Default ``--log_level=INFO``. + - ``--log_level``: The log level to set for logging.setLevel which can be CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET, case insensitive. The rank 0 log will not print in the terminal by default, while you can enable it by adding --log_level=debug. Default ``--log_level=INFO``. - - ``--nnodes``: The number of nodes for a distributed job, it can be a range in elastic mode, e.g., ``--nnnodes=2:3``. Default ``--nnodes=1``. + - ``--nnodes``: The number of nodes for a distributed job, it can be a range in elastic mode, e.g., ``--nnodes=2:3``. Default ``--nnodes=1``. - ``--nproc_per_node``: The number of processes to launch on a node. In gpu training, it should be less or equal to the gpus number of you system. e.g., ``--nproc_per_node=8`` @@ -93,9 +93,11 @@ def launch(): Returns: - ``None`` + - ``None`` Examples 0 (master, ip/port auto detection): + .. code-block:: bash + :name: code-block-example-bash0 # For training on multi node, run the following command in one of the nodes @@ -171,7 +173,7 @@ def launch(): .. code-block:: bash :name: code-block-example-bash5 - # To simulate distributed environment using single node, e.g., 2 servers and 4 workers, each worker use single gpu. + # To simulate distributed environment using single node, e.g., 2 servers and 4 workers, each worker use single gpu. export CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01 @@ -226,7 +228,7 @@ def launch(): python -m paddle.distributed.launch --master etcd://10.0.0.1:2379 --nnodes 2:4 train.py # once the number of nodes changes between 2:4 during training, the strategy holds - + """ # initialize the context to run diff --git a/python/paddle/distributed/models/moe/utils.py b/python/paddle/distributed/models/moe/utils.py index fd98c64318c60e2e67af320c51b24e39a3132c43..28cbfb4f4c74a2080fd2700533cf26a988d3fda7 100644 --- a/python/paddle/distributed/models/moe/utils.py +++ b/python/paddle/distributed/models/moe/utils.py @@ -17,11 +17,11 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import in_dygraph_mode -def _number_count(gate_idx, upper_range): +def _number_count(numbers, upper_range): """ calculate the expert count according to the gate index. Args: - gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64. + numbers (Tensor): Tensor. The input gate index whose data type should be int32 or int64. upper_range (int): The number of the experts. Returns: out (Tensor): The output expert count. @@ -30,26 +30,75 @@ def _number_count(gate_idx, upper_range): # required: distributed import paddle - gate_idx = [ + numbers = [ [0, 2], [0, 2] ] upper_range = 6 - gate_idx = paddle.to_tensor(gate_idx, dtype="int32") - number_count = paddle.distributed.utils.number_count(gate_idx, upper_range) + numbers = paddle.to_tensor(numbers, dtype="int32") + number_count = paddle.distributed.utils.number_count(numbers, upper_range) print(number_count) # the result: [2, 0, 2, 0, 0, 0] """ if in_dygraph_mode(): - return core.ops.number_count(gate_idx, 'upper_range', upper_range) + return core.ops.number_count(numbers, 'upper_range', upper_range) else: op_type = 'number_count' helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype) + out = helper.create_variable_for_type_inference(dtype=numbers.dtype) helper.append_op( type=op_type, - inputs={'gate_idx': gate_idx}, + inputs={'numbers': numbers}, outputs={'Out': out}, attrs={'upper_range': upper_range}) return out + + +def _assign_pos(x, cum_count): + """ + Assign pos decides which tokens should be fetched belong to + specially expert orderingly. + + Args: + x (Tensor): Tensor. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32 or int64. + cum_count (Tensor): The cumulative sum tokens of counters. Every element in the list must be a Tensor whose + data type should be int64. + + Returns: + out (Tensor): Assemble numbers in the order of counters. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + number_count = [2, 0, 2, 0] + numbers = [ + [0, 2], + [0, 2] + ] + number_count = paddle.to_tensor(number_count) + numbers = paddle.to_tensor(numbers, dtype="int32") + num_cum = paddle.cumsum(number_count) + pos = paddle.distributed.utils.assign_pos(x=numbers, cum_count=num_cum) + print(pos) # the result: (2, 0, 3, 1) + """ + if in_dygraph_mode(): + return core.ops.assign_pos(x, cum_count, cum_count[-1]) + else: + op_type = 'assign_pos' + + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=cum_count.dtype) + + helper.append_op( + type=op_type, + inputs={ + 'X': [x], + 'cum_count': [cum_count], + "eff_num_len": [cum_count[-1]] + }, + outputs={'Out': [out]}) + return out diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 878fc1c68e4c193e7056a65fc2c45ac121474125..b8a2d958a7311ea8b81a05727838f9aa2d59e6f9 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -31,6 +31,8 @@ import paddle.utils.deprecated as deprecated import paddle.profiler as profiler from paddle import _C_ops +_grad_scalar = None + class TensorHookRemoveHelper(object): """ @@ -265,6 +267,9 @@ def monkey_patch_varbase(): grad_tensor = [] else: grad_tensor = [grad_tensor] + if _grad_scalar: + # When using amp with Fleet DistributedStrategy, we do loss scaling implicitly. + self = _grad_scalar.scale(self) if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu(): # TODO(liuyuhui): Currently only for xpu. Will be removed in the future. scaled_loss = scale_loss(self) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2acf530eea3fbd2d7fbc9cf04d2c792b7175035c..f1a90553283c3e85f29d9842bae6951b02f576f4 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -939,6 +939,7 @@ if (WITH_DISTRIBUTE) set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120) set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60) + set_tests_properties(test_dist_dygraph_apis PROPERTIES TIMEOUT 120) endif() if (WITH_DISTRIBUTE AND NOT APPLE) diff --git a/python/paddle/fluid/tests/unittests/dygraph_fleet_api.py b/python/paddle/fluid/tests/unittests/dygraph_fleet_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9d74e4afd4b7c62f57b5fd39856a18fe799619 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_fleet_api.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import random +import numpy as np +import os +import shutil + +import paddle +import paddle.nn as nn +from paddle.fluid import core +import datetime +from datetime import timedelta +import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard +from paddle.fluid.dygraph.parallel import ParallelEnv + + +class TestDygraphFleetAPI(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + random.seed(2022) + np.random.seed(2022) + self.config() + + def config(self): + self.dtype = "float32" + self.shape = (2, 10, 5) + + def test_dygraph_fleet_api(self): + import paddle.distributed.fleet as fleet + import paddle.distributed as dist + strategy = fleet.DistributedStrategy() + strategy.amp = True + strategy.recompute = True + fleet.init(is_collective=True, strategy=strategy) + net = paddle.nn.Sequential( + paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2)) + net = dist.fleet.distributed_model(net) + data = np.random.uniform(-1, 1, [30, 10]).astype('float32') + data = paddle.to_tensor(data) + net(data) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_assign_pos_op.py b/python/paddle/fluid/tests/unittests/test_assign_pos_op.py new file mode 100644 index 0000000000000000000000000000000000000000..72924f242d211d063b1d547050de79f87f2d8dac --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_assign_pos_op.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import op_test +import numpy as np +import unittest +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.models.moe import utils + + +def assign_pos(x, _cum_count): + cum_count = np.copy(_cum_count) + x = x.reshape(-1) + res = np.zeros((cum_count[-1], ), dtype=np.int64) + for i, idx in enumerate(x): + p = cum_count[idx] + cum_count[idx] -= 1 + if p >= 1: + res[p - 1] = i + return res + + +def count(x, upper_num): + res = np.zeros((upper_num, )).astype(int) + for i in x.reshape(-1): + if i >= 0 and i < len(res): + res[i] += 1 + return res + + +# why defining the assert function specially? +# Becasue assign_pos_op is multithread-op, which can make the order of numbers +# in each counter(bin) is random. But the numbers set is certain in each counter(bin). +np_allclose = np.allclose + + +def assert_allclose(res, out, cum_count): + c0 = 0 + for c in cum_count: + if c == c0: + continue + data1 = np.copy(res[c0:c]) + data2 = np.copy(out[c0:c]) + data1.sort() + data2.sort() + assert np_allclose(data2, data1) + c0 = c + return True + + +def get_redefined_allclose(cum_count): + def redefined_allclose(x, y, *args, **kwargs): + return assert_allclose(x, y, cum_count) + + return redefined_allclose + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestAssignPosOpInt64(op_test.OpTest): + def setUp(self): + x = np.random.randint(0, 16, size=(100, 2)).astype("int64") + y = count(x, 16) + cum_count = np.cumsum(y).astype(x.dtype) + self.op_type = "assign_pos" + self.inputs = { + 'X': x, + "cum_count": cum_count, + "eff_num_len": np.array([cum_count[-1]]) + } + self.outputs = {'Out': assign_pos(x, cum_count)} + self.cum_count = cum_count + + def test_forward(self): + np.allclose = get_redefined_allclose(self.cum_count) + self.check_output_with_place(paddle.CUDAPlace(0)) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestAssignPosAPI(unittest.TestCase): + def setUp(self): + self.x = np.random.randint(0, 16, size=(100, 2)).astype("int64") + y = count(self.x, 16) + self.cum_count = np.cumsum(y).astype(self.x.dtype) + self.out = assign_pos(self.x, self.cum_count) + self.place = paddle.CUDAPlace(0) + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('x', self.x.shape, dtype="int64") + cum_count = paddle.fluid.data( + 'cum_count', self.cum_count.shape, dtype="int64") + out = utils._assign_pos(x, cum_count) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, + "cum_count": self.cum_count}, + fetch_list=[out]) + assert_allclose(res[0], self.out, self.cum_count) + + def test_api_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + cum_count = paddle.to_tensor(self.cum_count).astype(x.dtype) + + out = utils._assign_pos(x, cum_count) + assert_allclose(out.numpy(), self.out, self.cum_count) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py b/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py index 29cae0eb001806b8b908b4e97e93335533185ee0..eb7d05df492ec5475846112ef83f9cd4a6347011 100644 --- a/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py @@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase): remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample( label, self.num_classes, self.num_samples) - print(remapped_label, sampled_class_index) + + def test_group_value(): + for place in self.places: + with paddle.fluid.dygraph.guard(place): + label_np = np.random.randint( + 0, + self.num_classes, (self.batch_size, ), + dtype=self.dtype) + label = paddle.to_tensor(label_np) + + remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample( + label, self.num_classes, self.num_samples, group=True) self.assertRaises(ValueError, test_empty_label) + self.assertRaises(ValueError, test_group_value) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py b/python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6fb99ae9355eefdb6de4f3a1bd0b2712535b83 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py @@ -0,0 +1,27 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphFleetApi(TestMultipleGpus): + def test_dygraph_fleet_api(self): + self.run_mnist_2gpu('dygraph_fleet_api.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py index 15730710adf4cf50439c1cdc466eb1ff85ef8a7d..2b511b9eb442ba7c037161a54e8036a6a9cbb81b 100644 --- a/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py @@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase): return_softmax=True, reduction=None) + def test_group_value(): + for place in self.places: + with paddle.fluid.dygraph.guard(place): + labels_np = np.random.randint( + 0, self.num_class, (self.batch_dim, ), dtype="int64") + logits_np = np.random.uniform( + -0.99, 0.99, + [self.batch_dim, self.num_class]).astype(self.dtype) + labels = paddle.to_tensor(labels_np) + logits = paddle.to_tensor(logits_np) + + loss, softmax = paddle.nn.functional.margin_cross_entropy( + logits, + labels, + margin1=self.margin1, + margin2=self.margin2, + margin3=self.margin3, + scale=self.scale, + return_softmax=True, + reduction=None, + group=True) + self.assertRaises(ValueError, test_dim) self.assertRaises(NotImplementedError, test_label_type) + self.assertRaises(ValueError, test_group_value) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_number_count_op.py b/python/paddle/fluid/tests/unittests/test_number_count_op.py index 0df9d2a3a41b44c18b7e008a271c10544ec4dfa0..9eb89dfeb0e8d9e4538f3a7004da777eafbb2f34 100644 --- a/python/paddle/fluid/tests/unittests/test_number_count_op.py +++ b/python/paddle/fluid/tests/unittests/test_number_count_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,8 +26,8 @@ from paddle.fluid.backward import append_backward from paddle.distributed.models.moe import utils -def count(x, upper_range): - res = np.zeros((upper_range, )).astype(int) +def count(x, upper_num): + res = np.zeros((upper_num, )).astype(int) for i in x.reshape(-1): if i >= 0 and i < len(res): res[i] += 1 @@ -36,14 +36,14 @@ def count(x, upper_range): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestExpertCountOpInt64(op_test.OpTest): +class TestNumberCountOpInt64(op_test.OpTest): def setUp(self): - expert_num = 16 + upper_num = 16 self.op_type = "number_count" - x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64') - self.inputs = {'gate_idx': x} - self.outputs = {'Out': count(x, expert_num)} - self.attrs = {"upper_range": expert_num} + x = np.random.randint(-1, upper_num, size=(1000, 2)).astype('int64') + self.inputs = {'numbers': x} + self.outputs = {'Out': count(x, upper_num)} + self.attrs = {"upper_range": upper_num} def test_forward(self): self.check_output_with_place(paddle.CUDAPlace(0)) @@ -51,19 +51,19 @@ class TestExpertCountOpInt64(op_test.OpTest): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestExpertCountAPI(unittest.TestCase): +class TestNumberCountAPI(unittest.TestCase): def setUp(self): - self.upper_range = 320 + self.upper_num = 320 self.x = np.random.randint( - -1, self.upper_range, size=(6000, 200)).astype('int64') - self.out = count(self.x, self.upper_range) + -1, self.upper_num, size=(6000, 200)).astype('int64') + self.out = count(self.x, self.upper_num) self.place = paddle.CUDAPlace(0) def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('x', self.x.shape, dtype="int64") - out = utils._number_count(x, self.upper_range) + out = utils._number_count(x, self.upper_num) exe = paddle.static.Executor(self.place) res = exe.run(feed={'x': self.x}, fetch_list=[out]) assert np.allclose(res, self.out) @@ -71,7 +71,7 @@ class TestExpertCountAPI(unittest.TestCase): def test_api_dygraph(self): paddle.disable_static() x = paddle.to_tensor(self.x) - out = utils._number_count(x, self.upper_range) + out = utils._number_count(x, self.upper_num) assert np.allclose(out.numpy(), self.out) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 9e78ca6be3f2749e43963f63cdb8b6983f651697..e757fbf53487e0d7cb56e5683a5c1cf0aeb04a52 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1651,16 +1651,21 @@ def class_center_sample(label, num_classes, num_samples, group=None): .. hint:: If the number of the positive class centers is greater than the input num_samples, it keeps all the positive class centers and the shape of sampled_class_center will be [num_positive_class_centers]. - + The API supports CPU, single GPU and multi GPU. + For data parallel mode, set ``group=False``. + + For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group. + Args: label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes) num_classes (int): A positive integer to specify the number of classes at local rank. Note that num_classes of each GPU can be different. num_samples (int): A positive integer to specify the number of class center to sample. - group (Group, optional): The abstract representation of group. - See paddle.distributed.collective.Group. Default is ``None``. + group (Group, optional): The group instance return by paddle.distributed.new_group + or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks). + Default is ``None``. Returns: Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center, @@ -1733,18 +1738,25 @@ def class_center_sample(label, num_classes, num_samples, group=None): #Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True, # [0, 1, 2, 3, 5, 7, 8]) """ - if group is not None and not group.is_member(): + if not (group == False or group is None or hasattr(group, 'is_member')): + raise ValueError( + 'Expected group is False, None or instance of paddle.distributed.collective.Group \ + (got group: {})'.format(group)) + return + + if hasattr(group, 'is_member') and not group.is_member(): return - ring_id = 0 if group is None else group.id + ring_id = 0 rank = 0 nranks = 1 - if core.is_compiled_with_dist(): - parallel_env = paddle.distributed.ParallelEnv() - global_rank = parallel_env.rank - rank = global_rank if group is None else group.get_group_rank( - global_rank) - nranks = parallel_env.world_size if group is None else group.nranks + if group != False: + if core.is_compiled_with_dist(): + parallel_env = paddle.distributed.ParallelEnv() + global_rank = parallel_env.rank + rank = global_rank if group is None else group.get_group_rank( + global_rank) + nranks = parallel_env.world_size if group is None else group.nranks if num_samples > num_classes: raise ValueError( diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 10d4073b80c5998df7931fc8addc2507cb606ef2..b4594986f41a721123926376716490d438c51fbd 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits, r""" .. math:: - L=-\\frac{1}{N}\sum^N_{i=1}\log\\frac{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\\neq y_i} e^{scos\\theta_{y_i}}} + L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}} - where the :math:`\\theta_{y_i}` is the angle between the feature :math:`x` and + where the :math:`\theta_{y_i}` is the angle between the feature :math:`x` and the representation of class :math:`i`. The details of ArcFace loss could be referred to https://arxiv.org/abs/1801.07698. .. hint:: - The API supports model parallel and single GPU. And logits.shape[-1] can be different at each rank. + The API supports single GPU and multi GPU, and don't supports CPU. + + For data parallel mode, set ``group=False``. + + For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group. + And logits.shape[-1] can be different at each rank. Args: logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W. @@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits, margin2 (float, optional): m2 of margin loss, default value is `0.5`. margin3 (float, optional): m3 of margin loss, default value is `0.0`. scale (float, optional): s of margin loss, default value is `64.0`. - group (Group, optional): The abstract representation of group, see paddle.distributed.collective.Group. - Default `None`. + group (Group, optional): The group instance return by paddle.distributed.new_group + or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks). + Default is ``None``. return_softmax (bool, optional): Whether return softmax probability. Default value is `False`. reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, return the average of loss; @@ -1296,24 +1302,32 @@ def margin_cross_entropy(logits, """ assert reduction in ['mean', 'sum', 'none', None] - if group is not None and not group.is_member(): + if not (group == False or group is None or hasattr(group, 'is_member')): + raise ValueError( + 'Expected group is False, None or instance of paddle.distributed.collective.Group \ + (got group: {})'.format(group)) return - ring_id = 0 if group is None else group.id + if hasattr(group, 'is_member') and not group.is_member(): + return + + ring_id = 0 rank = 0 nranks = 1 - if core.is_compiled_with_dist(): - parallel_env = paddle.distributed.ParallelEnv() - global_rank = parallel_env.rank - rank = global_rank if group is None else group.get_group_rank( - global_rank) - nranks = parallel_env.world_size if group is None else group.nranks + if group != False: + ring_id = 0 if group is None else group.id + if core.is_compiled_with_dist(): + parallel_env = paddle.distributed.ParallelEnv() + global_rank = parallel_env.rank + rank = global_rank if group is None else group.get_group_rank( + global_rank) + nranks = parallel_env.world_size if group is None else group.nranks input_dims = len(list(logits.shape)) label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: raise ValueError( - 'Expected nput_dims - 1 = label_dims or input_dims == label_dims\ + 'Expected input_dims - 1 = label_dims or input_dims == label_dims\ (got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=-1) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 09a0d3cb41cbcb1a867e2e61e37946bf0d059805..b5daa290456e3e9d45947e6578db0ca3b0479cdf 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -1124,16 +1124,19 @@ class SimpleRNN(RNNBase): Using key word arguments to construct is recommended. Parameters: - input_size (int): The input size for the first layer's cell. - hidden_size (int): The hidden size for each layer's cell. - num_layers (int, optional): Number of layers. Defaults to 1. + input_size (int): The input size of :math:`x` for the first layer's cell. + hidden_size (int): The hidden size of :math:`h` for each layer's cell. + num_layers (int, optional): Number of recurrent layers. Defaults to 1. direction (str, optional): The direction of the network. It can be "forward" or "bidirect"(or "bidirectional"). When "bidirect", the way to merge outputs of forward and backward is concatenating. Defaults to "forward". - time_major (bool, optional): Whether the first dimension of the input means the - time steps. Defaults to False. - dropout (float, optional): The droput probability. Dropout is applied to the - input of each layer except for the first layer. Defaults to 0. + time_major (bool, optional): Whether the first dimension of the input + means the time steps. If time_major is True, the shape of Tensor is + [time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size]. + Defaults to False. `time_steps` means the length of input sequence. + dropout (float, optional): The droput probability. Dropout is applied + to the input of each layer except for the first layer. The range of + dropout from 0 to 1. Defaults to 0. activation (str, optional): The activation in each SimpleRNN cell. It can be `tanh` or `relu`. Defaults to `tanh`. weight_ih_attr (ParamAttr, optional): The parameter attribute for @@ -1148,13 +1151,13 @@ class SimpleRNN(RNNBase): None). For more information, please refer to :ref:`api_guide_Name`. Inputs: - - **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. + - **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. `time_steps` means the length of the input sequence. - **initial_states** (Tensor, optional): the initial state. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used. - **sequence_length** (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. If `sequence_length` is not None, the inputs are treated as padded sequences. In each input sequence, elements whose time step index are not less than the valid length are treated as paddings. Returns: - - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. + - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - **final_states** (Tensor): final states. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. @@ -1242,16 +1245,19 @@ class LSTM(RNNBase): Using key word arguments to construct is recommended. Parameters: - input_size (int): The input size for the first layer's cell. - hidden_size (int): The hidden size for each layer's cell. - num_layers (int, optional): Number of layers. Defaults to 1. + input_size (int): The input size of :math:`x` for the first layer's cell. + hidden_size (int): The hidden size of :math:`h` for each layer's cell. + num_layers (int, optional): Number of recurrent layers. Defaults to 1. direction (str, optional): The direction of the network. It can be "forward" or "bidirect"(or "bidirectional"). When "bidirect", the way to merge outputs of forward and backward is concatenating. Defaults to "forward". time_major (bool, optional): Whether the first dimension of the input - means the time steps. Defaults to False. + means the time steps. If time_major is True, the shape of Tensor is + [time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size]. + Defaults to False. `time_steps` means the length of input sequence. dropout (float, optional): The droput probability. Dropout is applied - to the input of each layer except for the first layer. Defaults to 0. + to the input of each layer except for the first layer. The range of + dropout from 0 to 1. Defaults to 0. weight_ih_attr (ParamAttr, optional): The parameter attribute for `weight_ih` of each cell. Default: None. weight_hh_attr (ParamAttr, optional): The parameter attribute for @@ -1264,13 +1270,13 @@ class LSTM(RNNBase): None). For more information, please refer to :ref:`api_guide_Name`. Inputs: - - **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. + - **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. `time_steps` means the length of the input sequence. - **initial_states** (list|tuple, optional): the initial state, a list/tuple of (h, c), the shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used. - **sequence_length** (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. If `sequence_length` is not None, the inputs are treated as padded sequences. In each input sequence, elements whos time step index are not less than the valid length are treated as paddings. Returns: - - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. + - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. @@ -1349,16 +1355,19 @@ class GRU(RNNBase): Using key word arguments to construct is recommended. Parameters: - input_size (int): The input size for the first layer's cell. - hidden_size (int): The hidden size for each layer's cell. - num_layers (int, optional): Number of layers. Defaults to 1. + input_size (int): The input size of :math:`x` for the first layer's cell. + hidden_size (int): The hidden size of :math:`h` for each layer's cell. + num_layers (int, optional): Number of recurrent layers. Defaults to 1. direction (str, optional): The direction of the network. It can be "forward" or "bidirect"(or "bidirectional"). When "bidirect", the way to merge outputs of forward and backward is concatenating. Defaults to "forward". time_major (bool, optional): Whether the first dimension of the input - means the time steps. Defaults to False. + means the time steps. If time_major is True, the shape of Tensor is + [time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size]. + Defaults to False. `time_steps` means the length of input sequence. dropout (float, optional): The droput probability. Dropout is applied - to the input of each layer except for the first layer. Defaults to 0. + to the input of each layer except for the first layer. The range of + dropout from 0 to 1. Defaults to 0. weight_ih_attr (ParamAttr, optional): The parameter attribute for `weight_ih` of each cell. Default: None. weight_hh_attr (ParamAttr, optional): The parameter attribute for @@ -1371,13 +1380,13 @@ class GRU(RNNBase): None). For more information, please refer to :ref:`api_guide_Name`. Inputs: - - **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. + - **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. `time_steps` means the length of the input sequence. - **initial_states** (Tensor, optional): the initial state. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used. Defaults to None. - **sequence_length** (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. If `sequence_length` is not None, the inputs are treated as padded sequences. In each input sequence, elements whos time step index are not less than the valid length are treated as paddings. Returns: - - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. + - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - **final_states** (Tensor): final states. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. diff --git a/python/paddle/text/datasets/wmt14.py b/python/paddle/text/datasets/wmt14.py index 7c8a549e7cb97453a421379e4a440e8a13a23487..a6d49d70ab3e307179536afe14851c57d40c99bc 100644 --- a/python/paddle/text/datasets/wmt14.py +++ b/python/paddle/text/datasets/wmt14.py @@ -55,8 +55,10 @@ class WMT14(Dataset): :attr:`data_file` is not set. Default True Returns: - Dataset: instance of WMT14 dataset - + Dataset: Instance of WMT14 dataset + - src_ids (np.array) - The sequence of token ids of source language. + - trg_ids (np.array) - The sequence of token ids of target language. + - trg_ids_next (np.array) - The next sequence of token ids of target language. Examples: .. code-block:: python diff --git a/python/paddle/text/datasets/wmt16.py b/python/paddle/text/datasets/wmt16.py index f95cbe771cadc834a4de697660caa22a0729521e..5e88023a49d80ccc619754322900b1d53c6760f9 100644 --- a/python/paddle/text/datasets/wmt16.py +++ b/python/paddle/text/datasets/wmt16.py @@ -71,7 +71,10 @@ class WMT16(Dataset): :attr:`data_file` is not set. Default True Returns: - Dataset: instance of WMT16 dataset + Dataset: Instance of WMT16 dataset. The instance of dataset has 3 fields: + - src_ids (np.array) - The sequence of token ids of source language. + - trg_ids (np.array) - The sequence of token ids of target language. + - trg_ids_next (np.array) - The next sequence of token ids of target language. Examples: