提交 4b5b19a5 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_some_yaml_config

......@@ -90,9 +90,6 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread
gc_.reset(nullptr);
exception_notifier_->UnregisterEvent();
completion_notifier_->UnregisterEvent();
async_work_queue_.reset(nullptr);
}
......
......@@ -19,37 +19,79 @@
namespace paddle {
namespace framework {
constexpr EventsWaiter::EventId kEmptyEventId = 0;
EventsWaiter::EventsWaiter()
: trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {}
: trigger_event_(kEmptyEventId),
counter_(0),
eof_(true),
waiting_(false),
cv_(1) {}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
EventId id = kEmptyEventId;
EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::LevelTriggered;
evt->checker = std::move(checker);
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name) {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
EventId id = kEmptyEventId;
EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::EdgeTriggered;
evt->checker = []() { return false; };
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}
void EventsWaiter::UnregisterEvent(const EventId& id) {
VLOG(10) << "Unregister event id:" << id;
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_.erase(id);
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
deleted_events_.insert(id);
if (deleted_events_.size() == events_.size()) {
eof_.store(true, std::memory_order_relaxed);
}
}
if (eof_.load(std::memory_order_relaxed)) {
cv_.Notify(true);
}
}
std::string EventsWaiter::WaitEvent() {
......@@ -61,42 +103,60 @@ std::string EventsWaiter::WaitEvent() {
PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting."));
}
auto w = cv_.GetWaiter(0);
cv_.Prewait();
std::string* triggered = trigger_event_;
if (triggered == nullptr) {
EventId triggered = trigger_event_;
while (triggered == kEmptyEventId && !eof_) {
cv_.Prewait();
// double check
triggered = trigger_event_;
// checkers
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
for (auto& kv : events_) {
auto& evt = kv.second;
if (TriggerType::LevelTriggered == evt.type && evt.checker()) {
triggered = new std::string(evt.name);
break;
if (triggered == kEmptyEventId) {
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
for (auto& kv : events_) {
auto& evt = kv.second;
if (TriggerType::LevelTriggered == evt.type && evt.checker()) {
triggered = evt.id;
break;
}
}
}
}
if (triggered != nullptr) {
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, triggered,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete triggered;
triggered = prev;
if (triggered != kEmptyEventId) {
EventId prev = kEmptyEventId;
if (!trigger_event_.compare_exchange_strong(
prev, triggered, std::memory_order_seq_cst,
std::memory_order_relaxed)) {
triggered = prev;
}
}
}
if (triggered != kEmptyEventId || eof_) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
triggered = trigger_event_;
}
}
if (triggered) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
triggered = trigger_event_;
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false, std::memory_order_relaxed);
std::string evt_name =
triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered);
VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name;
// lazy deletion
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (deleted_events_.size() > 0) {
for (auto evt : deleted_events_) {
events_.erase(evt);
}
deleted_events_.clear();
}
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
waiting_.store(false);
auto trigger_event = *triggered;
delete triggered;
return trigger_event;
return evt_name;
}
int EventsWaiter::Clear() {
......@@ -106,32 +166,33 @@ int EventsWaiter::Clear() {
std::memory_order_relaxed)) {
return -1;
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false);
return 0;
}
void EventsWaiter::TriggerEvent(const EventId& id) {
VLOG(10) << "Try to trigger event id:" << id;
std::string* trigger_event = new std::string;
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
auto iter = events_.find(id);
if (iter == events_.end()) {
delete trigger_event;
return;
}
*trigger_event = iter->second.name;
EventId prev = kEmptyEventId;
if (!trigger_event_.compare_exchange_strong(
prev, id, std::memory_order_seq_cst, std::memory_order_relaxed)) {
VLOG(10) << "Event id:" << prev << " is pending";
return;
}
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, trigger_event,
VLOG(10) << "Triggered event id:" << id;
cv_.Notify(true);
}
void EventsWaiter::CancelEvent(const EventId& id) {
VLOG(10) << "Try to cancel event id:" << id;
EventId prev = id;
if (!trigger_event_.compare_exchange_strong(prev, kEmptyEventId,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete trigger_event;
VLOG(10) << "Event id:" << prev << " is pending";
return;
}
VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event;
cv_.Notify(true);
VLOG(10) << "Cancelled event id:" << id;
}
std::string EventsWaiter::GetEventName(const EventId& id) {
......
......@@ -19,6 +19,7 @@
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/new_executor/workqueue/event_count.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
......@@ -37,13 +38,12 @@ class EventsWaiter {
// Make sure EventsWaiter has a longer lifetime than EventNotifier.
class EventNotifier {
public:
void NotifyEvent() { waiter_.TriggerEvent(id_); }
~EventNotifier() { waiter_.UnregisterEvent(id_); }
void UnregisterEvent() { waiter_.UnregisterEvent(id_); }
void NotifyEvent() { waiter_.TriggerEvent(id_); }
EventId GetEventId() { return id_; }
void CancelEvent() { waiter_.CancelEvent(id_); }
// return "Unregistered" if the corresponding event was unregistered.
std::string GetEventName() { return waiter_.GetEventName(id_); }
private:
......@@ -97,12 +97,16 @@ class EventsWaiter {
void TriggerEvent(const EventId& id);
void CancelEvent(const EventId& id);
std::string GetEventName(const EventId& id);
std::unordered_map<EventId, EventInfo> events_;
std::unordered_set<EventId> deleted_events_;
paddle::memory::SpinLock events_lock_;
std::atomic<std::string*> trigger_event_;
std::atomic<EventId> trigger_event_;
std::atomic<uint64_t> counter_;
std::atomic<bool> eof_;
std::atomic<bool> waiting_;
EventCount cv_;
};
......
......@@ -31,11 +31,8 @@ class WorkQueueImpl : public WorkQueue {
public:
explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
if (options_.track_task && options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options_.detached == false && options.events_waiter != nullptr) {
......@@ -47,9 +44,6 @@ class WorkQueueImpl : public WorkQueue {
}
virtual ~WorkQueueImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
delete queue_;
if (tracker_ != nullptr) {
tracker_->~TaskTracker();
......@@ -57,7 +51,6 @@ class WorkQueueImpl : public WorkQueue {
}
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}
......@@ -124,14 +117,12 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr &&
options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options.detached == false && options.events_waiter != nullptr) {
if (options.detached == false && options.events_waiter != nullptr &&
!destruct_notifier_) {
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
}
......@@ -141,9 +132,6 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
}
WorkQueueGroupImpl::~WorkQueueGroupImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
for (auto queue : queues_) {
queue->~NonblockingThreadPool();
}
......@@ -154,7 +142,6 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() {
free(queues_storage_);
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include <atomic>
#include <thread>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
......@@ -26,11 +27,12 @@ TEST(WorkQueueUtils, TestEventsWaiter) {
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
EXPECT_EQ(notifier->GetEventName(), "test_register_lt");
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
notifier->UnregisterEvent();
EXPECT_EQ(notifier->GetEventName(), "Unregistered");
notifier.reset();
notifier = events_waiter.RegisterEvent("test_register_et");
notifier->NotifyEvent();
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et");
notifier->NotifyEvent();
notifier->CancelEvent();
}
TEST(WorkQueue, TestSingleThreadedWorkQueue) {
......@@ -106,8 +108,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
// Cancel
work_queue->Cancel();
// Wait kQueueDestructEvent
std::thread waiter_thread([&events_waiter]() {
EXPECT_EQ(events_waiter.WaitEvent(),
paddle::framework::kQueueDestructEvent);
});
work_queue.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
waiter_thread.join();
}
TEST(WorkQueue, TestWorkQueueGroup) {
......@@ -154,10 +161,15 @@ TEST(WorkQueue, TestWorkQueueGroup) {
// WaitQueueGroupEmpty
events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
EXPECT_EQ(handle.get(), random_num);
// Cancel
queue_group->Cancel();
events_waiter.WaitEvent();
// Wait kQueueDestructEvent
std::thread waiter_thread([&events_waiter]() {
EXPECT_EQ(events_waiter.WaitEvent(),
paddle::framework::kQueueDestructEvent);
EXPECT_EQ(events_waiter.WaitEvent(), "NoEventNotifier");
});
queue_group.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
EXPECT_EQ(handle.get(), random_num);
waiter_thread.join();
}
......@@ -81,7 +81,13 @@ class TaskTracker {
~TaskTracker() = default;
void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); }
void AddCounter() {
if (0 == num_tasks_.fetch_add(1, std::memory_order_relaxed)) {
if (notifier_ != nullptr) {
notifier_->CancelEvent();
}
}
}
void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/assign_pos_op.h"
namespace paddle {
namespace operators {
class AssignPosOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("cum_count"), "Input", "cum_count",
"AssignPos");
OP_INOUT_CHECK(ctx->HasInput("eff_num_len"), "Input", "eff_num_len",
"AssignPos");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AssignPos");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AssignPos");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto cum_count_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "cum_count");
auto X_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
PADDLE_ENFORCE_EQ(cum_count_dtype, X_dtype,
platform::errors::InvalidArgument(
"The dtype of the cum_count and X should be same"));
PADDLE_ENFORCE_EQ(cum_count_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the cum_count_dtype, eff_num_len and "
"X should be same as int64"));
return framework::OpKernelType(cum_count_dtype, ctx.device_context());
}
};
class AssignPosOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "numbers to scatter.");
AddInput("cum_count", "The cumulative sum count of numbers.");
AddInput("eff_num_len",
"The effective numbers of numbers should be scattered.");
AddOutput("Out", "Assemble numbers in the order of counters.");
AddComment(R"DOC(
assign_pos_op Operator.
Assign pos decides which tokens should be fetched belong to
specially counter orderingly.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ops::AssignPosOp,
ops::AssignPosOpMaker);
REGISTER_OP_CPU_KERNEL(assign_pos, ops::AssignPosOpCPUKernel<int>,
ops::AssignPosOpCPUKernel<int64_t>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/assign_pos_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool(avoid_op_randomness);
namespace paddle {
namespace operators {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void AssignPos(T* cum_count, const T* numbers, T* out,
int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
int number_idx = numbers[i];
if (number_idx > -1) {
int p = platform::CudaAtomicAdd(cum_count + number_idx, -1);
out[p - 1] = i;
}
}
}
template <typename T>
class AssignPosCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// assign pos decides which tokens should be fetched belong to specially
// counter orderingly.
auto cum_count = context.Input<LoDTensor>(
"cum_count"); // (counter number) int32 | int64
auto numbers =
context.Input<LoDTensor>("X"); // (batch_size * seq_len, topk) int32
auto eff_num_len =
context.Input<LoDTensor>("eff_num_len"); // (sum(cum_count))
auto out = context.Output<LoDTensor>("Out"); // (cum_count) value ranges
// from 0 to batch_size *
// seq_len * topk
auto place = context.GetPlace();
auto numel = numbers->numel();
T* cum_data = const_cast<T*>(cum_count->data<T>());
auto cum_size = cum_count->numel();
framework::Tensor cpu_eff_num_len;
int64_t cpu_eff_num_len_data = 0;
if (platform::is_cpu_place(eff_num_len->place())) {
cpu_eff_num_len_data = eff_num_len->data<T>()[0];
} else {
framework::TensorCopySync(*eff_num_len, platform::CPUPlace(),
&cpu_eff_num_len);
cpu_eff_num_len_data = cpu_eff_num_len.data<T>()[0];
}
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data});
auto out_data = out->mutable_data<T>(out_dims, place);
const T* num_data = numbers->data<T>();
int blocks = NumBlocks(numel);
int threads = kNumCUDAThreads;
AssignPos<T><<<blocks, threads, 0, dev_ctx.stream()>>>(cum_data, num_data,
out_data, numel);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel<int64_t>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename T>
class AssignPosOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support assign pos op for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -22,8 +22,7 @@ class NumberCountOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx",
"NumberCount");
OP_INOUT_CHECK(ctx->HasInput("numbers"), "Input", "numbers", "NumberCount");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count",
"NumberCount");
}
......@@ -31,25 +30,24 @@ class NumberCountOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// the dtype of the gate_idx should be same as int64
auto gate_idx_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx");
// the dtype of the numbers should be same as int64
auto number_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "numbers");
PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64,
PADDLE_ENFORCE_EQ(number_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the gate_idx_dtype should be int64"));
return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace());
"The dtype of the number_dtype should be int64"));
return framework::OpKernelType(number_dtype, ctx.GetPlace());
}
};
class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("gate_idx", "(Tensor) The input gate index tensor.");
AddOutput("Out", "(Tensor) The output expert count tensor.");
AddAttr<int>("upper_range", "(int), The number of experts.");
AddInput("numbers", "(Tensor) The input gate index tensor.");
AddOutput("Out", "(Tensor) The output number count tensor.");
AddAttr<int>("upper_range", "(int), The number of different numbers.");
AddComment(R"DOC(number_count Operator.count gate indices.)DOC");
AddComment(R"DOC(number_count Operator.count numbers.)DOC");
}
};
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -38,7 +38,7 @@ __global__ void initialize_zero_kernel(T* data, const int length) {
}
template <typename T>
__global__ void NumberCount(const T* gate_idx, T* number_count,
__global__ void NumberCount(const T* numbers, T* number_count,
int64_t batch_size, int upper_range) {
int res_tmp[PERTHREAD_EXPERTS] = {0};
int expert_min = blockIdx.x * PERTHREAD_EXPERTS;
......@@ -47,7 +47,7 @@ __global__ void NumberCount(const T* gate_idx, T* number_count,
expert_max = upper_range;
}
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
T idx = gate_idx[i];
T idx = numbers[i];
if (idx == -1) {
continue;
}
......@@ -76,18 +76,18 @@ template <typename T>
class NumberCountOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto gate_idx = context.Input<LoDTensor>("gate_idx");
auto numbers = context.Input<LoDTensor>("numbers");
auto upper_range = context.Attr<int>("upper_range");
auto number_count = context.Output<LoDTensor>("Out");
int64_t batch_size = gate_idx->numel();
int64_t batch_size = numbers->numel();
auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
framework::DDim out_dims = phi::make_ddim({upper_range});
auto out_data = number_count->mutable_data<T>(out_dims, place);
const T* gate_data = gate_idx->data<T>();
const T* gate_data = numbers->data<T>();
initialize_zero_kernel<
T><<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
......@@ -251,6 +251,7 @@ set MSVC_STATIC_CRT=ON
set ON_INFER=ON
set WITH_TENSORRT=ON
set WITH_INFERENCE_API_TEST=ON
set vcvars64_dir="D:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat"
call :cmake || goto cmake_error
call :build || goto build_error
......@@ -314,11 +315,9 @@ echo ========================================
rem set vs language to english to block showIncludes, this need vs has installed English language package.
set VSLANG=1033
rem Configure the environment for 64-bit builds. 'DISTUTILS_USE_SDK' indicates that the user has selected the compiler.
echo %task_name%|findstr wincheck_inference >nul && (
call "D:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat"
) || (
call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat"
)
if not defined vcvars64_dir set vcvars64_dir="C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat"
call %vcvars64_dir%
set DISTUTILS_USE_SDK=1
rem Windows 10 Kit bin dir
set PATH=C:\Program Files (x86)\Windows Kits\10\bin\10.0.17763.0\x64;%PATH%
......@@ -381,18 +380,8 @@ echo echo ${md5_content}^>md5.txt >> cache.sh
%cache_dir%\tools\busybox64.exe bash cache.sh
set /p md5=< md5.txt
echo %task_name%|findstr build >nul && (
set THIRD_PARTY_HOME=%cache_dir:\=/%/third_party
set THIRD_PARTY_PATH=!THIRD_PARTY_HOME!/%md5%
echo %task_name% is a whl-build task, will only reuse local third_party cache.
goto :cmake_impl
) || (
echo %task_name% is a PR-CI-Windows task, will try to reuse bos and local third_party cache both.
)
if "%WITH_GPU%"=="ON" (
for /F %%# in ('dir /b /d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\"') do set cuda_version=%%#
set cuda_version=!cuda_version:~-4!
set cuda_version=%CUDA_TOOLKIT_ROOT_DIR:~-4%
set sub_dir=cuda!cuda_version:.=!
) else (
set sub_dir=cpu
......@@ -400,6 +389,13 @@ if "%WITH_GPU%"=="ON" (
set THIRD_PARTY_HOME=%cache_dir:\=/%/third_party/%sub_dir%
set THIRD_PARTY_PATH=%THIRD_PARTY_HOME%/%md5%
echo %task_name%|findstr build >nul && (
echo %task_name% is a whl-build task, will only reuse local third_party cache.
goto :cmake_impl
) || (
echo %task_name% is a PR-CI-Windows task, will try to reuse bos and local third_party cache both.
)
if not exist %THIRD_PARTY_PATH% (
echo There is no usable third_party cache in %THIRD_PARTY_PATH%, will download from bos.
pip install wget
......
......@@ -1024,6 +1024,9 @@ function generate_api_spec() {
${PADDLE_ROOT}/paddle/fluid/op_use_default_grad_maker_${spec_kind}.spec
deactivate
cd ${PADDLE_ROOT}/build
rm -rf ${PADDLE_ROOT}/build/.check_api_workspace
}
function check_approvals_of_unittest() {
......
......@@ -37,9 +37,45 @@ from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
from paddle.distributed.fleet.utils.recompute import RecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
__all__ = []
_grad_scalar = None
class _RecomputeModelWrapper(paddle.nn.Layer):
def __init__(self, model, segments=2, preserve_rng_state=True):
super(_RecomputeModelWrapper, self).__init__()
assert isinstance(model, paddle.nn.Sequential), (
"The model passed to RecomputeModelWrapper must be of type "
"paddle.nn.Sequential.")
self._model = model
self._segments = segments
self._preserve_rng_state = preserve_rng_state
self._layers = list(model.children())
self._segment_size = len(self._layers) // segments
def _run_func(self, begin, end):
def do_run(input):
for i in range(begin, end):
input = self._layers[i](input)
return input
return do_run
def _checkpoint(self, func, *args, **kwargs):
return RecomputeFunction.apply(func, self._preserve_rng_state, *args)
def forward(self, input):
end = 0
for begin in range(0, self._segment_size * (self._segments - 1),
self._segment_size):
end = begin + self._segment_size
input = self._checkpoint(self._run_func(begin, end), input)
return self._run_func(end, len(self._layers))(input)
def apply_ir_passes(main_program, startup_program, config):
build_strategy = config._user_defined_strategy.build_strategy._copy()
......@@ -952,6 +988,41 @@ class Fleet(object):
if self.worker_num() <= 1:
return model
amp_enable = False
recompute_enable = False
strategy = self._user_defined_strategy
if strategy.amp == True:
amp_enable = True
amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1"
if amp_level.upper() == "O2":
model = paddle.amp.decorate(
models=model,
optimizers=None,
level="O2",
master_weight=None,
save_dtype=None)
init_loss_scaling = strategy.amp_configs['init_loss_scaling']
incr_ratio = strategy.amp_configs['incr_ratio']
decr_ratio = strategy.amp_configs['decr_ratio']
incr_every_n_steps = strategy.amp_configs['incr_every_n_steps']
decr_every_n_nan_or_inf = strategy.amp_configs[
'decr_every_n_nan_or_inf']
use_dynamic_loss_scaling = strategy.amp_configs[
'use_dynamic_loss_scaling']
global _grad_scalar
_grad_scalar = paddle.amp.GradScaler(
init_loss_scaling=init_loss_scaling,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
incr_every_n_steps=incr_every_n_steps,
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if strategy.recompute == True:
recompute_enable = True
model = _RecomputeModelWrapper(model)
if self._user_defined_strategy.heter_ccl_mode == True:
distributed_model = paddle.DataParallel(
model,
......@@ -964,7 +1035,7 @@ class Fleet(object):
return distributed_model
if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
distributed_model = ShardingParallel(
model = ShardingParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
......@@ -975,22 +1046,23 @@ class Fleet(object):
assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
)
broadcast_sharding_parameters(model, self._hcg)
distributed_model = paddle.DataParallel(
model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
find_unused_parameters,
static_graph=True if recompute_enable else False)
elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
distributed_model = TensorParallel(
model = TensorParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
model = PipelineParallel(
model, self._hcg, strategy=self._user_defined_strategy)
return distributed_model
return model
@dygraph_only
def state_dict(self):
......
......@@ -40,9 +40,9 @@ def launch():
- ``--rank``: The rank of the node, can be auto assigned by master. Default ``--rank=-1``.
- ``--log_level``: The log levl to set for logging.setLevel. Default ``--log_level=INFO``.
- ``--log_level``: The log level to set for logging.setLevel which can be CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET, case insensitive. The rank 0 log will not print in the terminal by default, while you can enable it by adding --log_level=debug. Default ``--log_level=INFO``.
- ``--nnodes``: The number of nodes for a distributed job, it can be a range in elastic mode, e.g., ``--nnnodes=2:3``. Default ``--nnodes=1``.
- ``--nnodes``: The number of nodes for a distributed job, it can be a range in elastic mode, e.g., ``--nnodes=2:3``. Default ``--nnodes=1``.
- ``--nproc_per_node``: The number of processes to launch on a node. In gpu training, it should be less or equal to the gpus number of you system. e.g., ``--nproc_per_node=8``
......@@ -93,9 +93,11 @@ def launch():
Returns:
``None``
- ``None``
Examples 0 (master, ip/port auto detection):
.. code-block:: bash
:name: code-block-example-bash0
# For training on multi node, run the following command in one of the nodes
......@@ -171,7 +173,7 @@ def launch():
.. code-block:: bash
:name: code-block-example-bash5
# To simulate distributed environment using single node, e.g., 2 servers and 4 workers, each worker use single gpu.
# To simulate distributed environment using single node, e.g., 2 servers and 4 workers, each worker use single gpu.
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01
......@@ -226,7 +228,7 @@ def launch():
python -m paddle.distributed.launch --master etcd://10.0.0.1:2379 --nnodes 2:4 train.py
# once the number of nodes changes between 2:4 during training, the strategy holds
"""
# initialize the context to run
......
......@@ -17,11 +17,11 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
def _number_count(gate_idx, upper_range):
def _number_count(numbers, upper_range):
"""
calculate the expert count according to the gate index.
Args:
gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64.
numbers (Tensor): Tensor. The input gate index whose data type should be int32 or int64.
upper_range (int): The number of the experts.
Returns:
out (Tensor): The output expert count.
......@@ -30,26 +30,75 @@ def _number_count(gate_idx, upper_range):
# required: distributed
import paddle
gate_idx = [
numbers = [
[0, 2],
[0, 2]
]
upper_range = 6
gate_idx = paddle.to_tensor(gate_idx, dtype="int32")
number_count = paddle.distributed.utils.number_count(gate_idx, upper_range)
numbers = paddle.to_tensor(numbers, dtype="int32")
number_count = paddle.distributed.utils.number_count(numbers, upper_range)
print(number_count) # the result: [2, 0, 2, 0, 0, 0]
"""
if in_dygraph_mode():
return core.ops.number_count(gate_idx, 'upper_range', upper_range)
return core.ops.number_count(numbers, 'upper_range', upper_range)
else:
op_type = 'number_count'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype)
out = helper.create_variable_for_type_inference(dtype=numbers.dtype)
helper.append_op(
type=op_type,
inputs={'gate_idx': gate_idx},
inputs={'numbers': numbers},
outputs={'Out': out},
attrs={'upper_range': upper_range})
return out
def _assign_pos(x, cum_count):
"""
Assign pos decides which tokens should be fetched belong to
specially expert orderingly.
Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
cum_count (Tensor): The cumulative sum tokens of counters. Every element in the list must be a Tensor whose
data type should be int64.
Returns:
out (Tensor): Assemble numbers in the order of counters.
Examples:
.. code-block:: python
# required: distributed
import paddle
number_count = [2, 0, 2, 0]
numbers = [
[0, 2],
[0, 2]
]
number_count = paddle.to_tensor(number_count)
numbers = paddle.to_tensor(numbers, dtype="int32")
num_cum = paddle.cumsum(number_count)
pos = paddle.distributed.utils.assign_pos(x=numbers, cum_count=num_cum)
print(pos) # the result: (2, 0, 3, 1)
"""
if in_dygraph_mode():
return core.ops.assign_pos(x, cum_count, cum_count[-1])
else:
op_type = 'assign_pos'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=cum_count.dtype)
helper.append_op(
type=op_type,
inputs={
'X': [x],
'cum_count': [cum_count],
"eff_num_len": [cum_count[-1]]
},
outputs={'Out': [out]})
return out
......@@ -31,6 +31,8 @@ import paddle.utils.deprecated as deprecated
import paddle.profiler as profiler
from paddle import _C_ops
_grad_scalar = None
class TensorHookRemoveHelper(object):
"""
......@@ -265,6 +267,9 @@ def monkey_patch_varbase():
grad_tensor = []
else:
grad_tensor = [grad_tensor]
if _grad_scalar:
# When using amp with Fleet DistributedStrategy, we do loss scaling implicitly.
self = _grad_scalar.scale(self)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu():
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self)
......
......@@ -939,6 +939,7 @@ if (WITH_DISTRIBUTE)
set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60)
set_tests_properties(test_dist_dygraph_apis PROPERTIES TIMEOUT 120)
endif()
if (WITH_DISTRIBUTE AND NOT APPLE)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import random
import numpy as np
import os
import shutil
import paddle
import paddle.nn as nn
from paddle.fluid import core
import datetime
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
class TestDygraphFleetAPI(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
self.config()
def config(self):
self.dtype = "float32"
self.shape = (2, 10, 5)
def test_dygraph_fleet_api(self):
import paddle.distributed.fleet as fleet
import paddle.distributed as dist
strategy = fleet.DistributedStrategy()
strategy.amp = True
strategy.recompute = True
fleet.init(is_collective=True, strategy=strategy)
net = paddle.nn.Sequential(
paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2))
net = dist.fleet.distributed_model(net)
data = np.random.uniform(-1, 1, [30, 10]).astype('float32')
data = paddle.to_tensor(data)
net(data)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import op_test
import numpy as np
import unittest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.models.moe import utils
def assign_pos(x, _cum_count):
cum_count = np.copy(_cum_count)
x = x.reshape(-1)
res = np.zeros((cum_count[-1], ), dtype=np.int64)
for i, idx in enumerate(x):
p = cum_count[idx]
cum_count[idx] -= 1
if p >= 1:
res[p - 1] = i
return res
def count(x, upper_num):
res = np.zeros((upper_num, )).astype(int)
for i in x.reshape(-1):
if i >= 0 and i < len(res):
res[i] += 1
return res
# why defining the assert function specially?
# Becasue assign_pos_op is multithread-op, which can make the order of numbers
# in each counter(bin) is random. But the numbers set is certain in each counter(bin).
np_allclose = np.allclose
def assert_allclose(res, out, cum_count):
c0 = 0
for c in cum_count:
if c == c0:
continue
data1 = np.copy(res[c0:c])
data2 = np.copy(out[c0:c])
data1.sort()
data2.sort()
assert np_allclose(data2, data1)
c0 = c
return True
def get_redefined_allclose(cum_count):
def redefined_allclose(x, y, *args, **kwargs):
return assert_allclose(x, y, cum_count)
return redefined_allclose
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestAssignPosOpInt64(op_test.OpTest):
def setUp(self):
x = np.random.randint(0, 16, size=(100, 2)).astype("int64")
y = count(x, 16)
cum_count = np.cumsum(y).astype(x.dtype)
self.op_type = "assign_pos"
self.inputs = {
'X': x,
"cum_count": cum_count,
"eff_num_len": np.array([cum_count[-1]])
}
self.outputs = {'Out': assign_pos(x, cum_count)}
self.cum_count = cum_count
def test_forward(self):
np.allclose = get_redefined_allclose(self.cum_count)
self.check_output_with_place(paddle.CUDAPlace(0))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestAssignPosAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.randint(0, 16, size=(100, 2)).astype("int64")
y = count(self.x, 16)
self.cum_count = np.cumsum(y).astype(self.x.dtype)
self.out = assign_pos(self.x, self.cum_count)
self.place = paddle.CUDAPlace(0)
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('x', self.x.shape, dtype="int64")
cum_count = paddle.fluid.data(
'cum_count', self.cum_count.shape, dtype="int64")
out = utils._assign_pos(x, cum_count)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x,
"cum_count": self.cum_count},
fetch_list=[out])
assert_allclose(res[0], self.out, self.cum_count)
def test_api_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
cum_count = paddle.to_tensor(self.cum_count).astype(x.dtype)
out = utils._assign_pos(x, cum_count)
assert_allclose(out.numpy(), self.out, self.cum_count)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase):
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples)
print(remapped_label, sampled_class_index)
def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
label_np = np.random.randint(
0,
self.num_classes, (self.batch_size, ),
dtype=self.dtype)
label = paddle.to_tensor(label_np)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples, group=True)
self.assertRaises(ValueError, test_empty_label)
self.assertRaises(ValueError, test_group_value)
if __name__ == '__main__':
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphFleetApi(TestMultipleGpus):
def test_dygraph_fleet_api(self):
self.run_mnist_2gpu('dygraph_fleet_api.py')
if __name__ == "__main__":
unittest.main()
......@@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
return_softmax=True,
reduction=None)
def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
labels_np = np.random.randint(
0, self.num_class, (self.batch_dim, ), dtype="int64")
logits_np = np.random.uniform(
-0.99, 0.99,
[self.batch_dim, self.num_class]).astype(self.dtype)
labels = paddle.to_tensor(labels_np)
logits = paddle.to_tensor(logits_np)
loss, softmax = paddle.nn.functional.margin_cross_entropy(
logits,
labels,
margin1=self.margin1,
margin2=self.margin2,
margin3=self.margin3,
scale=self.scale,
return_softmax=True,
reduction=None,
group=True)
self.assertRaises(ValueError, test_dim)
self.assertRaises(NotImplementedError, test_label_type)
self.assertRaises(ValueError, test_group_value)
if __name__ == '__main__':
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -26,8 +26,8 @@ from paddle.fluid.backward import append_backward
from paddle.distributed.models.moe import utils
def count(x, upper_range):
res = np.zeros((upper_range, )).astype(int)
def count(x, upper_num):
res = np.zeros((upper_num, )).astype(int)
for i in x.reshape(-1):
if i >= 0 and i < len(res):
res[i] += 1
......@@ -36,14 +36,14 @@ def count(x, upper_range):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestExpertCountOpInt64(op_test.OpTest):
class TestNumberCountOpInt64(op_test.OpTest):
def setUp(self):
expert_num = 16
upper_num = 16
self.op_type = "number_count"
x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64')
self.inputs = {'gate_idx': x}
self.outputs = {'Out': count(x, expert_num)}
self.attrs = {"upper_range": expert_num}
x = np.random.randint(-1, upper_num, size=(1000, 2)).astype('int64')
self.inputs = {'numbers': x}
self.outputs = {'Out': count(x, upper_num)}
self.attrs = {"upper_range": upper_num}
def test_forward(self):
self.check_output_with_place(paddle.CUDAPlace(0))
......@@ -51,19 +51,19 @@ class TestExpertCountOpInt64(op_test.OpTest):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestExpertCountAPI(unittest.TestCase):
class TestNumberCountAPI(unittest.TestCase):
def setUp(self):
self.upper_range = 320
self.upper_num = 320
self.x = np.random.randint(
-1, self.upper_range, size=(6000, 200)).astype('int64')
self.out = count(self.x, self.upper_range)
-1, self.upper_num, size=(6000, 200)).astype('int64')
self.out = count(self.x, self.upper_num)
self.place = paddle.CUDAPlace(0)
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('x', self.x.shape, dtype="int64")
out = utils._number_count(x, self.upper_range)
out = utils._number_count(x, self.upper_num)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x}, fetch_list=[out])
assert np.allclose(res, self.out)
......@@ -71,7 +71,7 @@ class TestExpertCountAPI(unittest.TestCase):
def test_api_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
out = utils._number_count(x, self.upper_range)
out = utils._number_count(x, self.upper_num)
assert np.allclose(out.numpy(), self.out)
......
......@@ -1651,16 +1651,21 @@ def class_center_sample(label, num_classes, num_samples, group=None):
.. hint::
If the number of the positive class centers is greater than the input num_samples, it keeps all the positive
class centers and the shape of sampled_class_center will be [num_positive_class_centers].
The API supports CPU, single GPU and multi GPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
Args:
label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes)
num_classes (int): A positive integer to specify the number of classes at local rank.
Note that num_classes of each GPU can be different.
num_samples (int): A positive integer to specify the number of class center to sample.
group (Group, optional): The abstract representation of group.
See paddle.distributed.collective.Group. Default is ``None``.
group (Group, optional): The group instance return by paddle.distributed.new_group
or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
Returns:
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
......@@ -1733,18 +1738,25 @@ def class_center_sample(label, num_classes, num_samples, group=None):
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [0, 1, 2, 3, 5, 7, 8])
"""
if group is not None and not group.is_member():
if not (group == False or group is None or hasattr(group, 'is_member')):
raise ValueError(
'Expected group is False, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return
if hasattr(group, 'is_member') and not group.is_member():
return
ring_id = 0 if group is None else group.id
ring_id = 0
rank = 0
nranks = 1
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
if group != False:
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
if num_samples > num_classes:
raise ValueError(
......
......@@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits,
r"""
.. math::
L=-\\frac{1}{N}\sum^N_{i=1}\log\\frac{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\\neq y_i} e^{scos\\theta_{y_i}}}
L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
where the :math:`\\theta_{y_i}` is the angle between the feature :math:`x` and
where the :math:`\theta_{y_i}` is the angle between the feature :math:`x` and
the representation of class :math:`i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
.. hint::
The API supports model parallel and single GPU. And logits.shape[-1] can be different at each rank.
The API supports single GPU and multi GPU, and don't supports CPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
And logits.shape[-1] can be different at each rank.
Args:
logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W.
......@@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits,
margin2 (float, optional): m2 of margin loss, default value is `0.5`.
margin3 (float, optional): m3 of margin loss, default value is `0.0`.
scale (float, optional): s of margin loss, default value is `64.0`.
group (Group, optional): The abstract representation of group, see paddle.distributed.collective.Group.
Default `None`.
group (Group, optional): The group instance return by paddle.distributed.new_group
or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
return_softmax (bool, optional): Whether return softmax probability. Default value is `False`.
reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, return the average of loss;
......@@ -1296,24 +1302,32 @@ def margin_cross_entropy(logits,
"""
assert reduction in ['mean', 'sum', 'none', None]
if group is not None and not group.is_member():
if not (group == False or group is None or hasattr(group, 'is_member')):
raise ValueError(
'Expected group is False, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return
ring_id = 0 if group is None else group.id
if hasattr(group, 'is_member') and not group.is_member():
return
ring_id = 0
rank = 0
nranks = 1
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
if group != False:
ring_id = 0 if group is None else group.id
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
input_dims = len(list(logits.shape))
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
'Expected input_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=-1)
......
......@@ -1124,16 +1124,19 @@ class SimpleRNN(RNNBase):
Using key word arguments to construct is recommended.
Parameters:
input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1.
input_size (int): The input size of :math:`x` for the first layer's cell.
hidden_size (int): The hidden size of :math:`h` for each layer's cell.
num_layers (int, optional): Number of recurrent layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be "forward"
or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
outputs of forward and backward is concatenating. Defaults to "forward".
time_major (bool, optional): Whether the first dimension of the input means the
time steps. Defaults to False.
dropout (float, optional): The droput probability. Dropout is applied to the
input of each layer except for the first layer. Defaults to 0.
time_major (bool, optional): Whether the first dimension of the input
means the time steps. If time_major is True, the shape of Tensor is
[time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size].
Defaults to False. `time_steps` means the length of input sequence.
dropout (float, optional): The droput probability. Dropout is applied
to the input of each layer except for the first layer. The range of
dropout from 0 to 1. Defaults to 0.
activation (str, optional): The activation in each SimpleRNN cell. It can be
`tanh` or `relu`. Defaults to `tanh`.
weight_ih_attr (ParamAttr, optional): The parameter attribute for
......@@ -1148,13 +1151,13 @@ class SimpleRNN(RNNBase):
None). For more information, please refer to :ref:`api_guide_Name`.
Inputs:
- **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`.
- **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. `time_steps` means the length of the input sequence.
- **initial_states** (Tensor, optional): the initial state. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used.
- **sequence_length** (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. If `sequence_length` is not None, the inputs are treated as padded sequences. In each input sequence, elements whose time step index are not less than the valid length are treated as paddings.
Returns:
- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1.
- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence.
- **final_states** (Tensor): final states. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1.
......@@ -1242,16 +1245,19 @@ class LSTM(RNNBase):
Using key word arguments to construct is recommended.
Parameters:
input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1.
input_size (int): The input size of :math:`x` for the first layer's cell.
hidden_size (int): The hidden size of :math:`h` for each layer's cell.
num_layers (int, optional): Number of recurrent layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be "forward"
or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
outputs of forward and backward is concatenating. Defaults to "forward".
time_major (bool, optional): Whether the first dimension of the input
means the time steps. Defaults to False.
means the time steps. If time_major is True, the shape of Tensor is
[time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size].
Defaults to False. `time_steps` means the length of input sequence.
dropout (float, optional): The droput probability. Dropout is applied
to the input of each layer except for the first layer. Defaults to 0.
to the input of each layer except for the first layer. The range of
dropout from 0 to 1. Defaults to 0.
weight_ih_attr (ParamAttr, optional): The parameter attribute for
`weight_ih` of each cell. Default: None.
weight_hh_attr (ParamAttr, optional): The parameter attribute for
......@@ -1264,13 +1270,13 @@ class LSTM(RNNBase):
None). For more information, please refer to :ref:`api_guide_Name`.
Inputs:
- **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`.
- **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. `time_steps` means the length of the input sequence.
- **initial_states** (list|tuple, optional): the initial state, a list/tuple of (h, c), the shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used.
- **sequence_length** (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. If `sequence_length` is not None, the inputs are treated as padded sequences. In each input sequence, elements whos time step index are not less than the valid length are treated as paddings.
Returns:
- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1.
- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence.
- **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1.
......@@ -1349,16 +1355,19 @@ class GRU(RNNBase):
Using key word arguments to construct is recommended.
Parameters:
input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1.
input_size (int): The input size of :math:`x` for the first layer's cell.
hidden_size (int): The hidden size of :math:`h` for each layer's cell.
num_layers (int, optional): Number of recurrent layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be "forward"
or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
outputs of forward and backward is concatenating. Defaults to "forward".
time_major (bool, optional): Whether the first dimension of the input
means the time steps. Defaults to False.
means the time steps. If time_major is True, the shape of Tensor is
[time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size].
Defaults to False. `time_steps` means the length of input sequence.
dropout (float, optional): The droput probability. Dropout is applied
to the input of each layer except for the first layer. Defaults to 0.
to the input of each layer except for the first layer. The range of
dropout from 0 to 1. Defaults to 0.
weight_ih_attr (ParamAttr, optional): The parameter attribute for
`weight_ih` of each cell. Default: None.
weight_hh_attr (ParamAttr, optional): The parameter attribute for
......@@ -1371,13 +1380,13 @@ class GRU(RNNBase):
None). For more information, please refer to :ref:`api_guide_Name`.
Inputs:
- **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`.
- **inputs** (Tensor): the input sequence. If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, else, the shape is `[batch_size, time_steps, hidden_size]`. `time_steps` means the length of the input sequence.
- **initial_states** (Tensor, optional): the initial state. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. If initial_state is not given, zero initial states are used. Defaults to None.
- **sequence_length** (Tensor, optional): shape `[batch_size]`, dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. If `sequence_length` is not None, the inputs are treated as padded sequences. In each input sequence, elements whos time step index are not less than the valid length are treated as paddings.
Returns:
- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1.
- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, else, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence.
- **final_states** (Tensor): final states. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1.
......
......@@ -55,8 +55,10 @@ class WMT14(Dataset):
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of WMT14 dataset
Dataset: Instance of WMT14 dataset
- src_ids (np.array) - The sequence of token ids of source language.
- trg_ids (np.array) - The sequence of token ids of target language.
- trg_ids_next (np.array) - The next sequence of token ids of target language.
Examples:
.. code-block:: python
......
......@@ -71,7 +71,10 @@ class WMT16(Dataset):
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of WMT16 dataset
Dataset: Instance of WMT16 dataset. The instance of dataset has 3 fields:
- src_ids (np.array) - The sequence of token ids of source language.
- trg_ids (np.array) - The sequence of token ids of target language.
- trg_ids_next (np.array) - The next sequence of token ids of target language.
Examples:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册