未验证 提交 287ca7d5 编写于 作者: Z Zeng Jinle 提交者: GitHub

MLPerf Optimization for Release/2.2 (#37109)

* add mlperf optimization PRs

* update
上级 70cb0a54
......@@ -218,7 +218,7 @@ function(op_library TARGET)
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op" "fused_attention_op" "fused_feedforward_op")
"fused_bn_add_activation_op" "fused_attention_op" "fused_feedforward_op" "resnet_unit_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -143,6 +143,8 @@ struct BuildStrategy {
// Turn off inplace addto by default.
bool enable_addto_{false};
bool allow_cuda_graph_capture_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
......
......@@ -130,10 +130,12 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
}
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
if (!fetch_ops.empty()) {
ClearFetchOp(graph_, &fetch_ops);
for (auto &place : places_) {
fetch_ctxs_.Get(place)->Wait();
for (auto &place : places_) {
fetch_ctxs_.Get(place)->Wait();
}
}
return fetches;
......
......@@ -86,19 +86,28 @@ struct ScaleLossGradFunctor {
}
};
std::string ScaleLossGradOpHandle::LossGradName() const {
return static_cast<VarHandle *>(this->outputs_[0])->name();
}
void ScaleLossGradOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
RunOnVar(local_exec_scopes_[0]->FindVar(LossGradName()), true);
}
auto *tensor =
local_exec_scopes_[0]->FindVar(var_name)->GetMutable<LoDTensor>();
void ScaleLossGradOpHandle::RunOnVar(Variable *var, bool record_event) {
auto *tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim({1}));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_,
this->dev_ctxes_.at(place_));
this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); });
if (record_event) {
this->RunAndRecordEvent(
[&] { framework::VisitDataType(out_dtype_, func); });
} else {
framework::VisitDataType(out_dtype_, func);
}
#else
ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_, nullptr);
framework::VisitDataType(out_dtype_, func);
......
......@@ -46,6 +46,12 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
std::string Name() const override;
platform::Place GetPlace() const { return place_; }
void RunOnVar(Variable *var, bool record_event = false);
std::string LossGradName() const;
protected:
void RunImpl() override;
......
......@@ -22,7 +22,9 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
namespace details {
......@@ -49,8 +51,29 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
PrepareLocalExeScopes();
}
static void RunProgramDescs(const ProgramDescs &programs,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places) {
for (auto &program : programs) {
for (auto &op_desc : program.Block(0).AllOps()) {
for (size_t i = 0; i < local_exec_scopes.size(); ++i) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_exec_scopes[i], places[i]);
}
}
}
}
FetchResultType ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
strategy_.num_iteration_per_drop_scope_ =
std::numeric_limits<size_t>::max();
DropLocalExeScopes(/*need_wait=*/false);
}
#endif
if (drop_scope_counter_ == 0) {
platform::RecordEvent e("InitLocalVars");
InitVariables();
......@@ -84,7 +107,7 @@ FetchResultType ScopeBufferedSSAGraphExecutor::Run(
++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ ||
DropScopeOrNot()) {
DropLocalExeScopes();
DropLocalExeScopes(!platform::IsCUDAGraphCapturing());
}
if (VLOG_IS_ON(5)) {
......@@ -128,15 +151,7 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() {
if (graph.Has(details::kStartupProgramDescs)) {
auto &program_descs =
graph.Get<details::ProgramDescs>(details::kStartupProgramDescs);
for (auto &program_desc : program_descs) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_exec_scopes_[i], places_[i]);
}
}
}
RunProgramDescs(program_descs, local_exec_scopes_, places_);
}
is_initialized_ = true;
}
......@@ -144,23 +159,17 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() {
if (graph.Has(details::kProgramDescs)) {
auto &program_descs =
graph.Get<details::ProgramDescs>(details::kProgramDescs);
for (auto &program_desc : program_descs) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_exec_scopes_[i], places_[i]);
}
}
}
RunProgramDescs(program_descs, local_exec_scopes_, places_);
}
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes(bool need_wait) {
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0;
for (auto &p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
if (need_wait) {
for (auto &p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
scope_monitor_.ClearHistoryLocalExecScopes();
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
......
......@@ -53,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FetchResultType Run(const std::vector<std::string>& fetch_tensors,
bool return_merged) override;
void DropLocalExeScopes();
void DropLocalExeScopes(bool need_wait = true);
bool NeedCreateLocalExeScope();
......
......@@ -115,6 +115,7 @@ message BuildStrategy {
optional bool enable_auto_fusion = 11 [ default = false ];
optional bool enable_addto = 12 [ default = false ];
optional bool fix_op_run_order = 13 [ default = false ];
optional bool allow_cuda_graph_capture = 14 [ default = false ];
}
message ExecutionStrategy {
......
......@@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
out_var_ptr->GeneratedOp());
// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if (right_generated_op->Name() != "conv2d_grad") {
if (right_generated_op->Name() != "conv2d_grad" &&
right_generated_op->Name() != "resnet_unit_grad") {
continue;
}
......@@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
if (node.inputs.empty()) return false;
auto *generated_op = node.inputs[0];
auto *op_desc = generated_op->Op();
if (op_desc == nullptr || op_desc->Type() != "conv2d_grad") {
if (op_desc == nullptr || (op_desc->Type() != "conv2d_grad" &&
op_desc->Type() != "resnet_unit_grad")) {
return false;
}
const auto &outputs = op_desc->Outputs();
auto iter = outputs.find(GradVarName("Input"));
std::string grad_var_name = op_desc->Type() == "conv2d_grad" ? "Input" : "X";
auto iter = outputs.find(GradVarName(grad_var_name));
return iter != outputs.end() && !iter->second.empty() &&
iter->second[0] == node.Name() &&
!op_desc->GetAttrIfExists<bool>("use_addto");
......
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle op_graph_view multi_devices_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
......@@ -21,14 +22,23 @@ namespace paddle {
namespace framework {
namespace ir {
template <typename T>
static bool IsMatchedPlaceSingleDeviceOp(details::OpHandleBase *op_base,
const platform::Place &place) {
auto *op = dynamic_cast<T *>(op_base);
return op && op->GetPlace() == place;
}
static bool IsLockAndRecordEventFreeComputationOpHandle(
details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace()) &&
!platform::is_xpu_place(op->GetPlace()))
return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
if (!IsMatchedPlaceSingleDeviceOp<details::ComputationOpHandle>(
pending_op, op->GetPlace()) &&
!IsMatchedPlaceSingleDeviceOp<details::ScaleLossGradOpHandle>(
pending_op, op->GetPlace())) {
return false;
}
}
......
......@@ -15,8 +15,10 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "glog/logging.h"
namespace paddle {
namespace framework {
......
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -43,6 +45,10 @@ limitations under the License. */
DECLARE_double(eager_delete_tensor_gb);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
DECLARE_bool(sync_nccl_allreduce);
#endif
#ifdef WITH_GPERFTOOLS
#include "gperftools/profiler.h"
#endif
......@@ -669,6 +675,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// ncclOp
std::vector<ir::Graph *> async_graphs =
CompileGraphWithBuildStrategy(graph, &graphs, loss_var_name);
PrepareForCUDAGraphCapture(graph);
graph = member_->ApplyMemoryOptimizePass(graph);
async_graphs[0] = graph;
......@@ -882,6 +889,23 @@ void ParallelExecutor::BCastParamsToDevices(
FetchResultType ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter ParallelExecutor Run";
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_tensors.empty(), true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
PADDLE_ENFORCE_EQ(
member_->build_strategy_.allow_cuda_graph_capture_, true,
platform::errors::InvalidArgument(
"You must turn on build_strategy.allow_cuda_graph_capture = True "
"to enable CUDA Graph capturing."));
PADDLE_ENFORCE_EQ(
member_->places_[0], platform::CUDAGraphCapturingPlace(),
platform::errors::InvalidArgument("The place to capture CUDAGraph is "
"not the same as the place to run."));
}
#endif
#ifdef WITH_GPERFTOOLS
if (gProfileStarted) {
ProfilerFlush();
......@@ -932,6 +956,16 @@ void ParallelExecutor::SkipMemoryReuse(
void ParallelExecutor::FeedTensorsIntoLocalScopes(
const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
if (platform::IsCUDAGraphCapturing()) {
for (auto &tensor : tensors) {
PADDLE_ENFORCE_EQ(
tensor.empty(), true,
platform::errors::PermissionDenied(
"Feeding data is not permitted when capturing CUDA Graph."));
}
return;
}
if (!member_->AllowPartialFeed()) {
PADDLE_ENFORCE_EQ(tensors.size(), member_->local_scopes_.size(),
platform::errors::Unimplemented(
......@@ -987,6 +1021,14 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors) {
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(
tensors.empty(), true,
platform::errors::PermissionDenied(
"Feeding data is not permitted when capturing CUDA Graph."));
return;
}
size_t num_places = member_->places_.size();
bool allow_partial_feed = member_->AllowPartialFeed();
......@@ -1568,6 +1610,107 @@ const ir::Graph &ParallelExecutor::Graph() const {
return member_->executor_->Graph();
}
void ParallelExecutor::PrepareForCUDAGraphCapture(ir::Graph *graph) {
const auto &build_strategy = member_->build_strategy_;
if (!build_strategy.allow_cuda_graph_capture_) return;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_EQ(
build_strategy.async_mode_, false,
platform::errors::InvalidArgument(
"Async Executor does not support CUDA Graph capturing."));
PADDLE_ENFORCE_EQ(
platform::IsCUDAGraphCapturing(), false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch."));
PADDLE_ENFORCE_EQ(
member_->places_.size(), 1,
platform::errors::InvalidArgument(
"CUDA Graph is only supported when one GPU device is running."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(member_->places_[0]), true,
platform::errors::InvalidArgument(
"CUDA Graph is only supported on NVIDIA GPU device."));
PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce, false,
platform::errors::InvalidArgument(
"FLAGS_sync_nccl_allreduce must be False to support "
"CUDA Graph capturing."));
std::unordered_map<std::string, std::vector<VarDesc *>> all_vars;
for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
auto *var_desc = node->Var();
all_vars[var_desc->Name()].emplace_back(var_desc);
}
}
auto mark_var_as_persistable = [&all_vars](const std::string &name) {
auto iter = all_vars.find(name);
if (iter != all_vars.end()) {
for (auto *var_desc : iter->second) {
var_desc->SetPersistable(true);
}
}
};
// Step 1: All fused vars must be persistable.
if (graph->Has(details::kFusedVars)) {
auto &fused_vars = graph->Get<details::FusedVars>(details::kFusedVars);
for (auto &fused_var : fused_vars) {
fused_var.second.persistable_ = true;
mark_var_as_persistable(fused_var.first);
}
}
// Step 2: All pinned vars must be persistable.
if (graph->Has(details::kPinnedVars)) {
auto &pinned_vars = graph->Get<details::PinnedVars>(details::kPinnedVars);
for (auto &pinned_var : pinned_vars) {
mark_var_as_persistable(pinned_var);
}
}
// Step 3: Move all main programs to startup programs to make sure that
// the main programs would only be run once.
if (graph->Has(details::kProgramDescs)) {
auto &startup_programs =
graph->GetOrInit<details::ProgramDescs>(details::kStartupProgramDescs);
auto &main_programs =
graph->Get<details::ProgramDescs>(details::kProgramDescs);
for (auto &main_program : main_programs) {
startup_programs.emplace_back(main_program);
}
graph->Erase(details::kProgramDescs);
}
// Step 4: Mark all vars in startup programs to be persistable.
if (graph->Has(details::kStartupProgramDescs)) {
auto &startup_programs =
graph->GetOrInit<details::ProgramDescs>(details::kStartupProgramDescs);
for (auto &startup_program : startup_programs) {
for (auto &op_desc : startup_program.Block(0).AllOps()) {
for (auto &output : op_desc->OutputArgumentNames()) {
mark_var_as_persistable(output);
}
}
}
}
// Step 5: ScaleLossGrad must be run beforehand to avoid H2D copy.
auto ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
auto *scope = member_->local_scopes_[0];
for (auto *op : ops) {
auto *loss_grad_op = dynamic_cast<details::ScaleLossGradOpHandle *>(op);
if (loss_grad_op == nullptr) continue;
auto loss_grad_name = loss_grad_op->LossGradName();
mark_var_as_persistable(loss_grad_name);
loss_grad_op->RunOnVar(scope->Var(loss_grad_name));
loss_grad_op->SetSkipRunning(true);
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CUDA Graph is only supported on NVIDIA GPU device."));
#endif
}
} // namespace framework
} // namespace paddle
......
......@@ -144,6 +144,8 @@ class ParallelExecutor {
void SetReaderOpDeviceInfoOfGraphs(
const std::vector<ir::Graph *> &final_graphs);
void PrepareForCUDAGraphCapture(ir::Graph *graph);
ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
std::vector<VariableInfo> var_infos_;
......
......@@ -82,7 +82,11 @@ endif()
cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator)
cc_test(test_aligned_allocator SRCS test_aligned_allocator.cc DEPS aligned_allocator)
cc_library(allocator_strategy SRCS allocator_strategy.cc DEPS gflags ${AllocatorFacadeDeps})
cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy )
cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy)
if (WITH_GPU)
target_link_libraries(allocator_facade cuda_graph)
endif()
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator)
if (WITH_TESTING)
......
......@@ -32,6 +32,9 @@
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_graph.h"
#endif
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu/xpu_info.h"
#endif
......@@ -47,17 +50,64 @@ PADDLE_DEFINE_EXPORTED_bool(
"Whether to use system allocator to allocate CPU and GPU memory. "
"Only used for unittests.");
DECLARE_string(allocator_strategy);
namespace paddle {
namespace memory {
namespace allocation {
#ifdef PADDLE_WITH_CUDA
class CUDAGraphAllocator
: public Allocator,
public std::enable_shared_from_this<CUDAGraphAllocator> {
private:
class PrivateAllocation : public Allocation {
public:
PrivateAllocation(CUDAGraphAllocator* allocator,
AllocationPtr underlying_allocation)
: Allocation(underlying_allocation->ptr(),
underlying_allocation->size(),
underlying_allocation->place()),
allocator_(allocator->shared_from_this()),
underlying_allocation_(std::move(underlying_allocation)) {}
private:
std::shared_ptr<Allocator> allocator_;
AllocationPtr underlying_allocation_;
};
explicit CUDAGraphAllocator(const std::shared_ptr<Allocator>& allocator)
: underlying_allocator_(allocator) {}
public:
static std::shared_ptr<Allocator> Create(
const std::shared_ptr<Allocator>& allocator) {
return std::shared_ptr<Allocator>(new CUDAGraphAllocator(allocator));
}
protected:
Allocation* AllocateImpl(size_t size) {
VLOG(10) << "Allocate " << size << " for CUDA Graph";
return new PrivateAllocation(this, underlying_allocator_->Allocate(size));
}
void FreeImpl(Allocation* allocation) {
VLOG(10) << "delete for CUDA Graph";
delete allocation;
}
private:
std::shared_ptr<Allocator> underlying_allocator_;
};
#endif
class AllocatorFacadePrivate {
public:
using AllocatorMap = std::map<platform::Place, std::shared_ptr<Allocator>>;
AllocatorFacadePrivate() {
auto strategy = GetAllocatorStrategy();
switch (strategy) {
explicit AllocatorFacadePrivate(bool allow_free_idle_chunk = true) {
strategy_ = GetAllocatorStrategy();
switch (strategy_) {
case AllocatorStrategy::kNaiveBestFit: {
InitNaiveBestFitCPUAllocator();
#ifdef PADDLE_WITH_XPU
......@@ -91,7 +141,8 @@ class AllocatorFacadePrivate {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount();
++dev_id) {
InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id));
InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id),
allow_free_idle_chunk);
}
InitNaiveBestFitCUDAPinnedAllocator();
#endif
......@@ -117,7 +168,7 @@ class AllocatorFacadePrivate {
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported allocator strategy: %d", static_cast<int>(strategy)));
"Unsupported allocator strategy: %d", static_cast<int>(strategy_)));
}
}
InitZeroSizeAllocators();
......@@ -130,11 +181,29 @@ class AllocatorFacadePrivate {
CheckAllocThreadSafe();
}
inline const AllocatorMap& GetAllocatorMap() {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(platform::CUDAGraph::IsCapturing())) {
auto id = platform::CUDAGraph::CapturingID();
auto iter = cuda_graph_allocator_map_.find(id);
PADDLE_ENFORCE_NE(
iter, cuda_graph_allocator_map_.end(),
platform::errors::PermissionDenied(
"No memory pool is prepared for CUDA Graph capturing."));
return iter->second->allocators_;
} else {
return allocators_;
}
#else
return allocators_;
#endif
}
inline const std::shared_ptr<Allocator>& GetAllocator(
const platform::Place& place, size_t size) {
const auto& allocators =
(size > 0 ? (UNLIKELY(FLAGS_use_system_allocator) ? system_allocators_
: allocators_)
: GetAllocatorMap())
: zero_size_allocators_);
auto iter = allocators.find(place);
PADDLE_ENFORCE_NE(iter, allocators.end(),
......@@ -145,6 +214,7 @@ class AllocatorFacadePrivate {
private:
void InitSystemAllocators() {
if (!system_allocators_.empty()) return;
system_allocators_[platform::CPUPlace()] = std::make_shared<CPUAllocator>();
#ifdef PADDLE_WITH_XPU
int device_count = platform::GetXPUDeviceCount();
......@@ -183,10 +253,11 @@ class AllocatorFacadePrivate {
allocators_[p] = std::make_shared<ThreadLocalCUDAAllocator>(p);
}
void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p) {
void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p,
bool allow_free_idle_chunk) {
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
allocators_[p] = std::make_shared<AutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize());
cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk);
}
#endif
......@@ -226,6 +297,7 @@ class AllocatorFacadePrivate {
};
void InitZeroSizeAllocators() {
if (!zero_size_allocators_.empty()) return;
std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -279,12 +351,57 @@ class AllocatorFacadePrivate {
}
}
#ifdef PADDLE_WITH_CUDA
public:
void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) {
PADDLE_ENFORCE_EQ(strategy_, AllocatorStrategy::kAutoGrowth,
platform::errors::InvalidArgument(
"CUDA Graph is only supported when the "
"FLAGS_allocator_strategy=\"auto_growth\", but got "
"FLAGS_allocator_strategy=\"%s\"",
FLAGS_allocator_strategy));
auto& allocator = cuda_graph_allocator_map_[id];
PADDLE_ENFORCE_EQ(
allocator.get(), nullptr,
platform::errors::InvalidArgument(
"The memory pool of the CUDA Graph with ID %d have been prepared.",
id));
allocator.reset(
new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false));
for (auto& item : allocator->allocators_) {
auto& old_allocator = item.second;
old_allocator = CUDAGraphAllocator::Create(old_allocator);
}
VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id;
}
void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) {
auto iter = cuda_graph_allocator_map_.find(id);
PADDLE_ENFORCE_NE(iter, cuda_graph_allocator_map_.end(),
platform::errors::InvalidArgument(
"Cannot find CUDA Graph with ID = %d", id));
cuda_graph_allocator_map_.erase(iter);
VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id;
}
#endif
private:
AllocatorMap allocators_;
AllocatorMap zero_size_allocators_;
AllocatorMap system_allocators_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<CUDAGraphID, std::unique_ptr<AllocatorFacadePrivate>>
cuda_graph_allocator_map_;
#endif
AllocatorStrategy strategy_;
static AllocatorMap zero_size_allocators_;
static AllocatorMap system_allocators_;
};
AllocatorFacadePrivate::AllocatorMap
AllocatorFacadePrivate::zero_size_allocators_;
AllocatorFacadePrivate::AllocatorMap AllocatorFacadePrivate::system_allocators_;
// Pimpl. Make interface clean.
AllocatorFacade::AllocatorFacade() : m_(new AllocatorFacadePrivate()) {}
// delete m_ may cause core dump when the destructor of python in conflict with
......@@ -316,6 +433,16 @@ const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1);
}
#ifdef PADDLE_WITH_CUDA
void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) {
return m_->PrepareMemoryPoolForCUDAGraph(id);
}
void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) {
return m_->RemoveMemoryPoolOfCUDAGraph(id);
}
#endif
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -18,6 +18,9 @@
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/memory/allocation/npu_pinned_allocator.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h"
#endif
#include "paddle/fluid/platform/place.h"
namespace paddle {
......@@ -54,6 +57,11 @@ class AllocatorFacade {
uint64_t Release(const platform::Place& place);
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place);
#ifdef PADDLE_WITH_CUDA
void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id);
void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id);
#endif
// TODO(yy): Allocate a Copy-On-Write allocation?
private:
AllocatorFacade();
......
......@@ -39,11 +39,12 @@ namespace allocation {
AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator(
const std::shared_ptr<Allocator> &underlying_allocator, size_t alignment,
size_t chunk_size)
size_t chunk_size, bool allow_free_idle_chunk)
: underlying_allocator_(
std::make_shared<AlignedAllocator>(underlying_allocator, alignment)),
alignment_(alignment),
chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)) {}
chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)),
allow_free_idle_chunk_(allow_free_idle_chunk) {}
Allocation *AutoGrowthBestFitAllocator::AllocateImpl(size_t size) {
size = AlignedSize(size, alignment_);
......@@ -139,6 +140,9 @@ void AutoGrowthBestFitAllocator::FreeImpl(Allocation *allocation) {
}
uint64_t AutoGrowthBestFitAllocator::FreeIdleChunks() {
if (!allow_free_idle_chunk_) {
return 0;
}
uint64_t bytes = 0;
for (auto chunk_it = chunks_.begin(); chunk_it != chunks_.end();) {
auto &blocks = chunk_it->blocks_;
......
......@@ -31,7 +31,7 @@ class AutoGrowthBestFitAllocator : public Allocator {
public:
AutoGrowthBestFitAllocator(
const std::shared_ptr<Allocator> &underlying_allocator, size_t alignment,
size_t chunk_size = 0);
size_t chunk_size = 0, bool allow_free_idle_chunk = true);
bool IsAllocThreadSafe() const override { return true; }
......@@ -86,6 +86,7 @@ class AutoGrowthBestFitAllocator : public Allocator {
std::list<Chunk> chunks_;
size_t alignment_;
size_t chunk_size_;
bool allow_free_idle_chunk_;
SpinLock spinlock_;
};
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/cudnn_desc.h"
namespace paddle {
namespace operators {
......@@ -480,6 +481,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
const framework::ExecutionContext& ctx) {
platform::CUDAGraphCaptureModeGuard guard;
auto dtype = platform::CudnnDataType<T>::type;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
......@@ -601,6 +603,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
}
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
platform::CUDAGraphCaptureModeGuard guard;
size_t workspace_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
......
......@@ -18,7 +18,8 @@ register_operators(EXCLUDES
fused_bn_add_activation_op
fused_attention_op
fused_feedforward_op
fused_transformer_op)
fused_transformer_op
resnet_unit_op)
# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
......@@ -86,4 +87,11 @@ if (WITH_GPU OR WITH_ROCM)
op_library(fused_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")
endif()
# resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
op_library(resnet_unit_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n")
cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory)
cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory)
endif()
endif()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace op = paddle::operators;
using Tensor = paddle::framework::Tensor;
USE_OP(batch_norm);
USE_CUDA_ONLY_OP(fused_bn_add_activation);
USE_CUDA_ONLY_OP(fused_bn_add_activation_grad);
template <typename T>
void InitRandomTensor(const std::vector<int64_t> &dims,
framework::Tensor *cpu_out) {
T *cpu_out_ptr = cpu_out->mutable_data<T>(framework::make_ddim(dims),
platform::CPUPlace());
std::default_random_engine random(0);
std::uniform_real_distribution<float> dis(-1.0, 1.0);
for (int i = 0; i < cpu_out->numel(); ++i) {
cpu_out_ptr[i] = static_cast<T>(dis(random));
}
}
template <typename T>
void InitConstantTensor(const std::vector<int64_t> &dims, T value,
framework::Tensor *cpu_out) {
T *cpu_out_ptr = cpu_out->mutable_data<T>(framework::make_ddim(dims),
platform::CPUPlace());
for (int i = 0; i < cpu_out->numel(); ++i) {
cpu_out_ptr[i] = value;
}
}
template <typename T>
void CheckOutput(std::string name, const framework::Tensor &cpu_res,
const framework::Tensor &cpu_base, float diff,
bool is_relative_atol = false) {
if (cpu_res.dims().size() == cpu_base.dims().size()) {
EXPECT_EQ(cpu_res.dims(), cpu_base.dims());
} else {
EXPECT_EQ(cpu_res.numel(), cpu_base.numel());
}
const T *cpu_res_ptr = cpu_res.data<T>();
const T *cpu_base_ptr = cpu_base.data<T>();
float max_diff = 0;
int index = 0;
for (int i = 0; i < cpu_res.numel(); ++i) {
float cur_diff;
if (is_relative_atol) {
cur_diff = static_cast<float>(
std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) / cpu_base_ptr[i]));
EXPECT_LT(static_cast<float>(std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) /
cpu_base_ptr[i])),
diff);
} else {
cur_diff = static_cast<float>(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i]));
EXPECT_LT(static_cast<float>(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])),
diff);
}
if (cur_diff > max_diff) {
max_diff = cur_diff;
index = i;
}
}
std::string error_type = is_relative_atol ? "relative" : "absolute";
LOG(INFO) << "[" << name << "] The dims is [" << cpu_res.dims()
<< "], maximum " << error_type << " error is " << max_diff << ": "
<< cpu_res_ptr[index] << " vs " << cpu_base_ptr[index];
}
template <typename T>
void ComputeSumAndSquareSum(const framework::Tensor &cpu_x,
framework::Tensor *cpu_sum,
framework::Tensor *cpu_sum_of_square) {
// x is in NHWC format.
auto dims = cpu_x.dims();
int64_t c = dims[3];
const T *cpu_x_ptr = cpu_x.data<T>();
float *cpu_sum_ptr =
cpu_sum->mutable_data<float>({1, 1, 1, c}, platform::CPUPlace());
float *cpu_sum_square_ptr = cpu_sum_of_square->mutable_data<float>(
{1, 1, 1, c}, platform::CPUPlace());
for (int j = 0; j < c; ++j) {
float tmp_sum = 0.0f;
float tmp_sum_of_squares = 0.0f;
for (int i = 0; i < cpu_x.numel() / c; ++i) {
float tmp_x = static_cast<float>(cpu_x_ptr[i * c + j]);
tmp_sum += tmp_x;
tmp_sum_of_squares += tmp_x * tmp_x;
}
cpu_sum_ptr[j] = tmp_sum;
cpu_sum_square_ptr[j] = tmp_sum_of_squares;
}
}
template <typename T>
void ComputeInplaceAdd(const framework::Tensor &cpu_x,
framework::Tensor *cpu_y) {
EXPECT_EQ(cpu_x.dims(), cpu_y->dims());
const T *cpu_x_ptr = cpu_x.data<T>();
T *cpu_y_ptr = cpu_y->data<T>();
for (int64_t i = 0; i < cpu_x.numel(); ++i) {
cpu_y_ptr[i] += cpu_x_ptr[i];
}
}
template <typename T>
void ComputeInplaceRelu(framework::Tensor *cpu_x) {
T *cpu_x_ptr = cpu_x->data<T>();
for (int64_t i = 0; i < cpu_x->numel(); ++i) {
cpu_x_ptr[i] =
cpu_x_ptr[i] > static_cast<T>(0) ? cpu_x_ptr[i] : static_cast<T>(0);
}
}
void ComputeBatchNormForward(const platform::CUDADeviceContext &ctx,
const Tensor &cpu_x, const Tensor &cpu_scale,
const Tensor &cpu_bias, Tensor *cpu_mean,
Tensor *cpu_var, Tensor *cpu_saved_mean,
Tensor *cpu_saved_var, Tensor *cpu_y,
Tensor *saved_reserve_space) {
framework::Scope scope;
auto *x = scope.Var("X")->GetMutable<framework::LoDTensor>();
auto *scale = scope.Var("Scale")->GetMutable<framework::LoDTensor>();
auto *bias = scope.Var("Bias")->GetMutable<framework::LoDTensor>();
auto *mean = scope.Var("Mean")->GetMutable<framework::LoDTensor>();
auto *var = scope.Var("Variance")->GetMutable<framework::LoDTensor>();
auto *y = scope.Var("Y")->GetMutable<framework::LoDTensor>();
auto *saved_mean = scope.Var("SavedMean")->GetMutable<framework::LoDTensor>();
auto *saved_var =
scope.Var("SavedVariance")->GetMutable<framework::LoDTensor>();
auto *reserve_space =
scope.Var("ReserveSpace")->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace();
TensorCopySync(cpu_x, place, x);
TensorCopySync(cpu_scale, place, scale);
TensorCopySync(cpu_bias, place, bias);
TensorCopySync(*cpu_mean, place, mean);
TensorCopySync(*cpu_var, place, var);
int64_t channels = x->dims()[3];
scale->Resize({channels});
bias->Resize({channels});
mean->Resize({channels});
var->Resize({channels});
framework::AttributeMap attrs;
std::string data_layout = "NHWC";
attrs.insert({"data_layout", data_layout});
auto op = framework::OpRegistry::CreateOp(
"batch_norm", {{"X", {"X"}},
{"Scale", {"Scale"}},
{"Bias", {"Bias"}},
{"Mean", {"Mean"}},
{"Variance", {"Variance"}}},
{{"Y", {"Y"}},
{"MeanOut", {"Mean"}},
{"VarianceOut", {"Variance"}},
{"SavedMean", {"SavedMean"}},
{"SavedVariance", {"SavedVariance"}},
{"ReserveSpace", {"ReserveSpace"}}},
attrs);
op->Run(scope, ctx.GetPlace());
TensorCopySync(*y, platform::CPUPlace(), cpu_y);
TensorCopySync(*mean, platform::CPUPlace(), cpu_mean);
TensorCopySync(*var, platform::CPUPlace(), cpu_var);
TensorCopySync(*saved_mean, platform::CPUPlace(), cpu_saved_mean);
TensorCopySync(*saved_var, platform::CPUPlace(), cpu_saved_var);
// reserved_space will stay on GPU and used in grad op.
saved_reserve_space->ShareDataWith(*reserve_space);
}
void ComputeFusedBNAddReluForward(const platform::CUDADeviceContext &ctx,
const Tensor &cpu_x, const Tensor &cpu_z,
const Tensor &cpu_scale,
const Tensor &cpu_bias, Tensor *cpu_mean,
Tensor *cpu_var, Tensor *cpu_saved_mean,
Tensor *cpu_saved_var, Tensor *cpu_y,
Tensor *saved_reserve_space) {
framework::Scope scope;
auto *x = scope.Var("X")->GetMutable<framework::LoDTensor>();
auto *z = scope.Var("Z")->GetMutable<framework::LoDTensor>();
auto *scale = scope.Var("Scale")->GetMutable<framework::LoDTensor>();
auto *bias = scope.Var("Bias")->GetMutable<framework::LoDTensor>();
auto *mean = scope.Var("Mean")->GetMutable<framework::LoDTensor>();
auto *var = scope.Var("Variance")->GetMutable<framework::LoDTensor>();
auto *y = scope.Var("Y")->GetMutable<framework::LoDTensor>();
auto *saved_mean = scope.Var("SavedMean")->GetMutable<framework::LoDTensor>();
auto *saved_var =
scope.Var("SavedVariance")->GetMutable<framework::LoDTensor>();
auto *reserve_space =
scope.Var("ReserveSpace")->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace();
TensorCopySync(cpu_x, place, x);
TensorCopySync(cpu_z, place, z);
TensorCopySync(cpu_scale, place, scale);
TensorCopySync(cpu_bias, place, bias);
TensorCopySync(*cpu_mean, place, mean);
TensorCopySync(*cpu_var, place, var);
int64_t channels = x->dims()[3];
scale->Resize({channels});
bias->Resize({channels});
mean->Resize({channels});
var->Resize({channels});
framework::AttributeMap attrs;
auto op = framework::OpRegistry::CreateOp(
"fused_bn_add_activation",
{{"X", {"X"}}, {"Z", {"Z"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}},
{{"Y", {"Y"}},
{"MeanOut", {"Mean"}},
{"VarianceOut", {"Variance"}},
{"SavedMean", {"SavedMean"}},
{"SavedVariance", {"SavedVariance"}},
{"ReserveSpace", {"ReserveSpace"}}},
attrs);
op->Run(scope, ctx.GetPlace());
TensorCopySync(*y, platform::CPUPlace(), cpu_y);
TensorCopySync(*mean, platform::CPUPlace(), cpu_mean);
TensorCopySync(*var, platform::CPUPlace(), cpu_var);
TensorCopySync(*saved_mean, platform::CPUPlace(), cpu_saved_mean);
TensorCopySync(*saved_var, platform::CPUPlace(), cpu_saved_var);
// reserved_space will stay on GPU and used in grad op.
saved_reserve_space->ShareDataWith(*reserve_space);
}
void ComputeFusedBNAddReluBackward(
const platform::CUDADeviceContext &ctx, const Tensor &cpu_dy,
const Tensor &cpu_x, const Tensor &cpu_scale, const Tensor &cpu_bias,
const Tensor &cpu_saved_mean, const Tensor &cpu_saved_var,
const Tensor &cpu_y, const Tensor &saved_reserve_space, Tensor *cpu_dx,
Tensor *cpu_dz, Tensor *cpu_dscale, Tensor *cpu_dbias) {
framework::Scope scope;
auto *x = scope.Var("X")->GetMutable<framework::LoDTensor>();
auto *y = scope.Var("Y")->GetMutable<framework::LoDTensor>();
auto *dy = scope.Var("Y@GRAD")->GetMutable<framework::LoDTensor>();
auto *scale = scope.Var("Scale")->GetMutable<framework::LoDTensor>();
auto *bias = scope.Var("Bias")->GetMutable<framework::LoDTensor>();
auto *saved_mean = scope.Var("SavedMean")->GetMutable<framework::LoDTensor>();
auto *saved_var =
scope.Var("SavedVariance")->GetMutable<framework::LoDTensor>();
auto *reserve_space =
scope.Var("ReserveSpace")->GetMutable<framework::LoDTensor>();
auto *dx = scope.Var("X@GRAD")->GetMutable<framework::LoDTensor>();
auto *dz = scope.Var("Z@GRAD")->GetMutable<framework::LoDTensor>();
auto *dscale = scope.Var("Scale@GRAD")->GetMutable<framework::LoDTensor>();
auto *dbias = scope.Var("Bias@GRAD")->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace();
TensorCopySync(cpu_x, place, x);
TensorCopySync(cpu_y, place, y);
TensorCopySync(cpu_dy, place, dy);
TensorCopySync(cpu_scale, place, scale);
TensorCopySync(cpu_bias, place, bias);
TensorCopySync(cpu_saved_mean, place, saved_mean);
TensorCopySync(cpu_saved_var, place, saved_var);
reserve_space->ShareDataWith(saved_reserve_space);
int64_t channels = x->dims()[3];
scale->Resize({channels});
bias->Resize({channels});
saved_mean->Resize({channels});
saved_var->Resize({channels});
framework::AttributeMap attrs;
float momentum = 0.9;
float epsilon = 1e-5;
std::string act_type = "relu";
attrs.insert({"momentum", momentum});
attrs.insert({"epsilon", epsilon});
attrs.insert({"act_type", act_type});
auto op = framework::OpRegistry::CreateOp(
"fused_bn_add_activation_grad", {{"X", {"X"}},
{"Y", {"Y"}},
{"Y@GRAD", {"Y@GRAD"}},
{"Scale", {"Scale"}},
{"Bias", {"Bias"}},
{"SavedMean", {"SavedMean"}},
{"SavedVariance", {"SavedVariance"}},
{"ReserveSpace", {"ReserveSpace"}}},
{{"X@GRAD", {"X@GRAD"}},
{"Z@GRAD", {"Z@GRAD"}},
{"Scale@GRAD", {"Scale@GRAD"}},
{"Bias@GRAD", {"Bias@GRAD"}}},
attrs);
op->Run(scope, ctx.GetPlace());
TensorCopySync(*dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(*dz, platform::CPUPlace(), cpu_dz);
TensorCopySync(*dscale, platform::CPUPlace(), cpu_dscale);
TensorCopySync(*dbias, platform::CPUPlace(), cpu_dbias);
}
template <typename T>
class CudnnBNAddReluTester {
public:
CudnnBNAddReluTester(int batch_size, int height, int width, int channels,
std::string act_type, bool fuse_add, bool has_shortcut) {
batch_size_ = batch_size;
height_ = height;
width_ = width;
channels_ = channels;
ele_count_ = batch_size_ * height_ * width_;
act_type_ = act_type;
fuse_add_ = fuse_add;
has_shortcut_ = has_shortcut;
SetUp();
}
~CudnnBNAddReluTester() {}
void CheckForward(float diff, bool is_relative_atol = false) {
LOG(INFO) << "[CheckForward, diff=" << diff
<< ", is_relative_atol=" << is_relative_atol
<< "] act_type=" << act_type_ << ", fuse_add=" << fuse_add_
<< ", has_shortcut=" << has_shortcut_;
platform::CUDADeviceContext *ctx =
static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0)));
auto select = [&](Tensor *in) { return has_shortcut_ ? in : nullptr; };
framework::Tensor cpu_mean_base_x;
framework::Tensor cpu_var_base_x;
framework::Tensor cpu_mean_base_z;
framework::Tensor cpu_var_base_z;
if (!has_shortcut_ && fuse_add_ && (act_type_ == "relu")) {
BaselineForwardFusedBNAddRelu(
*ctx, &cpu_mean_base_x, &cpu_var_base_x, &cpu_saved_mean_base_x_,
&cpu_saved_var_base_x_, &cpu_y_base_, &saved_reserve_space_x_);
} else {
BaselineForward(
*ctx, &cpu_mean_base_x, &cpu_var_base_x, &cpu_saved_mean_base_x_,
&cpu_saved_var_base_x_, &cpu_y_base_, &saved_reserve_space_x_,
select(&cpu_mean_base_z), select(&cpu_var_base_z),
select(&cpu_saved_mean_base_z_), select(&cpu_saved_var_base_z_),
select(&saved_reserve_space_z_));
}
framework::Tensor cpu_mean_x;
framework::Tensor cpu_var_x;
framework::Tensor cpu_y;
framework::Tensor cpu_mean_z;
framework::Tensor cpu_var_z;
FusedForward(*ctx, &cpu_mean_x, &cpu_var_x, &cpu_saved_mean_x_,
&cpu_saved_var_x_, &cpu_y, &cpu_bitmask_, select(&cpu_mean_z),
select(&cpu_var_z), select(&cpu_saved_mean_z_),
select(&cpu_saved_var_z_));
CheckOutput<float>("Mean", cpu_mean_x, cpu_mean_base_x, diff,
is_relative_atol);
CheckOutput<float>("Variance", cpu_var_x, cpu_var_base_x, diff,
is_relative_atol);
CheckOutput<float>("SavedMean", cpu_saved_mean_x_, cpu_saved_mean_base_x_,
diff, is_relative_atol);
CheckOutput<float>("SavedVariance", cpu_saved_var_x_, cpu_saved_var_base_x_,
diff, is_relative_atol);
if (has_shortcut_) {
CheckOutput<float>("MeanZ", cpu_mean_z, cpu_mean_base_z, diff,
is_relative_atol);
CheckOutput<float>("VarianceZ", cpu_var_z, cpu_var_base_z, diff,
is_relative_atol);
CheckOutput<float>("SavedMeanZ", cpu_saved_mean_z_,
cpu_saved_mean_base_z_, diff, is_relative_atol);
CheckOutput<float>("SavedVarianceZ", cpu_saved_var_z_,
cpu_saved_var_base_z_, diff, is_relative_atol);
}
CheckOutput<T>("Y", cpu_y, cpu_y_base_, diff, is_relative_atol);
}
void CheckBackward(float diff, bool is_relative_atol = false) {
platform::CUDADeviceContext *ctx =
static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0)));
framework::Tensor cpu_dx_base;
framework::Tensor cpu_dz_base;
framework::Tensor cpu_dscale_base;
framework::Tensor cpu_dbias_base;
BaselineBackwardFusedBNAddRelu(*ctx, &cpu_dx_base, &cpu_dz_base,
&cpu_dscale_base, &cpu_dbias_base);
framework::Tensor cpu_dx;
framework::Tensor cpu_dz;
framework::Tensor cpu_dscale;
framework::Tensor cpu_dbias;
FusedBackward(*ctx, &cpu_dx, &cpu_dz, &cpu_dscale, &cpu_dbias);
CheckOutput<T>("DX", cpu_dx, cpu_dx_base, diff, is_relative_atol);
CheckOutput<T>("DZ", cpu_dz, cpu_dz_base, diff, is_relative_atol);
CheckOutput<float>("DScale", cpu_dscale, cpu_dscale_base, diff,
is_relative_atol);
CheckOutput<float>("DBias", cpu_dbias, cpu_dbias_base, diff,
is_relative_atol);
}
private:
void SetUp() {
InitRandomTensor<T>({batch_size_, height_, width_, channels_}, &cpu_x_);
InitRandomTensor<float>({channels_}, &cpu_bn_scale_x_);
InitRandomTensor<float>({channels_}, &cpu_bn_bias_x_);
if (has_shortcut_) {
InitRandomTensor<T>({batch_size_, height_, width_, channels_}, &cpu_z_);
InitRandomTensor<float>({channels_}, &cpu_bn_scale_z_);
InitRandomTensor<float>({channels_}, &cpu_bn_bias_z_);
} else {
if (fuse_add_) {
InitRandomTensor<T>({batch_size_, height_, width_, channels_}, &cpu_z_);
}
}
InitRandomTensor<T>({batch_size_, height_, width_, channels_}, &cpu_dy_);
}
void InitMeanVar(Tensor *cpu_mean, Tensor *cpu_var, Tensor *cpu_saved_mean,
Tensor *cpu_saved_var) {
InitConstantTensor<float>({channels_}, static_cast<float>(0.0f), cpu_mean);
InitConstantTensor<float>({channels_}, static_cast<float>(1.0f), cpu_var);
InitConstantTensor<float>({channels_}, static_cast<float>(0.0f),
cpu_saved_mean);
InitConstantTensor<float>({channels_}, static_cast<float>(0.0f),
cpu_saved_var);
}
void BaselineForward(const platform::CUDADeviceContext &ctx,
Tensor *cpu_mean_x, Tensor *cpu_var_x,
Tensor *cpu_saved_mean_x, Tensor *cpu_saved_var_x,
Tensor *cpu_y, Tensor *saved_reserve_space_x,
Tensor *cpu_mean_z = nullptr,
Tensor *cpu_var_z = nullptr,
Tensor *cpu_saved_mean_z = nullptr,
Tensor *cpu_saved_var_z = nullptr,
Tensor *saved_reserve_space_z = nullptr) {
InitMeanVar(cpu_mean_x, cpu_var_x, cpu_saved_mean_x, cpu_saved_var_x);
ComputeBatchNormForward(ctx, cpu_x_, cpu_bn_scale_x_, cpu_bn_bias_x_,
cpu_mean_x, cpu_var_x, cpu_saved_mean_x,
cpu_saved_var_x, cpu_y, saved_reserve_space_x);
if (has_shortcut_) {
framework::Tensor cpu_z_out;
InitMeanVar(cpu_mean_z, cpu_var_z, cpu_saved_mean_z, cpu_saved_var_z);
ComputeBatchNormForward(
ctx, cpu_z_, cpu_bn_scale_z_, cpu_bn_bias_z_, cpu_mean_z, cpu_var_z,
cpu_saved_mean_z, cpu_saved_var_z, &cpu_z_out, saved_reserve_space_z);
ComputeInplaceAdd<T>(cpu_z_out, cpu_y);
} else {
if (fuse_add_) {
ComputeInplaceAdd<T>(cpu_z_, cpu_y);
}
}
if (act_type_ == "relu") {
ComputeInplaceRelu<T>(cpu_y);
}
}
void BaselineForwardFusedBNAddRelu(const platform::CUDADeviceContext &ctx,
Tensor *cpu_mean, Tensor *cpu_var,
Tensor *cpu_saved_mean,
Tensor *cpu_saved_var, Tensor *cpu_y,
Tensor *saved_reserve_space) {
InitMeanVar(cpu_mean, cpu_var, cpu_saved_mean, cpu_saved_var);
ComputeFusedBNAddReluForward(
ctx, cpu_x_, cpu_z_, cpu_bn_scale_x_, cpu_bn_bias_x_, cpu_mean, cpu_var,
cpu_saved_mean, cpu_saved_var, cpu_y, saved_reserve_space);
}
void BaselineBackwardFusedBNAddRelu(const platform::CUDADeviceContext &ctx,
Tensor *cpu_dx, Tensor *cpu_dz,
Tensor *cpu_dscale, Tensor *cpu_dbias) {
ComputeFusedBNAddReluBackward(
ctx, cpu_dy_, cpu_x_, cpu_bn_scale_x_, cpu_bn_bias_x_,
cpu_saved_mean_base_x_, cpu_saved_var_base_x_, cpu_y_base_,
saved_reserve_space_x_, cpu_dx, cpu_dz, cpu_dscale, cpu_dbias);
}
void ComputeFusedBNStatsFinalize(const platform::CUDADeviceContext &ctx,
const Tensor &cpu_x,
const Tensor &cpu_bn_scale,
const Tensor &cpu_bn_bias, Tensor *sum,
Tensor *sum_of_square, Tensor *bn_scale,
Tensor *bn_bias, Tensor *mean, Tensor *var,
Tensor *saved_mean, Tensor *saved_var,
Tensor *equiv_scale, Tensor *equiv_bias) {
framework::Tensor cpu_sum;
framework::Tensor cpu_sum_of_square;
ComputeSumAndSquareSum<T>(cpu_x, &cpu_sum, &cpu_sum_of_square);
auto place = ctx.GetPlace();
TensorCopySync(cpu_sum, place, sum);
TensorCopySync(cpu_sum_of_square, place, sum_of_square);
TensorCopySync(cpu_bn_scale, place, bn_scale);
TensorCopySync(cpu_bn_bias, place, bn_bias);
bn_scale->Resize({1, 1, 1, channels_});
bn_bias->Resize({1, 1, 1, channels_});
// input
mean->Resize({1, 1, 1, channels_});
var->Resize({1, 1, 1, channels_});
// output
equiv_scale->Resize({1, 1, 1, channels_});
equiv_bias->Resize({1, 1, 1, channels_});
saved_mean->Resize({1, 1, 1, channels_});
saved_var->Resize({1, 1, 1, channels_});
auto param_shape = framework::vectorize<int>(bn_scale->dims());
op::CudnnBNStatsFinalize<T> bn_op(ctx, param_shape);
bn_op.Forward(ctx, *sum, *sum_of_square, *bn_scale, *bn_bias, saved_mean,
saved_var, mean, var, equiv_scale, equiv_bias, eps_,
momentum_, ele_count_, true);
}
// Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
void FusedForward(const platform::CUDADeviceContext &ctx, Tensor *cpu_mean_x,
Tensor *cpu_var_x, Tensor *cpu_saved_mean_x,
Tensor *cpu_saved_var_x, Tensor *cpu_y, Tensor *cpu_bitmask,
Tensor *cpu_mean_z = nullptr, Tensor *cpu_var_z = nullptr,
Tensor *cpu_saved_mean_z = nullptr,
Tensor *cpu_saved_var_z = nullptr) {
framework::Tensor x;
framework::Tensor sum_x;
framework::Tensor sum_of_square_x;
framework::Tensor bn_scale_x;
framework::Tensor bn_bias_x;
framework::Tensor z;
framework::Tensor sum_z;
framework::Tensor sum_of_square_z;
framework::Tensor bn_scale_z;
framework::Tensor bn_bias_z;
auto place = ctx.GetPlace();
TensorCopySync(cpu_x_, place, &x);
if (fuse_add_ || has_shortcut_) {
TensorCopySync(cpu_z_, place, &z);
}
framework::Tensor mean_x;
framework::Tensor var_x;
framework::Tensor saved_mean_x;
framework::Tensor saved_var_x;
framework::Tensor equiv_scale_x;
framework::Tensor equiv_bias_x;
framework::Tensor mean_z;
framework::Tensor var_z;
framework::Tensor saved_mean_z;
framework::Tensor saved_var_z;
framework::Tensor equiv_scale_z;
framework::Tensor equiv_bias_z;
framework::Tensor y;
framework::Tensor bitmask;
InitMeanVar(cpu_mean_x, cpu_var_x, cpu_saved_mean_x, cpu_saved_var_x);
TensorCopySync(*cpu_mean_x, place, &mean_x);
TensorCopySync(*cpu_var_x, place, &var_x);
if (has_shortcut_) {
InitMeanVar(cpu_mean_z, cpu_var_z, cpu_saved_mean_z, cpu_saved_var_z);
TensorCopySync(*cpu_mean_z, place, &mean_z);
TensorCopySync(*cpu_var_z, place, &var_z);
}
// 1. BN Stats Finalize
ComputeFusedBNStatsFinalize(ctx, cpu_x_, cpu_bn_scale_x_, cpu_bn_bias_x_,
&sum_x, &sum_of_square_x, &bn_scale_x,
&bn_bias_x, &mean_x, &var_x, &saved_mean_x,
&saved_var_x, &equiv_scale_x, &equiv_bias_x);
if (has_shortcut_) {
ComputeFusedBNStatsFinalize(ctx, cpu_z_, cpu_bn_scale_z_, cpu_bn_bias_z_,
&sum_z, &sum_of_square_z, &bn_scale_z,
&bn_bias_z, &mean_z, &var_z, &saved_mean_z,
&saved_var_z, &equiv_scale_z, &equiv_bias_z);
}
y.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
int c = channels_;
int64_t nhw = ele_count_;
int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = (nhw + 31) & ~31;
bitmask.Resize(framework::make_ddim({nhw_int32_elems, c_int32_elems, 1}));
auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale_x.dims());
auto bitmask_shape = framework::vectorize<int>(bitmask.dims());
// 2. Scale Bias + Relu
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_,
has_shortcut_, data_shape, param_shape,
bitmask_shape);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, &z, &equiv_scale_z,
&equiv_bias_z, &y, &bitmask);
TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x);
TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x);
TensorCopySync(saved_mean_x, platform::CPUPlace(), cpu_saved_mean_x);
TensorCopySync(saved_var_x, platform::CPUPlace(), cpu_saved_var_x);
if (has_shortcut_) {
TensorCopySync(mean_z, platform::CPUPlace(), cpu_mean_z);
TensorCopySync(var_z, platform::CPUPlace(), cpu_var_z);
TensorCopySync(saved_mean_z, platform::CPUPlace(), cpu_saved_mean_z);
TensorCopySync(saved_var_z, platform::CPUPlace(), cpu_saved_var_z);
}
TensorCopySync(y, platform::CPUPlace(), cpu_y);
TensorCopySync(bitmask, platform::CPUPlace(), cpu_bitmask);
}
// Get backward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
void FusedBackward(const platform::CUDADeviceContext &ctx, Tensor *cpu_dx,
Tensor *cpu_dz, Tensor *cpu_dscale, Tensor *cpu_dbias) {
framework::Tensor dy;
framework::Tensor x;
framework::Tensor bn_scale;
framework::Tensor bn_bias;
framework::Tensor saved_mean;
framework::Tensor saved_var;
framework::Tensor bitmask;
framework::Tensor dx;
framework::Tensor dz;
framework::Tensor dscale;
framework::Tensor dbias;
auto place = ctx.GetPlace();
TensorCopySync(cpu_dy_, place, &dy);
TensorCopySync(cpu_x_, place, &x);
TensorCopySync(cpu_bn_scale_x_, place, &bn_scale);
TensorCopySync(cpu_bn_bias_x_, place, &bn_bias);
TensorCopySync(cpu_saved_mean_x_, place, &saved_mean);
TensorCopySync(cpu_saved_var_x_, place, &saved_var);
TensorCopySync(cpu_bitmask_, place, &bitmask);
bn_scale.Resize({1, 1, 1, channels_});
bn_bias.Resize({1, 1, 1, channels_});
saved_mean.Resize({1, 1, 1, channels_});
saved_var.Resize({1, 1, 1, channels_});
dx.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
dz.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
dscale.Resize(framework::make_ddim({1, 1, 1, channels_}));
dbias.Resize(framework::make_ddim({1, 1, 1, channels_}));
auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale.dims());
auto bitmask_shape = framework::vectorize<int>(bitmask.dims());
std::string act_type = "relu";
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape,
param_shape, bitmask_shape);
sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var,
&bitmask, &dx, &dz, &dscale, &dbias, eps_);
TensorCopySync(dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(dz, platform::CPUPlace(), cpu_dz);
TensorCopySync(dscale, platform::CPUPlace(), cpu_dscale);
TensorCopySync(dbias, platform::CPUPlace(), cpu_dbias);
}
private:
int batch_size_;
int height_;
int width_;
int channels_;
int ele_count_;
std::string act_type_;
bool fuse_add_;
bool has_shortcut_;
// Forward input
framework::Tensor cpu_x_;
framework::Tensor cpu_bn_scale_x_;
framework::Tensor cpu_bn_bias_x_;
framework::Tensor cpu_z_;
framework::Tensor cpu_bn_scale_z_;
framework::Tensor cpu_bn_bias_z_;
// Backward input
framework::Tensor cpu_dy_;
framework::Tensor cpu_bitmask_;
framework::Tensor cpu_saved_mean_x_;
framework::Tensor cpu_saved_var_x_;
framework::Tensor cpu_saved_mean_z_;
framework::Tensor cpu_saved_var_z_;
framework::Tensor cpu_saved_mean_base_x_;
framework::Tensor cpu_saved_var_base_x_;
framework::Tensor saved_reserve_space_x_;
framework::Tensor cpu_saved_mean_base_z_;
framework::Tensor cpu_saved_var_base_z_;
framework::Tensor saved_reserve_space_z_;
framework::Tensor cpu_y_base_;
double eps_ = 1e-5;
float momentum_ = 0.9;
};
TEST(CudnnBNAddReluFp16, BNAdd) {
int batch_size = 4;
int height = 8;
int width = 8;
int channels = 64;
std::string act_type = "";
bool has_shortcut = false;
FLAGS_cudnn_batchnorm_spatial_persistent = true;
for (auto fuse_add : {false, true}) {
CudnnBNAddReluTester<paddle::platform::float16> test(
batch_size, height, width, channels, act_type, fuse_add, has_shortcut);
test.CheckForward(2e-3);
}
}
TEST(CudnnBNAddReluFp16, BNAddRelu) {
int batch_size = 4;
int height = 8;
int width = 8;
int channels = 64;
std::string act_type = "relu";
bool has_shortcut = false;
FLAGS_cudnn_batchnorm_spatial_persistent = true;
for (auto fuse_add : {false, true}) {
CudnnBNAddReluTester<paddle::platform::float16> test(
batch_size, height, width, channels, act_type, fuse_add, has_shortcut);
test.CheckForward(2e-3);
if (fuse_add) {
test.CheckBackward(2e-4);
}
}
}
TEST(CudnnBNAddReluFp16, HasShortcut) {
int batch_size = 4;
int height = 8;
int width = 8;
int channels = 64;
std::string act_type = "";
bool fuse_add = false;
bool has_shortcut = true;
FLAGS_cudnn_batchnorm_spatial_persistent = true;
CudnnBNAddReluTester<paddle::platform::float16> test(
batch_size, height, width, channels, act_type, fuse_add, has_shortcut);
test.CheckForward(5e-3);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace dynload = platform::dynload;
template <typename T>
using BatchNormParamType =
typename platform::CudnnDataType<T>::BatchNormParamType;
#if CUDNN_VERSION >= 8000
template <typename T>
struct BNStatsFinalizeArgs {
BNStatsFinalizeArgs() {
dtype = platform::CudnnDataType<T>::type;
param_dtype = platform::CudnnDataType<BatchNormParamType<T>>::type;
format = CUDNN_TENSOR_NHWC;
}
void Set(const std::vector<int> &param_shape) {
PADDLE_ENFORCE_EQ(
param_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of param_shape is expected to 4. But recieved "
"param_shape's size is %d, param_shape is [%s].",
param_shape.size(), framework::make_ddim(param_shape)));
in_desc.set(param_shape, format, param_dtype);
out_desc.set(param_shape, format, dtype);
}
cudnnDataType_t dtype;
cudnnDataType_t param_dtype;
cudnnTensorFormat_t format;
platform::TensorDescriptor in_desc;
platform::TensorDescriptor out_desc;
};
template <typename T>
class CudnnBNStatsFinalize {
public:
CudnnBNStatsFinalize(const platform::CUDADeviceContext &ctx,
const std::vector<int> &param_shape)
: train_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING),
inference_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE) {
args_.Set(param_shape);
}
~CudnnBNStatsFinalize() {}
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &sum,
const Tensor &sum_of_squares, const Tensor &scale,
const Tensor &bias, Tensor *saved_mean, Tensor *saved_invstd,
Tensor *running_mean, Tensor *running_var, Tensor *equiv_scale,
Tensor *equiv_bias, double eps, float momentum,
int64_t ele_count, bool is_train) {
auto place = ctx.GetPlace();
if (is_train) {
TrainInit(ctx);
} else {
InferenceInit(ctx);
}
auto &op = is_train ? train_op_ : inference_op_;
// Set variant_param for both inference_op_ and train_op_
float *sum_ptr = const_cast<float *>(sum.data<float>());
float *sum_of_squares_ptr =
const_cast<float *>(sum_of_squares.data<float>());
float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = saved_mean->mutable_data<float>(place);
float *saved_invstd_ptr = saved_invstd->mutable_data<float>(place);
float *running_mean_ptr = running_mean->mutable_data<float>(place);
float *running_var_ptr = running_var->mutable_data<float>(place);
T *equiv_scale_ptr = equiv_scale->mutable_data<T>(place);
T *equiv_bias_ptr = equiv_bias->mutable_data<T>(place);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_VAR, running_var_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, equiv_scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, equiv_bias_ptr);
op.SetOpVariantParamAttrPtr<double>(CUDNN_SCALAR_DOUBLE_BN_EPSILON, &eps);
// Set extra variant_param only for train_op_:
if (is_train) {
op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD, saved_invstd_ptr);
double avg_factor = 1.0 - momentum;
op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT,
&ele_count);
op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR,
&avg_factor);
}
// fused op execute
auto handle = ctx.cudnn_handle();
op.Execute(handle);
}
private:
void TrainInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param for train op
train_op_.SetOpConstParamAttr(
{CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER,
CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER,
CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER,
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
// Set input and output desc for train op
train_op_.SetOpConstParamDesc(
{CUDNN_PARAM_YSTATS_DESC, CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC},
args_.in_desc.desc());
train_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC,
args_.out_desc.desc());
// Get workspace
auto handle = ctx.cudnn_handle();
train_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// Check workspace size, also creates plan.
size_t workspace_size_bytes = train_op_.GetWorkspaceSizeInBytes(handle);
PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U,
platform::errors::InvalidArgument(
"Unexpected non-zero workspace size for "
"CudnnBNStatsFinalize."));
train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
static_cast<void *>(nullptr));
train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
&workspace_size_bytes);
}
void InferenceInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param for inference op
inference_op_.SetOpConstParamAttr(
{CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER,
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
// Set input and output desc for inference op
inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC,
args_.in_desc.desc());
inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC,
args_.out_desc.desc());
// Get workspace
auto handle = ctx.cudnn_handle();
inference_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// Check workspace size, also creates plan.
size_t workspace_size_bytes = inference_op_.GetWorkspaceSizeInBytes(handle);
PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U,
platform::errors::InvalidArgument(
"Unexpected non-zero workspace size for "
"CudnnBNStatsFinalize."));
inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
static_cast<void *>(nullptr));
inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
&workspace_size_bytes);
}
BNStatsFinalizeArgs<T> args_;
CudnnFusionOp train_op_;
CudnnFusionOp inference_op_;
};
#endif
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace dynload = platform::dynload;
#if CUDNN_VERSION >= 8000
// A wrapper for cuDNN fused_op API.
class CudnnFusionOp {
public:
explicit CudnnFusionOp(cudnnFusedOps_t op_id) : plan_created_(false) {
// New 'fused op' descriptor creation
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateFusedOpsPlan(&op_, op_id));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnCreateFusedOpsConstParamPack(&op_const_params_, op_id));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateFusedOpsVariantParamPack(
&op_variant_params_, op_id));
}
~CudnnFusionOp() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_));
}
// Execute fused op
void Execute(cudnnHandle_t cudnn_handle) {
PADDLE_ENFORCE_EQ(
plan_created_, true,
platform::errors::Fatal(
"CudnnFusionOp exec requested without a valid 'plan', need: "
"<set const params>, GetWorkspaceSizeBytes(), Execute()."));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnFusedOpsExecute(cudnn_handle, op_, op_variant_params_));
}
// Set const param pack attribute given a descriptor.
template <typename T>
void SetOpConstParamDesc(cudnnFusedOpsConstParamLabel_t param_label,
T *param_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetFusedOpsConstParamPackAttribute(
op_const_params_, param_label, param_ptr));
plan_created_ = false;
}
// Set multiple const param pack attribute given a descriptor.
template <typename T>
void SetOpConstParamDesc(
const std::vector<cudnnFusedOpsConstParamLabel_t> &param_labels,
T *param_ptr) {
for (auto param_label : param_labels) {
SetOpConstParamDesc(param_label, param_ptr);
}
}
// Set const param pack attribute given a value of param.
template <typename T>
void SetOpConstParamAttr(cudnnFusedOpsConstParamLabel_t param_label,
T param) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetFusedOpsConstParamPackAttribute(op_const_params_,
param_label, &param));
plan_created_ = false;
}
// Set multiple const param pack attribute given a value of param.
template <typename T>
void SetOpConstParamAttr(
const std::vector<cudnnFusedOpsConstParamLabel_t> &param_labels,
T param) {
for (auto param_label : param_labels) {
SetOpConstParamAttr(param_label, param);
}
}
// Set a variant param pack attribute given a reference to a param.
template <typename T>
void SetOpVariantParamAttrPtr(cudnnFusedOpsVariantParamLabel_t param_label,
T *param_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetFusedOpsVariantParamPackAttribute(
op_variant_params_, param_label, param_ptr));
}
// Set multiple const param pack attributes given a reference to a param.
template <typename T>
void SetOpVariantParamAttrPtr(
const std::vector<cudnnFusedOpsVariantParamLabel_t> &param_labels,
const T *param_ptr) {
for (auto param_label : param_labels) {
SetOpVariantParamAttrPtr(param_label, param_ptr);
}
}
// Get the workspace, which is required before Execute().
size_t GetWorkspaceSizeInBytes(cudnnHandle_t cudnn_handle) {
if (!plan_created_) {
workspace_bytes_ = 0U;
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnMakeFusedOpsPlan(
cudnn_handle, op_, op_const_params_, &workspace_bytes_));
plan_created_ = true;
}
return workspace_bytes_;
}
private:
bool plan_created_;
size_t workspace_bytes_;
cudnnFusedOpsPlan_t op_;
cudnnFusedOpsConstParamPack_t op_const_params_;
cudnnFusedOpsVariantParamPack_t op_variant_params_;
};
class CudnnFusionOpCache {
public:
static CudnnFusionOpCache &Instance() {
static CudnnFusionOpCache instance;
return instance;
}
framework::AlgorithmsCache<CudnnFusionOp *> *GetForward() {
return &forward_cache_;
}
framework::AlgorithmsCache<CudnnFusionOp *> *GetBackward() {
return &backward_cache_;
}
private:
CudnnFusionOpCache() {}
~CudnnFusionOpCache() {
// Need to delete the memory of cache.
}
CudnnFusionOpCache(const CudnnFusionOpCache &) {}
private:
framework::AlgorithmsCache<CudnnFusionOp *> forward_cache_;
framework::AlgorithmsCache<CudnnFusionOp *> backward_cache_;
};
#endif // CUDNN_VERSION >= 8000
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace dynload = platform::dynload;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
#if CUDNN_VERSION >= 8000
static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; }
template <typename T>
struct NormConvolutionArgs {
NormConvolutionArgs() {
dtype = platform::CudnnDataType<T>::type;
format = CUDNN_TENSOR_NHWC;
compute_type = platform::CudnnDataType<float>::type;
}
void Set(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape,
const std::vector<int> &filter_shape,
const std::vector<int> &output_shape, int padding, int stride,
int dilation, int group) {
PADDLE_ENFORCE_EQ(
input_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of input_shape is expected to 4. But recieved "
"input_shape's size is %d, input_shape is [%s].",
input_shape.size(), framework::make_ddim(input_shape)));
PADDLE_ENFORCE_EQ(
filter_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of filter_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s].",
filter_shape.size(), framework::make_ddim(filter_shape)));
PADDLE_ENFORCE_EQ(filter_shape[1] == filter_shape[2] &&
(filter_shape[1] == 1 || filter_shape[1] == 3),
true,
platform::errors::InvalidArgument(
"The filter_shape is expected to store as nhwc, and "
"h = w = 1 or 3. But recieved filter_shape is [%s].",
framework::make_ddim(filter_shape)));
PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0),
true,
platform::errors::InvalidArgument(
"The input channel is expected to be multiple of 8, "
"and the output channel is expected to be multiple "
"of 32. But recieved input channel is %d, output "
"channel is %d.",
filter_shape[3], filter_shape[0]));
PADDLE_ENFORCE_EQ(
output_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of output_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s].",
output_shape.size(), framework::make_ddim(output_shape)));
is_support = IsSupport(ctx, filter_shape, stride, dilation, group);
PADDLE_ENFORCE_EQ(
is_support, true,
platform::errors::InvalidArgument(
"Current test is only supported in the platforms with "
"compatiblity greater than or equal to 70 and the kernel size "
"must be equal to 1 or 3. When the kernel size is 1, "
"the stride must be 1 if the compatiblity is equal to 70. "
"Besides, the dilation and group must be equal to 1. But recieved "
"compatiblity is %d, kernel size is %d, stride is %d, "
"dilation is %d, group is %d",
ctx.GetComputeCapability(), filter_shape[1], stride, dilation,
group));
for (size_t i = 0; i < input_shape.size(); ++i) {
in_dims.push_back(input_shape[i]);
}
for (size_t i = 0; i < filter_shape.size(); ++i) {
filter_dims.push_back(filter_shape[i]);
}
paddings = {padding, padding};
strides = {stride, stride};
dilations = {dilation, dilation};
in_desc.set(input_shape, format, dtype);
filter_desc.set(filter_shape, format, dtype, group);
out_desc.set(output_shape, format, dtype);
int output_channel = filter_shape[0];
std::vector<int> stats_shape = {1, 1, 1, output_channel};
out_stats_desc.set(stats_shape, format, compute_type);
conv_desc.set(dtype, paddings, strides, dilations, false, group);
}
bool IsSupport(const platform::CUDADeviceContext &ctx,
const std::vector<int> &filter_shape, int stride, int dilation,
int group) {
int kernel_size = filter_shape[1];
if (dilation != 1 || group != 1) {
return false;
}
if (ctx.GetComputeCapability() == 70) {
if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) {
return true;
}
} else if (ctx.GetComputeCapability() > 70) {
if ((kernel_size == 3) || (kernel_size == 1)) {
return true;
}
}
return false;
}
cudnnDataType_t dtype;
cudnnTensorFormat_t format;
cudnnDataType_t compute_type;
std::vector<int64_t> in_dims;
std::vector<int64_t> filter_dims;
std::vector<int> strides;
std::vector<int> paddings;
std::vector<int> dilations;
platform::TensorDescriptor in_desc;
platform::FilterDescriptor filter_desc;
platform::TensorDescriptor out_desc;
platform::TensorDescriptor out_stats_desc;
platform::ConvolutionDescriptor conv_desc;
bool is_support;
};
template <typename T>
class CudnnNormConvolution {
public:
CudnnNormConvolution(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape,
const std::vector<int> &filter_shape,
const std::vector<int> &output_shape, const int &padding,
const int &stride, const int &dilation,
const int &group) {
args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group);
}
~CudnnNormConvolution() {}
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &input,
const Tensor &filter, Tensor *output, Tensor *sum,
Tensor *sum_of_squares) {
auto cudnn_handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
CudnnFusionOp *fwd_op = GetForwardOp(ctx);
size_t workspace_size = RoundUp(
static_cast<int64_t>(fwd_op->GetWorkspaceSizeInBytes(cudnn_handle)),
512);
// Set variant_param
// input ptr
T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
fwd_op->SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
// output ptr
T *output_ptr = output->mutable_data<T>(place);
float *sum_ptr = sum->mutable_data<float>(place);
float *sum_of_squares_ptr = sum_of_squares->mutable_data<float>(place);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
ctx.cudnn_workspace_handle().RunFunc(
[&](void *workspace_ptr) {
// workspace ptr
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
// fused op execute
fwd_op->Execute(cudnn_handle);
},
workspace_size);
}
private:
CudnnFusionOp *GetForwardOp(const platform::CUDADeviceContext &ctx) {
framework::AlgorithmsCache<CudnnFusionOp *> &cache =
*(CudnnFusionOpCache::Instance().GetForward());
CudnnFusionOp *fwd_op = cache.GetAlgorithm(
args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
CudnnFusionOp *fwd_op =
new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS);
// Set constant_param
fwd_op->SetOpConstParamAttr(
{CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_WDATA_PLACEHOLDER,
CUDNN_PARAM_YDATA_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
fwd_op->SetOpConstParamAttr(
{CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
// conv desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
args_.conv_desc.desc());
// input desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
// filter desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_WDESC,
args_.filter_desc.desc());
// output desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc());
// output_stats desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC,
args_.out_stats_desc.desc());
// batch_norm mode
fwd_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// Make cudnn fused ops plan
fwd_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
return fwd_op;
});
return fwd_op;
}
private:
NormConvolutionArgs<T> args_;
};
template <typename T>
class CudnnNormConvolutionGrad {
public:
CudnnNormConvolutionGrad(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape,
const std::vector<int> &filter_shape,
const std::vector<int> &output_shape,
const int &padding, const int &stride,
const int &dilation, const int &group) {
args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group);
dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
~CudnnNormConvolutionGrad() {}
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &input,
const Tensor &filter, const Tensor &output_grad,
Tensor *input_grad, Tensor *filter_grad,
bool use_addto = false) {
auto place = ctx.GetPlace();
T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
T *output_grad_ptr = const_cast<T *>(output_grad.data<T>());
if (filter_grad) {
T *filter_grad_ptr = filter_grad->mutable_data<T>(place);
BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr);
}
if (input_grad) {
T *input_grad_ptr = input_grad->mutable_data<T>(place);
BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto);
}
}
private:
void BackwardFilter(const platform::CUDADeviceContext &ctx,
T *output_grad_ptr, T *input_ptr, T *filter_grad_ptr) {
auto cudnn_handle = ctx.cudnn_handle();
CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
size_t workspace_size = RoundUp(
static_cast<int64_t>(wgrad_op->GetWorkspaceSizeInBytes(cudnn_handle)),
512);
wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, output_grad_ptr);
wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DWDATA, filter_grad_ptr);
wgrad_op->SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
ctx.cudnn_workspace_handle().RunFunc(
[&](void *workspace_ptr) {
// workspace ptr
wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
workspace_ptr);
// fused op execute
wgrad_op->Execute(cudnn_handle);
},
workspace_size);
}
void BackwardData(const platform::CUDADeviceContext &ctx, T *output_grad_ptr,
T *filter_ptr, T *input_grad_ptr, bool use_addto = false) {
auto cudnn_handle = ctx.cudnn_handle();
size_t workspace_size = GetWorkspaceSizeBwdData(ctx);
// Convolution dgrad followed optionally by batchnorm dgrad
ScalingParamType<T> alpha = 1.0f;
ScalingParamType<T> beta = use_addto ? 1.0f : 0.0f;
ctx.cudnn_workspace_handle().RunFunc(
[&](void *cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnConvolutionBackwardData(
cudnn_handle, &alpha, args_.filter_desc.desc(), filter_ptr,
args_.out_desc.desc(), output_grad_ptr,
args_.conv_desc.desc(), dgrad_algo_, cudnn_workspace_ptr,
workspace_size, &beta, args_.in_desc.desc(), input_grad_ptr));
},
workspace_size);
}
CudnnFusionOp *GetBackwardFilterOp(const platform::CUDADeviceContext &ctx) {
framework::AlgorithmsCache<CudnnFusionOp *> &cache =
*(CudnnFusionOpCache::Instance().GetBackward());
CudnnFusionOp *wgrad_op = cache.GetAlgorithm(
args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
CudnnFusionOp *wgrad_op =
new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD);
wgrad_op->SetOpConstParamAttr(
{CUDNN_PARAM_DYDATA_PLACEHOLDER, CUDNN_PARAM_XDATA_PLACEHOLDER,
CUDNN_PARAM_DWDATA_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
// conv desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
args_.conv_desc.desc());
// input desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC,
args_.in_desc.desc());
// filter desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DWDESC,
args_.filter_desc.desc());
// output desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DYDESC,
args_.out_desc.desc());
wgrad_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// Make cudnn fused ops plan
wgrad_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
return wgrad_op;
});
return wgrad_op;
}
size_t GetWorkspaceSizeBwdData(const platform::CUDADeviceContext &ctx) {
size_t workspace_size = 0U;
auto handle = ctx.cudnn_handle();
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, args_.filter_desc.desc(), args_.out_desc.desc(),
args_.conv_desc.desc(), args_.in_desc.desc(), dgrad_algo_,
&workspace_size));
return RoundUp(workspace_size, 512);
}
private:
NormConvolutionArgs<T> args_;
cudnnConvolutionBwdDataAlgo_t dgrad_algo_;
};
#endif
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/float16.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace op = paddle::operators;
using Tensor = paddle::framework::Tensor;
USE_OP(conv2d);
USE_OP(conv2d_grad);
USE_OP_DEVICE_KERNEL(conv2d, CUDNN);
USE_OP_DEVICE_KERNEL(conv2d_grad, CUDNN);
template <typename T>
void InitRandomTensor(const std::vector<int64_t> &dims,
framework::Tensor *cpu_out) {
T *cpu_out_ptr = cpu_out->mutable_data<T>(framework::make_ddim(dims),
platform::CPUPlace());
std::default_random_engine random(0);
std::uniform_real_distribution<float> dis(0.0, 1.0);
for (int i = 0; i < cpu_out->numel(); ++i) {
cpu_out_ptr[i] = static_cast<T>(dis(random));
}
}
template <typename T>
void TransposeNchwToNhwc(const framework::Tensor &cpu_in,
framework::Tensor *cpu_out) {
auto in_dims = cpu_in.dims();
EXPECT_EQ(cpu_in.dims().size(), 4);
const T *cpu_in_ptr = cpu_in.data<T>();
T *cpu_out_ptr = cpu_out->mutable_data<T>(
{in_dims[0], in_dims[2], in_dims[3], in_dims[1]}, platform::CPUPlace());
int64_t n = in_dims[0];
int64_t c = in_dims[1];
int64_t hw = in_dims[2] * in_dims[3];
for (int i = 0; i < n; ++i) {
for (int j = 0; j < hw; ++j) {
for (int k = 0; k < c; ++k) {
int dst_idx = i * hw * c + j * c + k;
int src_idx = i * c * hw + k * hw + j;
cpu_out_ptr[dst_idx] = cpu_in_ptr[src_idx];
}
}
}
}
template <typename T>
void CheckOutput(const framework::Tensor &cpu_res,
const framework::Tensor &cpu_base, float diff,
bool is_relative_atol = false) {
EXPECT_EQ(cpu_res.dims(), cpu_base.dims());
const T *cpu_res_ptr = cpu_res.data<T>();
const T *cpu_base_ptr = cpu_base.data<T>();
for (int i = 0; i < cpu_res.numel(); ++i) {
if (is_relative_atol) {
EXPECT_LT(static_cast<float>(std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) /
cpu_base_ptr[i])),
diff);
} else {
EXPECT_LT(static_cast<float>(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])),
diff);
}
}
}
// Use Paddle conv2d op results as baseline
void ComputeConv2DForward(const platform::CUDADeviceContext &ctx,
const Tensor &cpu_input, const Tensor &cpu_filter,
Tensor *cpu_output, int stride, int padding) {
framework::Scope scope;
auto *input = scope.Var("Input")->GetMutable<framework::LoDTensor>();
auto *filter = scope.Var("Filter")->GetMutable<framework::LoDTensor>();
auto *output = scope.Var("Output")->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace();
TensorCopySync(cpu_input, place, input);
TensorCopySync(cpu_filter, place, filter);
framework::AttributeMap attrs;
bool use_cudnn = true;
std::string data_format = "NHWC";
std::vector<int> strides = {stride, stride};
std::vector<int> paddings = {padding, padding};
attrs.insert({"strides", strides});
attrs.insert({"paddings", paddings});
attrs.insert({"use_cudnn", use_cudnn});
attrs.insert({"data_format", data_format});
auto op = framework::OpRegistry::CreateOp(
"conv2d", {{"Input", {"Input"}}, {"Filter", {"Filter"}}},
{{"Output", {"Output"}}}, attrs);
op->Run(scope, ctx.GetPlace());
TensorCopySync(*output, platform::CPUPlace(), cpu_output);
}
// Use Paddle conv2d_grad op results as baseline
void ComputeConv2DBackward(const platform::CUDADeviceContext &ctx,
const Tensor &cpu_input, const Tensor &cpu_filter,
const Tensor &cpu_output_grad,
framework::Tensor *cpu_input_grad,
framework::Tensor *cpu_filter_grad, int stride,
int padding, int dilation) {
framework::Scope scope;
auto *input = scope.Var("Input")->GetMutable<framework::LoDTensor>();
auto *filter = scope.Var("Filter")->GetMutable<framework::LoDTensor>();
auto *output_grad =
scope.Var("Output@GRAD")->GetMutable<framework::LoDTensor>();
auto *input_grad =
scope.Var("Input@GRAD")->GetMutable<framework::LoDTensor>();
auto *filter_grad =
scope.Var("Filter@GRAD")->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace();
TensorCopySync(cpu_input, place, input);
TensorCopySync(cpu_filter, place, filter);
TensorCopySync(cpu_output_grad, place, output_grad);
framework::AttributeMap attrs;
bool use_cudnn = true;
std::string data_format = "NHWC";
std::string padding_algorithm = "EXPLICIT";
std::vector<int> strides = {stride, stride};
std::vector<int> paddings = {padding, padding};
std::vector<int> dilations = {dilation, dilation};
int groups = 1;
bool exhaustive_search = false;
bool use_addto = false;
attrs.insert({"use_cudnn", use_cudnn});
attrs.insert({"data_format", data_format});
attrs.insert({"padding_algorithm", padding_algorithm});
attrs.insert({"strides", strides});
attrs.insert({"paddings", paddings});
attrs.insert({"dilations", dilations});
attrs.insert({"groups", groups});
attrs.insert({"exhaustive_search", exhaustive_search});
attrs.insert({"use_addto", use_addto});
auto op = framework::OpRegistry::CreateOp(
"conv2d_grad", {{"Input", {"Input"}},
{"Filter", {"Filter"}},
{"Output@GRAD", {"Output@GRAD"}}},
{{"Input@GRAD", {"Input@GRAD"}}, {"Filter@GRAD", {"Filter@GRAD"}}},
attrs);
op->Run(scope, ctx.GetPlace());
TensorCopySync(*input_grad, platform::CPUPlace(), cpu_input_grad);
TensorCopySync(*filter_grad, platform::CPUPlace(), cpu_filter_grad);
}
template <typename T>
void ComputeSumAndSquareSum(const framework::Tensor &cpu_out,
framework::Tensor *cpu_sum,
framework::Tensor *cpu_sum_of_square) {
auto dims = cpu_out.dims();
int64_t c = dims[3];
const T *cpu_out_ptr = cpu_out.data<T>();
float *cpu_sum_ptr =
cpu_sum->mutable_data<float>({1, 1, 1, c}, platform::CPUPlace());
float *cpu_sum_square_ptr = cpu_sum_of_square->mutable_data<float>(
{1, 1, 1, c}, platform::CPUPlace());
for (int j = 0; j < c; ++j) {
float tmp_sum = 0.0f;
float tmp_sum_of_squares = 0.0f;
for (int i = 0; i < cpu_out.numel() / c; ++i) {
float tmp_out = static_cast<float>(cpu_out_ptr[i * c + j]);
tmp_sum += tmp_out;
tmp_sum_of_squares += tmp_out * tmp_out;
}
cpu_sum_ptr[j] = tmp_sum;
cpu_sum_square_ptr[j] = tmp_sum_of_squares;
}
}
template <typename T>
class CudnnNormConvolutionTester {
public:
CudnnNormConvolutionTester(int batch_size, int height, int width,
int input_channels, int output_channels,
int kernel_size, int stride) {
batch_size_ = batch_size;
height_ = height;
width_ = width;
input_channels_ = input_channels;
output_channels_ = output_channels;
kernel_size_ = kernel_size;
stride_ = stride;
padding_ = (kernel_size_ - 1) / 2;
out_height_ = (height_ + 2 * padding_ - kernel_size_) / stride_ + 1;
out_width_ = (width_ + 2 * padding_ - kernel_size_) / stride_ + 1;
SetUp();
}
~CudnnNormConvolutionTester() {}
void CheckForward(float diff, bool is_relative_atol = false) {
platform::CUDADeviceContext *ctx =
static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0)));
framework::Tensor cpu_output_base;
framework::Tensor cpu_sum_base;
framework::Tensor cpu_sum_of_square_base;
BaselineForward(*ctx, &cpu_output_base, &cpu_sum_base,
&cpu_sum_of_square_base);
framework::Tensor cpu_output;
framework::Tensor cpu_sum;
framework::Tensor cpu_sum_of_square;
FusedForward(*ctx, &cpu_output, &cpu_sum, &cpu_sum_of_square);
// Check forward correctness between baseline and results of normconv.
CheckOutput<T>(cpu_output, cpu_output_base, diff, is_relative_atol);
CheckOutput<float>(cpu_sum, cpu_sum_base, diff, is_relative_atol);
CheckOutput<float>(cpu_sum_of_square, cpu_sum_of_square_base, diff,
is_relative_atol);
}
void CheckBackward(float diff, bool is_relative_atol = false) {
platform::CUDADeviceContext *ctx =
static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0)));
framework::Tensor cpu_input_grad_base;
framework::Tensor cpu_filter_nchw_grad_base;
framework::Tensor cpu_filter_nhwc_grad_base;
BaselineBackward(*ctx, &cpu_input_grad_base, &cpu_filter_nchw_grad_base);
TransposeNchwToNhwc<T>(cpu_filter_nchw_grad_base,
&cpu_filter_nhwc_grad_base);
framework::Tensor cpu_input_grad;
framework::Tensor cpu_filter_nhwc_grad;
FusedBackward(*ctx, &cpu_input_grad, &cpu_filter_nhwc_grad);
// Check backward correctness between baseline and results of normconv.
CheckOutput<T>(cpu_input_grad, cpu_input_grad_base, diff, is_relative_atol);
CheckOutput<T>(cpu_filter_nhwc_grad, cpu_filter_nhwc_grad_base, diff,
is_relative_atol);
}
private:
void SetUp() {
InitRandomTensor<T>({batch_size_, height_, width_, input_channels_},
&cpu_input_);
InitRandomTensor<T>(
{output_channels_, input_channels_, kernel_size_, kernel_size_},
&cpu_filter_nchw_);
// transpoes for filter, NCHW -> NHWC
TransposeNchwToNhwc<T>(cpu_filter_nchw_, &cpu_filter_nhwc_);
InitRandomTensor<T>(
{batch_size_, out_height_, out_width_, output_channels_},
&cpu_output_grad_);
}
void BaselineForward(const platform::CUDADeviceContext &ctx,
framework::Tensor *cpu_output_base,
framework::Tensor *cpu_sum_base,
framework::Tensor *cpu_sum_of_square_base) {
ComputeConv2DForward(ctx, cpu_input_, cpu_filter_nchw_, cpu_output_base,
stride_, padding_);
ComputeSumAndSquareSum<T>(*cpu_output_base, cpu_sum_base,
cpu_sum_of_square_base);
}
void BaselineBackward(const platform::CUDADeviceContext &ctx,
framework::Tensor *cpu_input_grad_base,
framework::Tensor *cpu_filter_grad_base) {
ComputeConv2DBackward(ctx, cpu_input_, cpu_filter_nchw_, cpu_output_grad_,
cpu_input_grad_base, cpu_filter_grad_base, stride_,
padding_, dilation_);
}
// get forward results of cudnn_norm_conv
void FusedForward(const platform::CUDADeviceContext &ctx,
framework::Tensor *cpu_output, framework::Tensor *cpu_sum,
framework::Tensor *cpu_sum_of_square) {
framework::Tensor input;
framework::Tensor filter_nhwc;
framework::Tensor output;
framework::Tensor sum;
framework::Tensor sum_of_square;
auto place = ctx.GetPlace();
TensorCopySync(cpu_input_, place, &input);
TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
output.Resize(framework::make_ddim(
{batch_size_, out_height_, out_width_, output_channels_}));
sum.Resize(framework::make_ddim({1, 1, 1, output_channels_}));
sum_of_square.Resize(framework::make_ddim({1, 1, 1, output_channels_}));
auto input_shape = framework::vectorize<int>(input.dims());
auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
auto output_shape = framework::vectorize<int>(output.dims());
op::CudnnNormConvolution<T> conv_op(ctx, input_shape, filter_shape,
output_shape, padding_, stride_,
dilation_, group_);
conv_op.Forward(ctx, input, filter_nhwc, &output, &sum, &sum_of_square);
TensorCopySync(output, platform::CPUPlace(), cpu_output);
TensorCopySync(sum, platform::CPUPlace(), cpu_sum);
TensorCopySync(sum_of_square, platform::CPUPlace(), cpu_sum_of_square);
}
void FusedBackward(const platform::CUDADeviceContext &ctx,
framework::Tensor *cpu_input_grad,
framework::Tensor *cpu_filter_grad) {
framework::Tensor input;
framework::Tensor filter_nhwc;
framework::Tensor output_grad;
framework::Tensor input_grad;
framework::Tensor filter_grad;
auto place = ctx.GetPlace();
TensorCopySync(cpu_input_, place, &input);
TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
TensorCopySync(cpu_output_grad_, place, &output_grad);
input_grad.Resize(input.dims());
filter_grad.Resize(filter_nhwc.dims());
auto input_shape = framework::vectorize<int>(input.dims());
auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
auto output_shape = framework::vectorize<int>(output_grad.dims());
op::CudnnNormConvolutionGrad<T> conv_grad_op(ctx, input_shape, filter_shape,
output_shape, padding_,
stride_, dilation_, group_);
conv_grad_op.Backward(ctx, input, filter_nhwc, output_grad, &input_grad,
&filter_grad);
TensorCopySync(input_grad, platform::CPUPlace(), cpu_input_grad);
TensorCopySync(filter_grad, platform::CPUPlace(), cpu_filter_grad);
}
private:
int batch_size_;
int height_;
int width_;
int out_height_;
int out_width_;
int input_channels_;
int output_channels_;
int kernel_size_;
int stride_;
int padding_;
const int dilation_ = 1;
const int group_ = 1;
// Forward input
framework::Tensor cpu_input_;
framework::Tensor cpu_filter_nchw_;
framework::Tensor cpu_filter_nhwc_;
// Backward input
framework::Tensor cpu_output_grad_;
};
// test for fp16, kernel = 1, output_channels = input_channels
TEST(CudnnNormConvFp16, K1S1) {
int batch_size = 4;
int height = 56;
int width = 56;
int input_channels = 32;
int output_channels = 32;
int kernel_size = 1;
int stride = 1;
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
test.CheckForward(1e-3, true);
test.CheckBackward(1e-3, true);
}
// test for fp16, kernel = 3, output_channels = input_channels
TEST(CudnnNormConvFp16, K3S1) {
int batch_size = 4;
int height = 56;
int width = 56;
int input_channels = 32;
int output_channels = 32;
int kernel_size = 3;
int stride = 1;
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
test.CheckForward(1e-3, true);
test.CheckBackward(1e-3, true);
}
// test for fp16, kernel = 1, output_channels = input_channels * 4
TEST(CudnnNormConvFp16, K1S1O4) {
int batch_size = 4;
int height = 56;
int width = 56;
int input_channels = 32;
int output_channels = 128;
int kernel_size = 1;
int stride = 1;
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
test.CheckForward(1e-3, true);
test.CheckBackward(1e-3, true);
}
// test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4
TEST(CudnnNormConvFp16, K1S2O4) {
int batch_size = 4;
int height = 8;
int width = 8;
int input_channels = 32;
int output_channels = 128;
int kernel_size = 1;
int stride = 2;
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3), paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3));
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
namespace dynload = platform::dynload;
template <typename T>
using BatchNormParamType =
typename platform::CudnnDataType<T>::BatchNormParamType;
#if CUDNN_VERSION >= 8000
template <typename T>
struct ScaleBiasAddReluArgs {
ScaleBiasAddReluArgs() {
dtype = platform::CudnnDataType<T>::type;
param_dtype = platform::CudnnDataType<BatchNormParamType<T>>::type;
format = CUDNN_TENSOR_NHWC;
}
void Set(const std::string &act_type, const std::vector<int> &data_shape,
const std::vector<int> &param_shape,
const std::vector<int> &bitmask_shape) {
PADDLE_ENFORCE_EQ(
data_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of data_shape is expected to 4. But recieved "
"data_shape's size is %d, data_shape is [%s].",
data_shape.size(), framework::make_ddim(data_shape)));
PADDLE_ENFORCE_EQ(
param_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of param_shape is expected to 4. But recieved "
"param_shape's size is %d, param_shape is [%s].",
param_shape.size(), framework::make_ddim(param_shape)));
PADDLE_ENFORCE_EQ(
bitmask_shape.size(), 3U,
platform::errors::InvalidArgument(
"The size of bitmask_shape is expected to 3. But recieved "
"bitmask_shape's size is %d, bitmask_shape is [%s].",
bitmask_shape.size(), framework::make_ddim(bitmask_shape)));
in_desc.set(data_shape, format, dtype);
out_desc.set(data_shape, format, dtype);
equiv_scale_bias_desc.set(param_shape, format, dtype);
scale_bias_mean_var_desc.set(param_shape, format, param_dtype);
bitmask_desc.set(bitmask_shape, format, CUDNN_DATA_INT32);
// set activation desc
cudnnActivationMode_t mode = CUDNN_ACTIVATION_IDENTITY;
if (act_type != "") {
PADDLE_ENFORCE_EQ(
act_type, "relu",
platform::errors::InvalidArgument(
"Only relu activation supported in normalized convolution."));
mode = CUDNN_ACTIVATION_RELU;
}
double dummy_clip = 0.0;
activation_desc.set(mode, dummy_clip);
}
cudnnDataType_t dtype;
cudnnDataType_t param_dtype;
cudnnTensorFormat_t format;
platform::TensorDescriptor in_desc;
platform::TensorDescriptor out_desc;
platform::TensorDescriptor equiv_scale_bias_desc;
platform::TensorDescriptor scale_bias_mean_var_desc;
platform::TensorDescriptor bitmask_desc;
platform::ActivationDescriptor activation_desc;
};
template <typename T>
class CudnnScaleBiasAddRelu {
public:
CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx,
const std::string &act_type, bool fuse_add,
bool has_shortcut, const std::vector<int> &data_shape,
const std::vector<int> &param_shape,
const std::vector<int> &bitmask_shape)
: fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK),
bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) {
fuse_add_ = fuse_add;
has_shortcut_ = has_shortcut;
args_.Set(act_type, data_shape, param_shape, bitmask_shape);
}
~CudnnScaleBiasAddRelu() {}
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x,
const Tensor &x_scale, const Tensor &x_bias, const Tensor *z,
const Tensor *z_scale, const Tensor *z_bias, Tensor *out,
Tensor *bitmask) {
ForwardInit(ctx);
auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle();
fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param
// input ptr
T *x_ptr = const_cast<T *>(x.data<T>());
T *x_scale_ptr = const_cast<T *>(x_scale.data<T>());
T *x_bias_ptr = const_cast<T *>(x_bias.data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr);
if (has_shortcut_) {
T *z_ptr = const_cast<T *>(z->data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale->data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else {
if (fuse_add_) {
T *z_ptr = const_cast<T *>(z->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
}
}
fwd_op_.SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_);
// output ptr
T *out_ptr = out->mutable_data<T>(place);
int32_t *bitmask_ptr = bitmask->mutable_data<int32_t>(place);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr);
workspace_handle.RunFunc(
[&](void *workspace_ptr) {
// workspace ptr
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
// workspace ptr
fwd_op_.Execute(handle);
},
fwd_workspace_byte_);
}
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy,
const Tensor &x, const Tensor &scale, const Tensor &bias,
const Tensor &saved_mean, const Tensor &saved_invstd,
const Tensor *bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
Tensor *dbias, double eps) {
BackwardInit(ctx);
auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle();
bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param
// input ptr
T *dy_ptr = const_cast<T *>(dy.data<T>());
T *x_ptr = const_cast<T *>(x.data<T>());
float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = const_cast<float *>(saved_mean.data<float>());
float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>());
int32_t *bitmask_ptr =
bitmask ? const_cast<int32_t *>(bitmask->data<int32_t>()) : nullptr;
T *dx_ptr = dx->mutable_data<T>(place);
T *dz_ptr = dz ? dz->mutable_data<T>(place) : nullptr;
float *dscale_ptr = dscale ? dscale->mutable_data<float>(place) : nullptr;
float *dbias_ptr = dbias ? dbias->mutable_data<float>(place) : nullptr;
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD,
saved_invstd_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr);
bwd_op_.SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &bwd_workspace_byte_);
// output ptr
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DXDATA, dx_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DSCALE, dscale_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr);
bwd_op_.SetOpVariantParamAttrPtr<double>(CUDNN_SCALAR_DOUBLE_BN_EPSILON,
&eps);
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr);
}
workspace_handle.RunFunc(
[&](void *workspace_ptr) {
// workspace ptr
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
// workspace ptr
bwd_op_.Execute(handle);
},
bwd_workspace_byte_);
}
private:
void ForwardInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param
fwd_op_.SetOpConstParamAttr(
{CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER,
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER, CUDNN_PARAM_YDATA_PLACEHOLDER,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
if (has_shortcut_) {
fwd_op_.SetOpConstParamAttr(
{CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER,
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
} else if (fuse_add_) {
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}
// input desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
if (has_shortcut_ || fuse_add_) {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc());
}
// equiv scale/bias desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC,
args_.equiv_scale_bias_desc.desc());
if (has_shortcut_) {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC,
args_.equiv_scale_bias_desc.desc());
}
// output desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc());
// bitmask desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC,
args_.bitmask_desc.desc());
// activation desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC,
args_.activation_desc.desc());
// others
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
}
void BackwardInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param
bwd_op_.SetOpConstParamAttr(
{CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_DYDATA_PLACEHOLDER,
CUDNN_PARAM_DXDATA_PLACEHOLDER, CUDNN_PARAM_BN_SCALE_PLACEHOLDER,
CUDNN_PARAM_BN_BIAS_PLACEHOLDER, CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER,
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}
// input desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc());
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc());
}
// scale/bias/mean/var desc for backward
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC,
args_.scale_bias_mean_var_desc.desc());
// output desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DYDESC, args_.out_desc.desc());
// bitmask desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC,
args_.bitmask_desc.desc());
// activation desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC,
args_.activation_desc.desc());
// others
bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
}
bool fuse_add_ = false;
bool has_shortcut_ = false;
size_t fwd_workspace_byte_;
size_t bwd_workspace_byte_;
ScaleBiasAddReluArgs<T> args_;
CudnnFusionOp fwd_op_;
CudnnFusionOp bwd_op_;
};
#endif
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// Shape of bitmask
static framework::DDim GetBitmaskDims(std::vector<int> out_shape) {
int c = out_shape.back();
int64_t nhw = std::accumulate(out_shape.begin(), out_shape.end(), 1,
std::multiplies<int>()) /
c;
int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = ((nhw + 31) & ~31);
std::vector<int> bitmask_shape = {nhw_int32_elems, c_int32_elems, 1};
return framework::make_ddim(bitmask_shape);
}
class ResNetUnitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
// Check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("MeanX"), "Input", "MeanX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("VarX"), "Input", "VarX", "ResNetUnitOp");
bool fuse_add = ctx->Attrs().Get<bool>("fuse_add");
bool has_shortcut = ctx->Attrs().Get<bool>("has_shortcut");
if (fuse_add || has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitOp");
}
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("MeanZ"), "Input", "MeanZ", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("VarZ"), "Input", "VarZ", "ResNetUnitOp");
}
// Check output
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("BitMask"), "Output", "BitMask",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("ConvX"), "Output", "ConvX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedMeanX"), "Output", "SavedMeanX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdX"), "Output", "SavedInvstdX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningMeanX"), "Output", "RunningMeanX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningVarX"), "Output", "RunningVarX",
"ResNetUnitOp");
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasOutput("ConvZ"), "Output", "ConvZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedMeanZ"), "Output", "SavedMeanZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdZ"), "Output", "SavedInvstdZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningMeanZ"), "Output", "RunningMeanZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningVarZ"), "Output", "RunningVarZ",
"ResNetUnitOp");
}
// make sure Mean/RunningMean and Var/RunningVar share memory
PADDLE_ENFORCE_EQ(
ctx->Inputs("MeanX")[0], ctx->Outputs("RunningMeanX")[0],
platform::errors::InvalidArgument(
"MeanX and RunningMeanX should share the same memory"));
PADDLE_ENFORCE_EQ(ctx->Inputs("VarX")[0], ctx->Outputs("RunningVarX")[0],
platform::errors::InvalidArgument(
"VarX and RunningVarX should share the same memory"));
if (has_shortcut) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("MeanZ")[0], ctx->Outputs("RunningMeanZ")[0],
platform::errors::InvalidArgument(
"MeanZ and RunningMeanZ should share the same memory"));
PADDLE_ENFORCE_EQ(
ctx->Inputs("VarZ")[0], ctx->Outputs("RunningVarZ")[0],
platform::errors::InvalidArgument(
"VarZ and RunningVarZ should share the same memory"));
}
// Check dims of inputs
const auto x_dims = ctx->GetInputDim("X");
const auto w_dims = ctx->GetInputDim("FilterX");
const auto bn_param_dims = ctx->GetInputDim("ScaleX");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument(
"The dimensions of input "
"must equal to 4."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]",
x_dims, x_dims.size()));
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of filter "
"must equal to 4."
"But received: the shape of filter "
"= [%s], the dimension of filter = [%d] ",
w_dims, w_dims.size()));
PADDLE_ENFORCE_EQ(bn_param_dims.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of bn param "
"must equal to 4."
"But received: the shape of bn param "
"= [%s], the dimension of bn param = [%d] ",
bn_param_dims, bn_param_dims.size()));
auto data_format = ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(
data_format, "NHWC",
platform::errors::InvalidArgument("The data format must equal to NHWC. "
"But received: the data format "
"= [%s]",
data_format));
// Calculate the dims of outputs
int batch = x_dims[0];
int output_channel = w_dims[0];
int filter_size = w_dims[2];
int stride = ctx->Attrs().Get<int>("stride");
int padding = ctx->Attrs().Get<int>("padding");
int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1;
int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1;
std::vector<int> out_shape = {batch, out_h, out_w, output_channel};
auto y_dims = framework::make_ddim(out_shape);
auto bitmask_dims = GetBitmaskDims(out_shape);
// Set dims of outputs
ctx->SetOutputDim("Y", y_dims);
ctx->SetOutputDim("BitMask", bitmask_dims);
ctx->SetOutputDim("ConvX", y_dims);
ctx->SetOutputDim("SavedMeanX", bn_param_dims);
ctx->SetOutputDim("SavedInvstdX", bn_param_dims);
ctx->SetOutputDim("RunningMeanX", bn_param_dims);
ctx->SetOutputDim("RunningVarX", bn_param_dims);
if (has_shortcut) {
ctx->SetOutputDim("ConvZ", y_dims);
ctx->SetOutputDim("SavedMeanZ", bn_param_dims);
ctx->SetOutputDim("SavedInvstdZ", bn_param_dims);
ctx->SetOutputDim("RunningMeanZ", bn_param_dims);
ctx->SetOutputDim("RunningVarZ", bn_param_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto bn_param_type = framework::proto::VarType::FP32;
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("ScaleX")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("BiasX")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
};
class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "The input 1 tensor");
AddInput("FilterX", "Filter tensor of input 1");
AddInput("ScaleX", "Scale tensor of input 1 used in batchnorm");
AddInput("BiasX", "Bias tensor of input 1 used in batchnorm");
AddInput("MeanX", "Mean tensor of input 1 used in batchnorm");
AddInput("VarX", "Variance tensor of input 1 used in batchnorm");
AddInput("Z", "The input 2 tensor").AsDispensable();
AddInput("FilterZ", "Filter tensor of input 2").AsDispensable();
AddInput("ScaleZ", "Scale tensor of input 2").AsDispensable();
AddInput("BiasZ", "Bias tensor of input 2").AsDispensable();
AddInput("MeanZ", "Mean tensor of input 2").AsDispensable();
AddInput("VarZ", "Variance tensor of input 2").AsDispensable();
AddOutput("Y", "The result of the resnet unit");
AddOutput("BitMask", "The bitmask generated after relu");
AddOutput("ConvX", "The output of input 1 after conv");
AddOutput("SavedMeanX", "Mean of input 1 in the current batch");
AddOutput("SavedInvstdX", "Invstd of input 1 in the current batch");
AddOutput("RunningMeanX", "Shared memory with MeanX");
AddOutput("RunningVarX", "Shared memory with VarX");
AddOutput("ConvZ", "The output of input 2 after conv").AsDispensable();
AddOutput("SavedMeanZ", "Mean of input 1 in the current batch")
.AsDispensable();
AddOutput("SavedInvstdZ", "Invstd of input 1 in the current batch")
.AsDispensable();
AddOutput("RunningMeanZ", "Shared memory with MeanZ").AsDispensable();
AddOutput("RunningVarZ", "Shared memory with VarZ").AsDispensable();
AddAttr<int>("stride", "").SetDefault(1);
AddAttr<int>("stride_z", "").SetDefault(1);
AddAttr<int>("padding", "").SetDefault(0);
AddAttr<int>("dilation", "").SetDefault(1);
AddAttr<int>("group", "").SetDefault(1);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "").SetDefault(1e-5);
AddAttr<std::string>("data_format", "").SetDefault("NHWC");
AddAttr<bool>("fuse_add", "").SetDefault(false);
AddAttr<bool>("has_shortcut", "").SetDefault(false);
AddAttr<bool>("use_global_stats", "").SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("use_addto", "").SetDefault(false);
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
AddComment(R"DOC(
Fusion op of the basic unit of resnet block.
The implementation is based on the latest fusion op interface in cuDNN v8.0.
For more details:
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t
)DOC");
}
};
class ResNetUnitGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
// check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ConvX"), "Input", "ConvX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedMeanX"), "Input", "SavedMeanX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedInvstdX"), "Input", "SavedInvstdX",
"ResNetUnitGradOp");
bool fuse_add = ctx->Attrs().Get<bool>("fuse_add");
bool has_shortcut = ctx->Attrs().Get<bool>("has_shortcut");
if (fuse_add || has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitGradOp");
}
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ConvZ"), "Input", "ConvZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedMeanZ"), "Input", "SavedMeanZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedInvstdZ"), "Input", "SavedInvstdZ",
"ResNetUnitGradOp");
}
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("BitMask"), "Input", "BitMask",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "ResNetUnitGradOp");
// check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterX")), "Output",
framework::GradVarName("FilterX"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleX")), "Output",
framework::GradVarName("ScaleX"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasX")), "Output",
framework::GradVarName("BiasX"), "ResNetUnitGradOp");
if (fuse_add) {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Z")), "Output",
framework::GradVarName("Z"), "ResNetUnitGradOp");
}
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterZ")),
"Output", framework::GradVarName("FilterZ"),
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleZ")), "Output",
framework::GradVarName("ScaleZ"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasZ")), "Output",
framework::GradVarName("BiasZ"), "ResNetUnitGradOp");
}
const auto x_dims = ctx->GetInputDim("X");
const auto filter_x_dims = ctx->GetInputDim("FilterX");
const auto param_dims = ctx->GetInputDim("ScaleX");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->SetOutputDim(framework::GradVarName("FilterX"), filter_x_dims);
ctx->SetOutputDim(framework::GradVarName("ScaleX"), param_dims);
ctx->SetOutputDim(framework::GradVarName("BiasX"), param_dims);
if (fuse_add || has_shortcut) {
const auto z_dims = ctx->GetInputDim("Z");
ctx->SetOutputDim(framework::GradVarName("Z"), z_dims);
}
if (has_shortcut) {
const auto filter_z_dims = ctx->GetInputDim("FilterZ");
ctx->SetOutputDim(framework::GradVarName("FilterZ"), filter_z_dims);
ctx->SetOutputDim(framework::GradVarName("ScaleZ"), param_dims);
ctx->SetOutputDim(framework::GradVarName("BiasZ"), param_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar(framework::GradVarName("Y")),
platform::errors::NotFound(
"Can not find Y@GRAD in the execution context."));
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
}
};
template <typename T>
class ResNetUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("resnet_unit_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("FilterX", this->Input("FilterX"));
op->SetInput("ConvX", this->Output("ConvX"));
op->SetInput("ScaleX", this->Input("ScaleX"));
op->SetInput("BiasX", this->Input("BiasX"));
op->SetInput("SavedMeanX", this->Output("SavedMeanX"));
op->SetInput("SavedInvstdX", this->Output("SavedInvstdX"));
op->SetInput("Z", this->Input("Z"));
op->SetInput("FilterZ", this->Input("FilterZ"));
op->SetInput("ConvZ", this->Output("ConvZ"));
op->SetInput("ScaleZ", this->Input("ScaleZ"));
op->SetInput("BiasZ", this->Input("BiasZ"));
op->SetInput("SavedMeanZ", this->Output("SavedMeanZ"));
op->SetInput("SavedInvstdZ", this->Output("SavedInvstdZ"));
op->SetInput("Y", this->Output("Y"));
op->SetInput("BitMask", this->Output("BitMask"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("FilterX"),
this->InputGrad("FilterX"));
op->SetOutput(framework::GradVarName("ScaleX"), this->InputGrad("ScaleX"));
op->SetOutput(framework::GradVarName("BiasX"), this->InputGrad("BiasX"));
op->SetOutput(framework::GradVarName("Z"), this->InputGrad("Z"));
op->SetOutput(framework::GradVarName("FilterZ"),
this->InputGrad("FilterZ"));
op->SetOutput(framework::GradVarName("ScaleZ"), this->InputGrad("ScaleZ"));
op->SetOutput(framework::GradVarName("BiasZ"), this->InputGrad("BiasZ"));
}
};
class ResNetUnitOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(resnet_unit, ops::ResNetUnitOp, ops::ResNetUnitOpMaker,
ops::ResNetUnitOpInferVarType,
ops::ResNetUnitGradOpMaker<paddle::framework::OpDesc>,
ops::ResNetUnitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(resnet_unit_grad, ops::ResNetUnitGradOp);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h"
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ResNetUnitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It must use CUDAPlace."));
PADDLE_ENFORCE_EQ(platform::CudnnDataType<T>::type, CUDNN_DATA_HALF,
platform::errors::Unavailable(
"ResNetUnitOp only supports float16 for now."));
// input x
const Tensor *input_x = ctx.Input<Tensor>("X");
const Tensor *filter_x = ctx.Input<Tensor>("FilterX");
const Tensor *scale_x = ctx.Input<Tensor>("ScaleX");
const Tensor *bias_x = ctx.Input<Tensor>("BiasX");
// norm conv
Tensor *conv_out_x = ctx.Output<Tensor>("ConvX");
// bn finalize
Tensor *saved_mean_x = ctx.Output<Tensor>("SavedMeanX");
Tensor *saved_invstd_x = ctx.Output<Tensor>("SavedInvstdX");
Tensor *running_mean_x = ctx.Output<Tensor>("RunningMeanX");
Tensor *running_var_x = ctx.Output<Tensor>("RunningVarX");
// sbar
Tensor *output = ctx.Output<Tensor>("Y");
Tensor *bitmask = ctx.Output<Tensor>("BitMask");
// attrs
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool is_test = ctx.Attr<bool>("is_test");
bool is_train = !is_test && !use_global_stats;
std::string act_type = ctx.Attr<std::string>("act_type");
auto input_x_shape = framework::vectorize<int>(input_x->dims());
auto filter_x_shape = framework::vectorize<int>(filter_x->dims());
auto param_dims = scale_x->dims();
auto param_shape = framework::vectorize<int>(scale_x->dims());
auto output_shape = framework::vectorize<int>(output->dims());
auto bitmask_shape = framework::vectorize<int>(bitmask->dims());
int output_channel = filter_x_shape[0];
int64_t ele_count =
std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>()) /
output_channel;
auto place = ctx.GetPlace();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// 1. Conv
Tensor sum_x;
Tensor sum_of_squares_x;
sum_x.Resize(param_dims);
sum_of_squares_x.Resize(param_dims);
CudnnNormConvolution<T> conv_x_op(dev_ctx, input_x_shape, filter_x_shape,
output_shape, padding, stride, dilation,
group);
conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x,
&sum_of_squares_x);
// 2. BN
Tensor equiv_scale_x;
Tensor equiv_bias_x;
equiv_scale_x.Resize(param_dims);
equiv_bias_x.Resize(param_dims);
CudnnBNStatsFinalize<T> bn_x_op(dev_ctx, param_shape);
bn_x_op.Forward(dev_ctx, sum_x, sum_of_squares_x, *scale_x, *bias_x,
saved_mean_x, saved_invstd_x, running_mean_x, running_var_x,
&equiv_scale_x, &equiv_bias_x, eps, momentum, ele_count,
is_train);
// 3. scale + bias + add + relu
CudnnScaleBiasAddRelu<T> sbar_op(dev_ctx, act_type, fuse_add, has_shortcut,
output_shape, param_shape, bitmask_shape);
if (has_shortcut) {
// input z
const Tensor *input_z = ctx.Input<Tensor>("Z");
const Tensor *filter_z = ctx.Input<Tensor>("FilterZ");
const Tensor *scale_z = ctx.Input<Tensor>("ScaleZ");
const Tensor *bias_z = ctx.Input<Tensor>("BiasZ");
// norm conv
Tensor *conv_out_z = ctx.Output<Tensor>("ConvZ");
// bn finalize
Tensor *saved_mean_z = ctx.Output<Tensor>("SavedMeanZ");
Tensor *saved_invstd_z = ctx.Output<Tensor>("SavedInvstdZ");
Tensor *running_mean_z = ctx.Output<Tensor>("RunningMeanZ");
Tensor *running_var_z = ctx.Output<Tensor>("RunningVarZ");
auto input_z_shape = framework::vectorize<int>(input_z->dims());
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
// 3.1 Conv for second input
Tensor sum_z;
Tensor sum_of_squares_z;
sum_z.Resize(param_dims);
sum_of_squares_z.Resize(param_dims);
CudnnNormConvolution<T> conv_z_op(dev_ctx, input_z_shape, filter_z_shape,
output_shape, padding, stride_z,
dilation, group);
conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z,
&sum_of_squares_z);
// 3.2 BN for second input
Tensor equiv_scale_z;
Tensor equiv_bias_z;
equiv_scale_z.Resize(param_dims);
equiv_bias_z.Resize(param_dims);
CudnnBNStatsFinalize<T> bn_z_op(dev_ctx, param_shape);
bn_z_op.Forward(dev_ctx, sum_z, sum_of_squares_z, *scale_z, *bias_z,
saved_mean_z, saved_invstd_z, running_mean_z,
running_var_z, &equiv_scale_z, &equiv_bias_z, eps,
momentum, ele_count, is_train);
// 3.3 sbar
sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x,
conv_out_z, &equiv_scale_z, &equiv_bias_z, output,
bitmask);
} else {
const Tensor *input_z = fuse_add ? ctx.Input<Tensor>("Z") : nullptr;
sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x,
input_z, nullptr, nullptr, output, bitmask);
}
}
};
template <typename T>
class ResNetUnitGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It must use CUDAPlace."));
PADDLE_ENFORCE_EQ(platform::CudnnDataType<T>::type, CUDNN_DATA_HALF,
platform::errors::Unavailable(
"ResNetUnitOp only supports float16 for now."));
const Tensor *y_grad = ctx.Input<Tensor>(framework::GradVarName("Y"));
const Tensor *x = ctx.Input<Tensor>("X");
const Tensor *filter_x = ctx.Input<Tensor>("FilterX");
const Tensor *scale_x = ctx.Input<Tensor>("ScaleX");
const Tensor *bias_x = ctx.Input<Tensor>("BiasX");
const Tensor *saved_mean_x = ctx.Input<Tensor>("SavedMeanX");
const Tensor *saved_invstd_x = ctx.Input<Tensor>("SavedInvstdX");
const Tensor *conv_out_x = ctx.Input<Tensor>("ConvX");
const Tensor *output = ctx.Input<Tensor>("Y");
const Tensor *bitmask = ctx.Input<Tensor>("BitMask");
Tensor *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor *filter_x_grad =
ctx.Output<Tensor>(framework::GradVarName("FilterX"));
Tensor *scale_x_grad = ctx.Output<Tensor>(framework::GradVarName("ScaleX"));
Tensor *bias_x_grad = ctx.Output<Tensor>(framework::GradVarName("BiasX"));
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
std::string act_type = ctx.Attr<std::string>("act_type");
auto x_shape = framework::vectorize<int>(x->dims());
auto filter_x_shape = framework::vectorize<int>(filter_x->dims());
auto param_shape = framework::vectorize<int>(scale_x->dims());
auto output_shape = framework::vectorize<int>(output->dims());
auto bitmask_shape = framework::vectorize<int>(bitmask->dims());
auto place = ctx.GetPlace();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// 1. Backward of BN (+ Add + Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad
Tensor conv_out_x_grad;
conv_out_x_grad.Resize(conv_out_x->dims());
CudnnScaleBiasAddRelu<T> sbar_x_op(dev_ctx, act_type, fuse_add,
has_shortcut, output_shape, param_shape,
bitmask_shape);
if (has_shortcut) {
// X Z
// | |
// NormConv NormConv
// | |
// BNStatsFinalize BNStatsFinalize
// \ /
// ScaleBiasAddRelu
// |
// Y
const Tensor *z = ctx.Input<Tensor>("Z");
const Tensor *filter_z = ctx.Input<Tensor>("FilterZ");
const Tensor *scale_z = ctx.Input<Tensor>("ScaleZ");
const Tensor *bias_z = ctx.Input<Tensor>("BiasZ");
const Tensor *saved_mean_z = ctx.Input<Tensor>("SavedMeanZ");
const Tensor *saved_invstd_z = ctx.Input<Tensor>("SavedInvstdZ");
const Tensor *conv_out_z = ctx.Input<Tensor>("ConvZ");
Tensor *z_grad = ctx.Output<Tensor>(framework::GradVarName("Z"));
Tensor *filter_z_grad =
ctx.Output<Tensor>(framework::GradVarName("FilterZ"));
Tensor *scale_z_grad =
ctx.Output<Tensor>(framework::GradVarName("ScaleZ"));
Tensor *bias_z_grad = ctx.Output<Tensor>(framework::GradVarName("BiasZ"));
// 1.1 Backward of BN + Add (+ Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad and z_grad_temp
Tensor z_grad_temp;
z_grad_temp.Resize(conv_out_z->dims());
sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x,
*saved_mean_x, *saved_invstd_x, bitmask,
&conv_out_x_grad, &z_grad_temp, scale_x_grad,
bias_x_grad, eps);
// 1.2 bn backward for z, get conv_out_z_grad, dscale_z, dbias_z
Tensor conv_out_z_grad;
conv_out_z_grad.Resize(conv_out_z->dims());
CudnnScaleBiasAddRelu<T> sbar_z_op(
dev_ctx, "", false, false, output_shape, param_shape, bitmask_shape);
sbar_z_op.Backward(dev_ctx, z_grad_temp, *conv_out_z, *scale_z, *bias_z,
*saved_mean_z, *saved_invstd_z, nullptr,
&conv_out_z_grad, nullptr, scale_z_grad, bias_z_grad,
eps);
// 1.3 Backward of Conv for z, get z_grad and filter_z_grad
auto z_shape = framework::vectorize<int>(z->dims());
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
CudnnNormConvolutionGrad<T> conv_z_op(dev_ctx, z_shape, filter_z_shape,
output_shape, padding, stride_z,
dilation, group);
conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad,
filter_z_grad);
} else {
// 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad (and z_grad)
Tensor *z_grad =
fuse_add ? ctx.Output<Tensor>(framework::GradVarName("Z")) : nullptr;
sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x,
*saved_mean_x, *saved_invstd_x, bitmask,
&conv_out_x_grad, z_grad, scale_x_grad, bias_x_grad,
eps);
}
// 2. Backward of Conv for x, get x_grad and filter_x_grad
bool use_addto = ctx.Attr<bool>("use_addto");
CudnnNormConvolutionGrad<T> conv_x_op(dev_ctx, x_shape, filter_x_shape,
output_shape, padding, stride,
dilation, group);
conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad,
filter_x_grad, use_addto);
}
};
} // namespace operators
} // namespace paddle
#if CUDNN_VERSION >= 8000
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(resnet_unit, ops::ResNetUnitKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(resnet_unit_grad,
ops::ResNetUnitGradKernel<plat::float16>);
#endif
......@@ -16,66 +16,141 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#ifdef __HIPCC__
#define POOLING_BLOCK_SIZE 256
#else
#define POOLING_BLOCK_SIZE 512
#endif
namespace paddle {
namespace operators {
namespace math {
struct FastDivModForPooling {
public:
platform::FastDivMod channel;
platform::FastDivMod width;
platform::FastDivMod height;
explicit HOSTDEVICE FastDivModForPooling(const int channels,
const int output_width,
const int output_height) {
channel = platform::FastDivMod(channels);
width = platform::FastDivMod(output_width);
height = platform::FastDivMod(output_height);
}
};
struct FastDivModForPoolingWithMoreStaff {
public:
platform::FastDivMod channel;
platform::FastDivMod width;
platform::FastDivMod height;
platform::FastDivMod ksize_w;
platform::FastDivMod ksize_h;
platform::FastDivMod stride_w;
platform::FastDivMod stride_h;
explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff(
const int channels, const int input_width, const int input_height,
const int ksize_width, const int ksize_height, const int stride_width,
const int stride_height) {
channel = platform::FastDivMod(channels);
width = platform::FastDivMod(input_width);
height = platform::FastDivMod(input_height);
ksize_w = platform::FastDivMod(ksize_width);
ksize_h = platform::FastDivMod(ksize_height);
stride_w = platform::FastDivMod(stride_width);
stride_h = platform::FastDivMod(stride_height);
}
};
template <typename FastDivModForPooling>
__device__ void OffsetPreparationFor4Dimension(
int index, bool channel_last, FastDivModForPooling divmods,
const int pad_width, const int pad_height, const int aux_width,
const int aux_height, int* w_offset, int* h_offset, int* c_offset,
int* stride) {
if (!channel_last) { /* NCHW */
auto input_width_divmod = divmods.width.Divmod(index);
auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
*w_offset = input_width_divmod.val[1] + pad_width;
*h_offset = input_height_divmod.val[1] + pad_height;
*c_offset = channel_divmod.val[1];
*stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) *
aux_height * aux_width;
} else { /* NHWC */
auto c_divmod = divmods.channel.Divmod(index);
auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]);
auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
*c_offset = c_divmod.val[1];
*w_offset = input_width_divmod.val[1] + pad_width;
*h_offset = input_height_divmod.val[1] + pad_height;
*stride = input_height_divmod.val[0] * aux_height * aux_width *
divmods.channel.divisor;
}
}
int GetThreadsPerBlock(const platform::CUDADeviceContext& ctx,
int threads_per_block, int64_t numel) {
int sm_count = ctx.GetSMCount();
if (numel / (sm_count << 1) < threads_per_block) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 1));
} else if (numel / (sm_count << 2) < threads_per_block) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads_per_block);
}
template <typename PoolProcess, typename T>
__global__ void KernelPool2D(const int nthreads, const T* input_data,
const int channels, const int input_height,
const int input_width, const int output_height,
const int output_width, const int ksize_height,
const int ksize_width, const int stride_height,
const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_process,
bool exclusive, bool adaptive, T* output_data,
bool channel_last = false) {
__global__ void KernelPool2D(
const int nthreads, const T* input_data, const int channels,
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, FastDivModForPooling divmods,
PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw, ph, c, batch_idx;
if (!channel_last) { /*NCHW*/
pw = index % output_width;
ph = (index / output_width) % output_height;
c = (index / output_width / output_height) % channels;
batch_idx = index / output_width / output_height / channels;
} else { /*NHWC*/
c = index % channels;
pw = (index / channels) % output_width;
ph = (index / channels / output_width) % output_height;
batch_idx = index / channels / output_width / output_height;
}
int hstart, hend, wstart, wend;
int w_offset, h_offset, c_offset, input_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, channel_last, divmods, 0, 0, input_width, input_height,
&w_offset, &h_offset, &c_offset, &input_offset);
input_data += input_offset;
int hstart, hend;
int wstart, wend;
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
} else {
hstart = ph * stride_height - padding_height;
hstart = h_offset * stride_height - padding_height;
hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
wstart = pw * stride_width - padding_width;
wstart = w_offset * stride_width - padding_width;
wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
}
if (!channel_last) {
input_data += (batch_idx * channels + c) * input_height * input_width;
} else {
input_data += batch_idx * input_height * input_width * channels;
}
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
auto input_idx = channel_last ? (h * input_width + w) * channels + c
: h * input_width + w;
auto input_idx = channel_last
? (h * input_width + w) * channels + c_offset
: h * input_width + w;
pool_process.compute(input_data[input_idx], &ele);
}
}
......@@ -85,91 +160,109 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
output_data[index] = ele;
}
}
template <typename PoolProcess, typename T>
template <typename T, typename PoolProcess>
__global__ void KernelPool2DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, const int channels, const int input_height,
const int input_width, const int output_height, const int output_width,
const int ksize_height, const int ksize_width, const int stride_height,
const int stride_width, const int padding_height, const int padding_width,
PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad,
bool channel_last = false) {
const int nthreads, const T* __restrict__ input_data,
const T* __restrict__ output_data, const const T* __restrict__ output_grad,
const int output_width, const int output_height, const int input_width,
const int input_height, const int ksize_width, const int ksize_height,
const int stride_width, const int stride_height, const int padding_width,
const int padding_height, FastDivModForPoolingWithMoreStaff divmods,
PoolProcess pool_process, bool exclusive, bool adaptive,
T* __restrict__ input_grad, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int w_offset, h_offset, offsetC, batch_idx;
if (!channel_last) { /* NCHW */
w_offset = index % input_width + padding_width;
h_offset = (index / input_width) % input_height + padding_height;
offsetC = (index / input_width / input_height) % channels;
batch_idx = index / input_width / input_height / channels;
} else { /* NHWC */
offsetC = index % channels;
w_offset = (index / channels) % input_width + padding_width;
h_offset =
(index / channels / input_width) % input_height + padding_height;
batch_idx = index / channels / input_width / input_height;
T input = static_cast<T>(0);
T input_grad_data = static_cast<T>(0);
int phstart, phend, pwstart, pwend;
int w_offset, h_offset, c_offset, output_offset;
OffsetPreparationFor4Dimension<>(index, channel_last, divmods,
padding_width, padding_height,
output_width, output_height, &w_offset,
&h_offset, &c_offset, &output_offset);
if (pool_process.use_x) {
input = input_data[index];
output_data += output_offset;
}
output_grad += output_offset;
int phstart, phend;
int pwstart, pwend;
if (adaptive) {
phstart = AdaptStartIndex(h_offset, output_height, input_height);
phend = AdaptEndIndex(h_offset, output_height, input_height);
auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
phstart = divmods.height.Div(h_offset * output_height);
pwstart = divmods.width.Div(w_offset * output_width);
phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
pwstart = AdaptStartIndex(w_offset, output_width, input_width);
pwend = AdaptEndIndex(w_offset, output_width, input_width);
} else {
phstart = (h_offset < ksize_height)
? 0
: (h_offset - ksize_height) / stride_height + 1;
pwstart = (w_offset < ksize_width)
? 0
: (w_offset - ksize_width) / stride_width + 1;
phend = min(h_offset / stride_height + 1, output_height);
pwend = min(w_offset / stride_width + 1, output_width);
}
T gradient = static_cast<T>(0.0);
T input = input_data[index];
int output_stride;
if (!channel_last) {
output_stride =
(batch_idx * channels + offsetC) * output_height * output_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
: ksize_w_divmod.val[0];
auto tmp_height = ksize_h_divmod.val[1] > 0
? ksize_h_divmod.val[0] + 1
: ksize_h_divmod.val[0];
int pool_size = tmp_height * tmp_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + c_offset
: tmp_idx;
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size),
&input_grad_data);
}
}
} else {
output_stride = batch_idx * output_height * output_width * channels;
}
output_data += output_stride;
output_grad += output_stride;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size;
if (adaptive) {
pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
ksize_height)) *
static_cast<int>(
ceil(static_cast<double>(input_width) / ksize_width));
} else {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);
if (exclusive) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + c_offset
: tmp_idx;
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(
input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &input_grad_data);
}
}
} else {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size = ksize_height * ksize_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + c_offset
: tmp_idx;
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(
input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &input_grad_data);
}
}
int output_sub_idx = channel_last
? (ph * output_width + pw) * channels + offsetC
: ph * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &gradient);
}
}
input_grad[index] = gradient;
input_grad[index] = input_grad_data;
}
}
......@@ -180,45 +273,32 @@ __global__ void KernelMaxPool2DGrad(
const int input_width, const int output_height, const int output_width,
const int ksize_height, const int ksize_width, const int stride_height,
const int stride_width, const int padding_height, const int padding_width,
T* input_grad, bool channel_last = false) {
T* input_grad, FastDivModForPooling divmods, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw, ph, c, batch_idx;
if (!channel_last) { /* NCHW */
pw = index % output_width;
ph = (index / output_width) % output_height;
c = (index / output_width / output_height) % channels;
batch_idx = index / output_width / output_height / channels;
} else { /* NHWC */
c = index % channels;
pw = (index / channels) % output_width;
ph = (index / channels / output_width) % output_height;
batch_idx = index / channels / output_width / output_height;
}
int hstart = ph * stride_height - padding_height;
int w_offset, h_offset, c_offset, input_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, channel_last, divmods, 0, 0, input_width, input_height,
&w_offset, &h_offset, &c_offset, &input_offset);
input_data += input_offset;
input_grad += input_offset;
int hstart = h_offset * stride_height - padding_height;
int hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
int wstart = pw * stride_width - padding_width;
int wstart = w_offset * stride_width - padding_width;
int wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
int input_stride;
if (!channel_last) {
input_stride = (batch_idx * channels + c) * input_height * input_width;
} else {
input_stride = batch_idx * input_height * input_width * channels;
}
input_data += input_stride;
input_grad += input_stride;
T ele = output_data[index];
int maxIndex = -1;
bool stop = false;
for (int h = hstart; h < hend && !stop; ++h) {
for (int w = wstart; w < wend && !stop; ++w) {
int input_data_idx = channel_last ? (h * input_width + w) * channels + c
: h * input_width + w;
int input_data_idx = channel_last
? (h * input_width + w) * channels + c_offset
: h * input_width + w;
if (ele == input_data[input_data_idx]) {
maxIndex = input_data_idx;
stop = true;
......@@ -264,10 +344,13 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
nthreads, input, input_channels, input_height, input_width, output_height,
output_width, ksize_height, ksize_width, stride_height, stride_width,
padding_height, padding_width, pool_compute, exclusive, adaptive, output);
padding_height, padding_width, pool_divmods, pool_compute, exclusive,
adaptive, output);
}
/*
......@@ -311,11 +394,14 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, pool_process, exclusive,
adaptive, output_data);
stride_width, padding_height, padding_width, pool_divmods, pool_process,
exclusive, adaptive, output_data);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const std::vector<int>& ksize,
......@@ -357,11 +443,14 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, pool_process, exclusive,
adaptive, output_data, channel_last);
stride_width, padding_height, padding_width, pool_divmods, pool_process,
exclusive, adaptive, output_data, channel_last);
}
};
/*
......@@ -402,15 +491,18 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
pool_process, exclusive, adaptive, input_grad_data);
int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads);
int grids = (nthreads + blocks - 1) / blocks;
auto pool_divmods = FastDivModForPoolingWithMoreStaff(
input_channels, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height);
KernelPool2DGrad<T, PoolProcess><<<grids, blocks, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, output_width,
output_height, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height, padding_width, padding_height,
pool_divmods, pool_process, exclusive, adaptive, input_grad_data);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
......@@ -424,7 +516,6 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
bool channel_last = (data_format == "NHWC");
const int batch_size = input.dims()[0];
const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
const int input_width = channel_last ? input.dims()[2] : input.dims()[3];
......@@ -447,19 +538,22 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
pool_process, exclusive, adaptive, input_grad_data, channel_last);
int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads);
int grids = (nthreads + blocks - 1) / blocks;
auto pool_divmods = FastDivModForPoolingWithMoreStaff(
input_channels, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height);
KernelPool2DGrad<T, PoolProcess><<<grids, blocks, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, output_width,
output_height, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height, padding_width, padding_height,
pool_divmods, pool_process, exclusive, adaptive, input_grad_data,
channel_last);
}
};
......@@ -505,11 +599,13 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
input_grad_data);
input_grad_data, pool_divmods);
}
void operator()(
const platform::CUDADeviceContext& context,
......@@ -550,11 +646,14 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
input_grad_data, channel_last);
input_grad_data, pool_divmods, channel_last);
}
};
......@@ -689,35 +788,40 @@ __global__ void KernelPool3D(
}
}
template <typename PoolProcess, typename T>
template <typename T, typename PoolProcess>
__global__ void KernelPool3DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, const int channels, const int input_depth,
const int input_height, const int input_width, const int output_depth,
const int output_height, const int output_width, const int ksize_depth,
const int ksize_height, const int ksize_width, const int stride_depth,
const int stride_height, const int stride_width, const int padding_depth,
const int padding_height, const int padding_width, PoolProcess pool_process,
bool exclusive, bool adaptive, T* input_grad, bool channel_last = false) {
const int nthreads, const T* __restrict__ input_data,
const T* __restrict__ output_data, const T* __restrict__ output_grad,
const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height,
const int output_width, const int ksize_depth, const int ksize_height,
const int ksize_width, const int stride_depth, const int stride_height,
const int stride_width, const int padding_depth, const int padding_height,
const int padding_width, PoolProcess pool_process, bool exclusive,
bool adaptive, T* input_grad, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int w_offset, h_offset, d_offset, offsetC, batch_idx;
int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride;
T input = static_cast<T>(0);
if (!channel_last) { /* "NCDHW" */
w_offset = index % input_width + padding_width;
h_offset = (index / input_width) % input_height + padding_height;
d_offset =
(index / input_width / input_height) % input_depth + padding_depth;
offsetC = (index / input_width / input_height / input_depth) % channels;
c_offset = (index / input_width / input_height / input_depth) % channels;
batch_idx = index / input_width / input_height / input_depth / channels;
output_stride = (batch_idx * channels + c_offset) * output_depth *
output_height * output_width;
} else { /* "NDHWC" */
offsetC = index % channels;
c_offset = index % channels;
w_offset = (index / channels) % input_width + padding_width;
h_offset =
(index / channels / input_width) % input_height + padding_height;
d_offset = (index / channels / input_width / input_height) % input_depth +
padding_depth;
batch_idx = index / channels / input_width / input_height / input_depth;
output_stride =
batch_idx * output_depth * output_height * output_width * channels;
}
int pdstart, pdend;
......@@ -746,20 +850,12 @@ __global__ void KernelPool3DGrad(
phend = min((h_offset) / stride_height + 1, output_height);
pwend = min((w_offset) / stride_width + 1, output_width);
}
T gradient = static_cast<T>(0.0);
T input = input_data[index];
int output_stride;
if (!channel_last) {
output_stride = (batch_idx * channels + offsetC) * output_depth *
output_height * output_width;
} else {
output_stride =
batch_idx * output_depth * output_height * output_width * channels;
if (pool_process.use_x) {
input = input_data[index];
output_data += output_stride;
}
output_data += output_stride;
output_grad += output_stride;
T input_grad_data = static_cast<T>(0.0);
for (int pd = pdstart; pd < pdend; ++pd) {
for (int ph = phstart; ph < phend; ++ph) {
......@@ -792,16 +888,17 @@ __global__ void KernelPool3DGrad(
int output_sub_idx =
channel_last
? ((pd * output_height + ph) * output_width + pw) * channels +
offsetC
c_offset
: (pd * output_height + ph) * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &gradient);
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size),
&input_grad_data);
}
}
}
input_grad[index] = gradient;
input_grad[index] = input_grad_data;
}
}
......@@ -1088,7 +1185,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_depth, input_height, input_width, output_depth, output_height,
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
......@@ -1142,7 +1239,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_depth, input_height, input_width, output_depth, output_height,
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
......@@ -1315,33 +1412,33 @@ __global__ void KernelMaxPool2dWithIdx(
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, bool adaptive, T1* output_data, T2* mask_data) {
const int padding_width, bool adaptive, T1* output_data, T2* mask_data,
FastDivModForPooling divmods) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels;
int batch_idx = index / output_width / output_height / channels;
int hstart, hend, wstart, wend;
int w_offset, h_offset, c_offset, input_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, false, divmods, 0, 0, input_width, input_height, &w_offset,
&h_offset, &c_offset, &input_offset);
input_data += input_offset;
int hstart, hend;
int wstart, wend;
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
} else {
hstart = ph * stride_height - padding_height;
hstart = h_offset * stride_height - padding_height;
hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
wstart = pw * stride_width - padding_width;
wstart = w_offset * stride_width - padding_width;
wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
}
input_data += (batch_idx * channels + c) * input_height * input_width;
T1 ele = -FLT_MAX;
int max_index = -1;
for (int h = hstart; h < hend; ++h) {
......@@ -1365,16 +1462,17 @@ __global__ void KernelMaxPool2DWithIdxGrad(
const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, bool adaptive,
T1* input_grad) {
T1* input_grad, FastDivModForPooling divmods) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int w_offset = index % input_width;
int h_offset = (index / input_width) % input_height;
int offsetC = (index / input_width / input_height) % channels;
int batch_idx = index / input_width / input_height / channels;
int phstart, phend, pwstart, pwend;
int w_offset, h_offset, c_offset, output_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, false, divmods, 0, 0, output_width, output_height, &w_offset,
&h_offset, &c_offset, &output_offset);
mask_data += output_offset;
output_grad += output_offset;
int phstart, phend;
int pwstart, pwend;
if (adaptive) {
phstart = h_offset * output_height / input_height;
phend =
......@@ -1396,20 +1494,15 @@ __global__ void KernelMaxPool2DWithIdxGrad(
pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
}
T1 gradient = 0;
T1 input_grad_data = 0;
int input_current_featuremap_idx = h_offset * input_width + w_offset;
int output_idx =
(batch_idx * channels + offsetC) * output_height * output_width;
mask_data += output_idx;
output_grad += output_idx;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
gradient += output_grad[ph * output_width + pw];
input_grad_data += output_grad[ph * output_width + pw];
}
}
input_grad[index] = gradient;
input_grad[index] = input_grad_data;
}
}
......@@ -1453,11 +1546,14 @@ class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, adaptive, output_data,
mask_data);
mask_data, pool_divmods);
}
};
......@@ -1497,11 +1593,13 @@ class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, input_width, input_height);
KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
nthreads, output_grad_data, mask_data, input_channels, input_height,
input_width, output_height, output_width, ksize_height, ksize_width,
stride_height, stride_width, padding_height, padding_width, adaptive,
input_grad_data);
input_grad_data, pool_divmods);
}
};
......@@ -1590,7 +1688,8 @@ __global__ void KernelMaxPool3DWithIdxGrad(
int w_offset = index % input_width;
int h_offset = (index / input_width) % input_height;
int d_offset = (index / input_width / input_height) % input_depth;
int offsetC = (index / input_width / input_height / input_depth) % channels;
int c_offset =
(index / input_width / input_height / input_depth) % channels;
int batch_idx = index / input_width / input_height / input_depth / channels;
int pdstart, pdend;
......@@ -1625,10 +1724,10 @@ __global__ void KernelMaxPool3DWithIdxGrad(
pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
}
T1 gradient = 0;
T1 input_grad_data = 0;
int input_current_feature_map_idx =
(d_offset * input_height + h_offset) * input_width + w_offset;
int output_idx = (batch_idx * channels + offsetC) * output_depth *
int output_idx = (batch_idx * channels + c_offset) * output_depth *
output_height * output_width;
mask += output_idx;
output_grad += output_idx;
......@@ -1638,12 +1737,12 @@ __global__ void KernelMaxPool3DWithIdxGrad(
for (int pw = pwstart; pw < pwend; ++pw) {
if (mask[(pd * output_height + ph) * output_width + pw] ==
input_current_feature_map_idx)
gradient +=
input_grad_data +=
output_grad[(pd * output_height + ph) * output_width + pw];
}
}
}
input_grad[index] = gradient;
input_grad[index] = input_grad_data;
}
}
......
......@@ -68,8 +68,9 @@ class AvgPool {
template <class T>
class MaxPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
static constexpr bool use_x = true;
HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += dy * static_cast<T>(x == y);
}
};
......@@ -77,8 +78,9 @@ class MaxPoolGrad {
template <class T>
class AvgPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
static constexpr bool use_x = false;
HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += (scale * dy);
}
};
......
......@@ -13,46 +13,158 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
namespace paddle {
namespace operators {
class LarsMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Velocity"), "Input", "Velocity",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), "Input", "LearningRate",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("ParamOut"), "Output", "ParamOut",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), "Output", "VelocityOut",
"LarsMomentum");
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));
auto lr_dims = ctx->GetInputsDim("LearningRate");
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");
PADDLE_ENFORCE_EQ(
param_dim.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(), grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(), velocity_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(), velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(), grad_dim.size()));
if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam",
"LarsMomentumMultiPrecision");
OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), "Output",
"MasterParamOut", "LarsMomentumMultiPrecision");
}
for (size_t i = 0; i < lr_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
framework::product(lr_dims[i])));
}
for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i],
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx->Inputs("Grad")[i].front(),
ctx->GetInputsVarType("Grad")[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], grad_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i], grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], velocity_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i], velocity_dim[i]));
}
ctx->SetOutputsDim("ParamOut", param_dim);
ctx->SetOutputsDim("VelocityOut", param_dim);
if (ctx->HasOutputs("MasterParamOut")) {
ctx->SetOutputsDim("MasterParamOut", param_dim);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated");
"Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad",
"(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter");
"Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity",
"(LoDTensor, default LoDTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
"that has to be updated")
.AsDuplicable();
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDuplicable()
.AsDispensable();
AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
"It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
"It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDuplicable()
.AsDispensable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
.SetDefault(0.001);
AddAttr<float>("lars_weight_decay",
"(float, default 0.0005) LARS weight decay")
.SetDefault(0.0005);
AddAttr<std::vector<float>>(
"lars_weight_decay",
"(std::vector<float>, default 0.0005) LARS weight decay params")
.SetDefault({0.0005});
AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0);
......@@ -68,10 +180,8 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Lars Momentum Optimizer.
This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each
weight using a local learning rate:
$$
local\_lr = \eta *
\frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\
......@@ -79,10 +189,8 @@ velocity = mu * velocity +
local\_lr * (grad + \beta * param) \\
param = param - velocity. \\
$$
Note that we use lars_weight_decay here to decay weights, you may need not to
use L2 regularizers in case of using LARS.
)DOC");
}
};
......@@ -96,7 +204,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker,
lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference);
......
......@@ -14,7 +14,21 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/platform/fast_divmod.h"
#if CUDA_VERSION >= 11000
#include <cooperative_groups.h>
#endif
#ifdef __HIPCC__
#define LARS_BLOCK_SIZE 256
#else
#define LARS_BLOCK_SIZE 512
#endif
#define LARS_MAX_MERGED_OPS 60
namespace paddle {
namespace operators {
......@@ -22,124 +36,472 @@ namespace operators {
template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* p, const T* g, const MT* v,
const MultiPrecisionType<T>* learning_rate, const MT mu, const int64_t num,
const MT lars_coeff, const MT lars_weight_decay,
const MultiPrecisionType<T>* p_norm, const MultiPrecisionType<T>* g_norm,
T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out,
const MultiPrecisionType<T> rescale_grad) {
const MT lr = static_cast<MT>(learning_rate[0]);
MT local_lr = lr;
const MT p_n = static_cast<MT>(p_norm[0]);
const MT g_n = static_cast<MT>(g_norm[0]);
__device__ __forceinline__ float Sqrt(float x) { return sqrtf(x); }
__device__ __forceinline__ double Sqrt(double x) { return sqrt(x); }
__device__ __forceinline__ float Fma(float x, float y, float z) {
return fmaf(x, y, z);
}
__device__ __forceinline__ double Fma(double x, double y, double z) {
return fma(x, y, z);
}
template <typename T>
class LarsThreadConfig {
public:
int grid_for_norm;
int grid_for_lars;
#if CUDA_VERSION >= 11000
if (lars_weight_decay > static_cast<MT>(0) && p_n > static_cast<MT>(0) &&
g_n > static_cast<MT>(0)) {
local_lr =
lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon);
private:
int grid_stride;
public:
explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_lars =
std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE);
grid_stride = LARS_BLOCK_SIZE * grid_for_lars;
}
CUDA_KERNEL_LOOP(i, num) {
MT grad = static_cast<MT>(g[i]) * static_cast<MT>(rescale_grad);
MT param = master_p ? master_p[i] : static_cast<MT>(p[i]);
MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param);
MT p_new = param - v_new;
int GetRepeatTimes(int64_t numel) {
return (numel + grid_stride - 1) / grid_stride - 1;
}
#else
int repeat_times;
explicit LarsThreadConfig(const int64_t numel) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_norm = std::min(grid, LARS_BLOCK_SIZE);
const int grid_stride = grid_for_norm * LARS_BLOCK_SIZE;
repeat_times = (numel + grid_stride - 1) / grid_stride - 1;
// Determine to read 4 fp16 or float data once, but 2 double data once.
grid_for_lars =
std::is_same<double, T>::value
? (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1)
: (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2);
}
#endif
};
template <typename T, typename MT, int VecSize, bool IsAmp = false>
__device__ inline void VectorizeLarsUpdate(
const T* __restrict__ grad, const MT* param, const MT* velocity,
T* param_out, MT* velocity_out, const MT mu, MT local_lr,
const MT lars_weight_decay, const MT rescale_grad, const int tid,
const int grid_stride, const int numel, MT* master_param_out = nullptr) {
using VecType = paddle::platform::AlignedVector<T, VecSize>;
using VecMType = paddle::platform::AlignedVector<MT, VecSize>;
int main = numel >> (VecSize >> 1);
int tail_offset = main * VecSize;
v_out[i] = v_new;
p_out[i] = static_cast<T>(p_new);
if (master_p_out) master_p_out[i] = p_new;
const VecType* grad_vec = reinterpret_cast<const VecType*>(grad);
const VecMType* param_vec = reinterpret_cast<const VecMType*>(param);
const VecMType* velocity_vec = reinterpret_cast<const VecMType*>(velocity);
VecType* param_out_vec = reinterpret_cast<VecType*>(param_out);
VecMType* velocity_out_vec = reinterpret_cast<VecMType*>(velocity_out);
VecMType* master_param_out_vec;
if (IsAmp) {
master_param_out_vec = reinterpret_cast<VecMType*>(master_param_out);
}
for (int i = tid; i < main; i += grid_stride) {
VecType param_out_tmp;
VecMType velocity_tmp, param_tmp;
VecType grad_data = grad_vec[i];
VecMType param_data = param_vec[i];
VecMType velocity_data = velocity_vec[i];
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
MT grad_val = static_cast<MT>(grad_data[j]) * rescale_grad;
velocity_tmp[j] =
Fma(velocity_data[j], mu,
local_lr * Fma(lars_weight_decay, param_data[j], grad_val));
param_tmp[j] = param_data[j] - velocity_tmp[j];
param_out_tmp[j] = static_cast<T>(param_tmp[j]);
}
param_out_vec[i] = param_out_tmp;
velocity_out_vec[i] = velocity_tmp;
if (IsAmp) {
master_param_out_vec[i] = param_tmp;
}
}
for (int i = tid + tail_offset; i < numel; i += grid_stride) {
MT grad_val = static_cast<MT>(grad[i]) * rescale_grad;
MT param_val = param[i];
MT velocity_tmp = Fma(velocity[i], mu, local_lr * Fma(lars_weight_decay,
param_val, grad_val));
MT param_tmp = param_val - velocity_tmp;
param_out[i] = static_cast<T>(param_tmp);
velocity_out[i] = velocity_tmp;
if (IsAmp) {
master_param_out[i] = param_tmp;
}
}
}
template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MPDType = MultiPrecisionType<T>;
#if CUDA_VERSION >= 11000
/* Once CUDA_VERSION is beyond 11, cooperative_groups can be involved in without
--rdc=true compile flag, then L2_norm kernel can be set with __device__ and
cooperative_groups::grid_group also can be involved. Otherwise, adding this
flag may affect much, L2_norm kernel shall be set with __global__.*/
// TODO(limingshu): declaration of cooperative_groups wapper is invalid in host.
template <typename T, typename MT>
__forceinline__ __device__ void L2NormKernel(
const cooperative_groups::grid_group* cg,
#else
template <typename T, typename MT>
__global__ void L2NormKernel(
#endif
const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer, const int64_t numel, const int repeat_times,
const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr,
MT* __restrict__ g_n = nullptr) {
__shared__ MT s_buffer[2];
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPDType>(ctx, multi_precision);
MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0);
while (tid < numel) {
MT tmp0 = static_cast<MT>(p_data[tid]);
MT tmp1 = static_cast<MT>(g_data[tid]);
p_tmp += (tmp0 * tmp0);
g_tmp += (tmp1 * tmp1);
tid += grid_stride;
}
p_tmp = math::blockReduceSum<MT>(p_tmp, FINAL_MASK);
g_tmp = math::blockReduceSum<MT>(g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
p_buffer[blockIdx.x] = p_tmp;
g_buffer[blockIdx.x] = g_tmp;
}
#if CUDA_VERSION >= 11000
cg->sync(); // Grid sync for writring partial result to gloabl memory
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
MT tmp0 = math::blockReduceSum<MT>(p_part_sum, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_part_sum, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] = tmp0;
s_buffer[1] = tmp1;
}
__syncthreads();
*p_n = Sqrt(s_buffer[0]);
*g_n = rescale_grad * Sqrt(s_buffer[1]);
#endif
}
template <typename T, typename MT>
__forceinline__ __device__ void MomentumUpdate(
const T* param, const T* __restrict__ grad, const MT* velocity,
T* param_out, MT* velocity_out, const MT* master_param,
MT* master_param_out, const MT* __restrict__ learning_rate, const MT mu,
const MT lars_weight_decay, const MT lars_coeff, const MT epsilon,
const MT rescale_grad, const MT param_norm, const MT grad_norm,
const int tid, const int grid_stride, const int64_t numel,
const bool is_amp) {
const MT lr = learning_rate[0];
MT local_lr = lr;
if (lars_weight_decay > static_cast<MT>(0)) {
local_lr = lr * lars_coeff * param_norm /
(fma(lars_weight_decay, param_norm, grad_norm) + epsilon);
}
if (is_amp) {
VectorizeLarsUpdate<T, MT, /*VecSize=*/4, /*IsAmp=*/true>(
grad, master_param, velocity, param_out, velocity_out, mu, local_lr,
lars_weight_decay, rescale_grad, tid, grid_stride, numel,
master_param_out);
} else {
if (std::is_same<T, float>::value ||
std::is_same<T, paddle::platform::float16>::value) {
/* TODO(limingshu): pointer cast may damage memory accessing for fp16 */
VectorizeLarsUpdate<T, MT, /*VecSize=*/4, /*IsAmp=*/false>(
grad, reinterpret_cast<const MT*>(param), velocity, param_out,
velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid,
grid_stride, numel);
} else {
InnerCompute<T>(ctx, multi_precision);
VectorizeLarsUpdate<T, MT, /*VecSize=*/2, /*IsAmp=*/false>(
grad, reinterpret_cast<const MT*>(param), velocity, param_out,
velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid,
grid_stride, numel);
}
}
}
private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext& ctx,
const bool multi_precision) const {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto grad = ctx.Input<framework::LoDTensor>("Grad");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}
#if CUDA_VERSION >= 11000
template <typename T, typename MT>
struct LarsParamWarpper {
int64_t numel_arr[LARS_MAX_MERGED_OPS];
int repeat_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS];
T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS];
MT weight_decay_arr[LARS_MAX_MERGED_OPS];
};
const MT* master_p = multi_precision ? master_param->data<MT>() : nullptr;
MT* master_p_out = multi_precision
? master_param_out->mutable_data<MT>(ctx.GetPlace())
: nullptr;
template <typename T, typename MT>
__global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT> lars_warpper,
MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer,
const int op_num, const MT mu,
const MT lars_coeff, const MT epsilon,
const MT rescale_grad,
const bool is_amp) {
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
for (int i = 0; i < op_num; ++i) {
int numel = lars_warpper.numel_arr[i];
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, lars_warpper.p_out_arr[i], lars_warpper.g_arr[i],
p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i],
rescale_grad, 0, &param_norm, &grad_norm);
MomentumUpdate<T, MT>(
lars_warpper.p_out_arr[i], lars_warpper.g_arr[i],
lars_warpper.v_out_arr[i], lars_warpper.p_out_arr[i],
lars_warpper.v_out_arr[i], lars_warpper.master_p_out_arr[i],
lars_warpper.master_p_out_arr[i], lars_warpper.lr_arr[i], mu,
lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
}
}
#endif
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
MT* v_out = velocity_out->mutable_data<MT>(ctx.GetPlace());
template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* param, const T* __restrict__ grad, const MT* velocity,
T* param_out, MT* velocity_out, const MT* master_param,
MT* master_param_out, const MT* __restrict__ learning_rate,
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu,
const MT lars_coeff, const MT lars_weight_decay, const MT epsilon,
const MT rescale_grad, const int repeat_times, const int thresh,
const int64_t numel, const bool is_amp) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
#if CUDA_VERSION >= 11000
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times,
rescale_grad, gridDim.x, &param_norm, &grad_norm);
#else
const MT rescale_grad_pow = rescale_grad * rescale_grad;
MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__syncthreads();
MT param_norm = Sqrt(math::blockReduceSum<MT>(param_part_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK));
#endif
MomentumUpdate<T, MT>(param, grad, velocity, param_out, velocity_out,
master_param, master_param_out, learning_rate, mu,
lars_weight_decay, lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
}
template <typename T, typename MT>
inline void SeparatedLarsMomentumOpCUDAKernel(
const platform::CUDADeviceContext& cuda_ctx, const T* param_data,
T* param_out_data, const MT* velocity_data, MT* velocity_out_data,
const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu,
const MT lars_coeff, const MT weight_decay, const MT epsilon,
const MT rescale_grad, const int64_t numel, const MT* master_param_data,
MT* master_out_data, const bool is_amp) {
LarsThreadConfig<T> lars_thread_config(numel);
L2NormKernel<T, MT><<<lars_thread_config.grid_for_norm, LARS_BLOCK_SIZE, 0,
cuda_ctx.stream()>>>(
param_data, grad_data, p_buffer, g_buffer, numel,
lars_thread_config.repeat_times, rescale_grad);
MomentumLarsKernel<T, MT><<<lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE,
0, cuda_ctx.stream()>>>(
param_data, grad_data, velocity_data, param_out_data, velocity_out_data,
master_param_data, master_out_data, lr, p_buffer, g_buffer, mu,
lars_coeff, weight_decay, epsilon, rescale_grad, 0,
lars_thread_config.grid_for_norm, numel, is_amp);
}
template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MT = MultiPrecisionType<T>;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int num_blocks_per_sm = 0;
bool multi_precision = ctx.Attr<bool>("multi_precision");
auto& cuda_ctx = ctx.template device_context<platform::CUDADeviceContext>();
int sm_num = cuda_ctx.GetSMCount();
framework::Tensor tmp_buffer_t =
ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>(
{LARS_BLOCK_SIZE << 1}, cuda_ctx);
auto* p_buffer = tmp_buffer_t.mutable_data<MT>(ctx.GetPlace());
auto* g_buffer = p_buffer + LARS_BLOCK_SIZE;
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
MT lars_weight_decay =
static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
MPDType rescale_grad =
static_cast<MPDType>(ctx.Attr<float>("rescale_grad"));
auto* p = param->data<T>();
auto* g = grad->data<T>();
auto* v = velocity->data<MT>();
auto* lr = learning_rate->data<MPDType>();
int block = 512;
int grid = (param->numel() + block - 1) / block;
auto eigen_p = framework::EigenVector<T>::Flatten(*param);
auto eigen_g = framework::EigenVector<T>::Flatten(*grad);
// calculate norms using eigein and launch the kernel.
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
auto* p_norm_data = p_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<MPDType>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<MPDType>::From(g_norm_t);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
// eigen unsupport fp16 l2-norm
ep_norm.device(*place) =
eigen_p.template cast<MPDType>().square().sum().sqrt();
eg_norm.device(*place) =
(eigen_g.template cast<MPDType>() * rescale_grad).square().sum().sqrt();
MomentumLarsKernel<
T, MT><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out,
rescale_grad);
MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
auto master_param_out =
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
int op_num = grad.size();
#if CUDA_VERSION >= 11000
if (op_num > 1) {
LarsParamWarpper<T, MT> lars_warpper;
PADDLE_ENFORCE_LT(
op_num, LARS_MAX_MERGED_OPS,
platform::errors::InvalidArgument(
"The maximum number of merged-ops supported is (%d), but"
"lars op required for trainning this model is (%d)\n",
LARS_MAX_MERGED_OPS, op_num));
/* Implementation of lars optimizer consists of following two steps:
1. Figure out the L2 norm statistic result of grad data and param data.
2. Update param and velocity with usage of L2 norm statistic result.
Step1 and step2 can be merged with api provided by nvida
cudaLaunchCooperativeKernel:
- The thread quantity shall less than pyhsical SM limited threads
- Launche as thread-block can synchronizlly execute. */
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, MergedMomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
sizeof(MT) << 1);
size_t total_numel = 0;
for (int i = 0; i < op_num; ++i) {
size_t temp_numel = param[i]->numel();
total_numel += temp_numel;
lars_warpper.numel_arr[i] = temp_numel;
lars_warpper.g_arr[i] = grad[i]->data<T>();
lars_warpper.lr_arr[i] = learning_rate[i]->data<MT>();
lars_warpper.p_out_arr[i] =
param_out[i]->mutable_data<T>(ctx.GetPlace());
lars_warpper.v_out_arr[i] =
velocity_out[i]->mutable_data<MT>(ctx.GetPlace());
lars_warpper.weight_decay_arr[i] = static_cast<MT>(weight_decay_arr[i]);
PADDLE_ENFORCE_EQ(
param[i]->data<T>(), lars_warpper.p_out_arr[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
PADDLE_ENFORCE_EQ(velocity[i]->data<MT>(), lars_warpper.v_out_arr[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
int64_t avg_numel = total_numel / op_num;
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num,
num_blocks_per_sm);
for (int i = 0; i < op_num; ++i) {
lars_warpper.repeat_arr[i] =
lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]);
}
if (multi_precision) {
for (int i = 0; i < op_num; ++i) {
lars_warpper.master_p_out_arr[i] =
master_param_out[i]->mutable_data<MT>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(master_param[i]->data<MT>(),
lars_warpper.master_p_out_arr[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
}
}
void* cuda_param[] = {reinterpret_cast<void*>(&lars_warpper),
reinterpret_cast<void*>(&p_buffer),
reinterpret_cast<void*>(&g_buffer),
reinterpret_cast<void*>(&op_num),
reinterpret_cast<void*>(&mu),
reinterpret_cast<void*>(&lars_coeff),
reinterpret_cast<void*>(&epsilon),
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads, and thead of each block synchronizedly cooperate.
cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(MergedMomentumLarsKernel<T, MT>),
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
} else {
auto* param_data = param[0]->data<T>();
auto* grad_data = grad[0]->data<T>();
auto* velocity_data = velocity[0]->data<MT>();
auto* lr = learning_rate[0]->data<MT>();
auto* param_out_data = param_out[0]->mutable_data<T>(ctx.GetPlace());
auto* velocity_out_data =
velocity_out[0]->mutable_data<MT>(ctx.GetPlace());
const MT* master_param_data =
multi_precision ? master_param[0]->data<MT>() : nullptr;
MT* master_param_out_data =
multi_precision
? master_param_out[0]->mutable_data<MT>(ctx.GetPlace())
: nullptr;
int64_t numel = param[0]->numel();
MT lars_weight_decay = weight_decay_arr[0];
// Figure out how many blocks can be active in each sm.
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, MomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
sizeof(MT) << 1);
LarsThreadConfig<float> lars_thread_config(numel, sm_num,
num_blocks_per_sm);
int repeat_times = lars_thread_config.GetRepeatTimes(numel);
int thresh = 0;
void* cuda_param[] = {
reinterpret_cast<void*>(&param_data),
reinterpret_cast<void*>(&grad_data),
reinterpret_cast<void*>(&velocity_data),
reinterpret_cast<void*>(&param_out_data),
reinterpret_cast<void*>(&velocity_out_data),
reinterpret_cast<void*>(&master_param_data),
reinterpret_cast<void*>(&master_param_out_data),
reinterpret_cast<void*>(&lr),
reinterpret_cast<void*>(&p_buffer),
reinterpret_cast<void*>(&g_buffer),
reinterpret_cast<void*>(&mu),
reinterpret_cast<void*>(&lars_coeff),
reinterpret_cast<void*>(&lars_weight_decay),
reinterpret_cast<void*>(&epsilon),
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&repeat_times),
reinterpret_cast<void*>(&thresh), // Just a placeholder
reinterpret_cast<void*>(&numel),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads.
cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(MomentumLarsKernel<T, MT>),
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
}
#else
for (int i = 0; i < op_num; ++i) {
const MT* master_param_data =
multi_precision ? master_param[i]->data<MT>() : nullptr;
MT* master_param_out_data =
multi_precision
? master_param_out[i]->mutable_data<MT>(ctx.GetPlace())
: nullptr;
SeparatedLarsMomentumOpCUDAKernel<T, MT>(
cuda_ctx, param[i]->data<T>(),
param_out[i]->mutable_data<T>(ctx.GetPlace()),
velocity[i]->data<MT>(),
velocity_out[i]->mutable_data<MT>(ctx.GetPlace()), grad[i]->data<T>(),
learning_rate[i]->data<MT>(), p_buffer, g_buffer, mu, lars_coeff,
weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(),
master_param_data, master_param_out_data, multi_precision);
}
#endif
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -23,54 +23,48 @@ template <typename T>
class LarsMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad");
// only support dense for now.
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
auto grad = ctx.Input<framework::LoDTensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
T epsilon = ctx.Attr<float>("epsilon");
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
int op_num = param.size();
for (int i = 0; i < op_num; ++i) {
auto* lr = learning_rate[i]->data<T>();
T lars_weight_decay = weight_decay_arr[i];
param_out[i]->mutable_data<T>(ctx.GetPlace());
velocity_out[i]->mutable_data<T>(ctx.GetPlace());
auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto* lr = learning_rate->data<T>();
auto p_out = framework::EigenVector<T>::Flatten(*(param_out[i]));
auto v_out = framework::EigenVector<T>::Flatten(*(velocity_out[i]));
auto p = framework::EigenVector<T>::Flatten(*(param[i]));
auto v = framework::EigenVector<T>::Flatten(*(velocity[i]));
auto g = framework::EigenVector<T>::Flatten(*(grad[i]));
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
p_norm_t.mutable_data<T>(ctx.GetPlace());
g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
p_norm_t.mutable_data<T>(ctx.GetPlace());
g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
ep_norm = p.square().sum().sqrt();
eg_norm = g.square().sum().sqrt();
ep_norm = p.square().sum().sqrt();
eg_norm = g.square().sum().sqrt();
T local_lr = lr[0];
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
local_lr = lr[0] * lars_coeff * ep_norm(0) /
(eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon);
T local_lr = lr[0];
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
local_lr = lr[0] * lars_coeff * ep_norm(0) /
(eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon);
}
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
p_out = p - v_out;
}
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
p_out = p - v_out;
}
};
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
namespace paddle {
namespace operators {
class MergedMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto param_dtype =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(param_dtype, ctx.GetPlace());
}
};
class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated")
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddComment(R"DOC(Merged Momentum Optimizer.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp,
ops::MergedMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(
merged_momentum, ops::MergedMomentumOpKernel<plat::CPUDeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CPUDeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
merged_momentum,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
template <typename MT, uint32_t kParamNum, bool kHasMasterParams>
struct MergedMomentumMasterParams {
MT *PADDLE_RESTRICT master_params[kParamNum];
HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; }
HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; }
};
template <typename MT, uint32_t kParamNum>
struct MergedMomentumMasterParams<MT, kParamNum, false> {
HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; }
HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {}
};
template <typename T, typename MT, bool kHasMasterParams,
uint32_t kParamNum = kHasMasterParams ? 55 : 110>
struct MergedMomentumKernelParam
: public MergedMomentumMasterParams<MT, kParamNum, kHasMasterParams> {
static constexpr auto N = kParamNum;
size_t sizes[N];
T *PADDLE_RESTRICT params[N];
const T *PADDLE_RESTRICT grads[N];
MT *PADDLE_RESTRICT velocitys[N];
const MT *PADDLE_RESTRICT lr;
MT mu;
MT rescale_grad;
uint32_t param_num;
HOSTDEVICE void operator()(size_t i) const {
const auto lr_val = *lr;
for (uint32_t idx = 0; idx < param_num; ++idx) {
auto size = sizes[idx];
if (i >= size) continue;
auto param_p = params[idx];
auto grad_p = grads[idx];
auto velocity_p = velocitys[idx];
auto master_param_p = this->MasterParam(idx);
const MT param =
master_param_p ? master_param_p[i] : static_cast<MT>(param_p[i]);
const MT grad = static_cast<MT>(grad_p[i]) * rescale_grad;
const MT velocity = velocity_p[i];
const MT velocity_out = velocity * mu + grad;
const MT param_out = param - lr_val * velocity_out;
velocity_p[i] = velocity_out;
param_p[i] = static_cast<T>(param_out);
if (master_param_p) {
master_param_p[i] = param_out;
}
}
}
};
template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(
n, params_out.size(),
platform::errors::InvalidArgument(
"Output(ParamOut) number must be equal to Input(Param) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i], params_out[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
}
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"Input(Grad) number must be equal to Input(Param) number."));
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"Input(Velocity) number and Input(Param) number."));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument("Output(VelocityOut) number must be "
"equal to Input(Param) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
platform::errors::InvalidArgument("Input(MasterParam) number must be "
"equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, master_params_out.size(),
platform::errors::InvalidArgument(
"Output(MasterParamOut) number must be equal to "
"Input(MasterParam) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i],
platform::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
}
} else {
master_params.clear();
master_params_out.clear();
}
auto lr = ctx.Input<framework::Tensor>("LearningRate");
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
using MPType = typename operators::details::MPTypeTrait<T>::Type;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
}
};
} // namespace operators
} // namespace paddle
......@@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor {
}
};
template <typename T, typename MT, typename UpdateMethod>
template <typename T, typename MT, RegularizationType kRegType,
typename UpdateMethod>
class DenseMomentumFunctor;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, UseNesterov> {
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
private:
const T* param_;
const T* grad_;
......@@ -193,7 +194,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
T* param_out_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
......@@ -201,7 +201,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
......@@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
......@@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
if (kRegType == RegularizationType::kL2DECAY) {
grad += regularization_coeff_ * param;
}
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - (grad + velocity_out * mu_) * lr;
......@@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
}
};
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, NoNesterov> {
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
private:
const T* param_;
const T* grad_;
......@@ -254,7 +252,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
T* param_out_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
......@@ -262,7 +259,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
......@@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
......@@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
if (kRegType == RegularizationType::kL2DECAY) {
grad += regularization_coeff_ * param;
}
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - lr * velocity_out;
......@@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
DenseMomentumFunctor<T, MT, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
param->data<T>(), grad->data<T>(), velocity->data<MT>(), \
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, \
param->numel(), regularization_coeff, \
param_out->mutable_data<T>(ctx.GetPlace()), \
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kL2DECAY);
} else {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
}
} else {
DenseMomentumFunctor<T, MT, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
RegularizationType::kL2DECAY);
} else {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
RegularizationType::kNONE);
}
}
}
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
class Pow2DecayWithLinearWarmupOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
auto dim = framework::make_ddim({1});
ctx->SetOutputDim("LearningRateOut", dim);
ctx->SetOutputDim("StepOut", dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "LearningRate");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class Pow2DecayWithLinearWarmupOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("LearningRate", "(Tensor) The input learning rate Tensor.");
AddInput("Step", "(Tensor) The input global step Tensor.");
AddOutput("LearningRateOut",
"(Tensor) The output learning rate Tensor. Same with "
"Input(LearningRate).");
AddOutput(
"StepOut",
"(Tensor) The output learning rate Tensor. Same with Input(Step).");
AddAttr<int64_t>("warmup_steps", "(int64_t) The warmup steps.");
AddAttr<int64_t>(
"total_steps",
"(int64_t) The total steps for changing the learning rate.");
AddAttr<float>("base_lr",
"(float) The final learning rate value after warmup.");
AddAttr<float>("end_lr",
"(float) The final learning rate value after total_steps.");
AddComment(R"DOC(
The Pow2DecayWithLinearWarmup learning rate scheduler.
When step_num < warmup_steps, lr = base_lr * step_num / warmup_steps
When warmup_steps <= step_num <= total_steps,
factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps)
lr = (base_lr - end_lr) * factor * factor + end_lr
When step_num > total_steps, lr = end_lr
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOp,
ops::Pow2DecayWithLinearWarmupOpMaker);
REGISTER_OP_CPU_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CPUDeviceContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CPUDeviceContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CUDADeviceContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CUDADeviceContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
template <typename T, typename AttrT>
struct Pow2DecayWithLinearWarmupFunctor {
template <typename U>
using RestrictPtr = U *PADDLE_RESTRICT;
public:
HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr<T> lr,
RestrictPtr<int64_t> step,
size_t warmup_steps,
size_t total_steps, AttrT base_lr,
AttrT end_lr)
: lr_(lr),
step_(step),
warmup_steps_(warmup_steps),
total_steps_(total_steps),
base_lr_(base_lr),
end_lr_(end_lr) {}
HOSTDEVICE void operator()(size_t) const {
size_t step = static_cast<size_t>(*step_) + 1;
*step_ = static_cast<int64_t>(step);
if (step <= warmup_steps_) {
auto new_lr = static_cast<double>(step) / warmup_steps_ * base_lr_;
*lr_ = static_cast<T>(new_lr);
} else if (step < total_steps_) {
auto factor = 1 -
static_cast<double>(step - warmup_steps_) /
(total_steps_ - warmup_steps_);
auto new_lr =
static_cast<double>(base_lr_ - end_lr_) * (factor * factor) + end_lr_;
*lr_ = static_cast<T>(new_lr);
} else {
*lr_ = static_cast<T>(end_lr_);
}
}
private:
RestrictPtr<T> lr_;
RestrictPtr<int64_t> step_;
size_t warmup_steps_;
size_t total_steps_;
AttrT base_lr_;
AttrT end_lr_;
};
template <typename DeviceContext, typename T>
class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const {
const auto *lr = ctx.Input<framework::Tensor>("LearningRate");
const auto *step = ctx.Input<framework::Tensor>("Step");
auto *lr_out = ctx.Output<framework::Tensor>("LearningRateOut");
auto *step_out = ctx.Output<framework::Tensor>("StepOut");
PADDLE_ENFORCE_EQ(
lr, lr_out, platform::errors::InvalidArgument("Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."));
PADDLE_ENFORCE_NOT_NULL(lr,
platform::errors::InvalidArgument(
"Input(LearingRate) should not be nullptr."));
PADDLE_ENFORCE_EQ(step, step_out,
platform::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_NOT_NULL(step, platform::errors::InvalidArgument(
"Input(Step) should not be nullptr."));
PADDLE_ENFORCE_EQ(
step->IsInitialized(), true,
platform::errors::InvalidArgument("Input(Step) must be initialized."));
auto warmup_steps = static_cast<size_t>(ctx.Attr<int64_t>("warmup_steps"));
auto total_steps = static_cast<size_t>(ctx.Attr<int64_t>("total_steps"));
PADDLE_ENFORCE_LE(warmup_steps, total_steps,
platform::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));
auto base_lr = ctx.Attr<float>("base_lr");
auto end_lr = ctx.Attr<float>("end_lr");
auto *lr_data = lr_out->data<T>();
auto *step_data = step_out->data<int64_t>();
auto &dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, 1);
using AttrT = double;
Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor(
lr_data, step_data, warmup_steps, total_steps,
static_cast<AttrT>(base_lr), static_cast<AttrT>(end_lr));
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -59,9 +59,14 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS})
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
IF(WITH_GPU)
nv_library(cuda_graph SRCS cuda_graph.cc DEPS enforce allocator_facade)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda)
nv_library(cuda_profiler SRCS cuda_profiler.cc DEPS enforce)
nv_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade cuda_graph)
ELSE()
cc_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade)
ENDIF()
IF(WITH_ROCM)
hip_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda)
ENDIF()
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/cuda_graph.h"
namespace paddle {
namespace platform {
std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
void CUDAGraph::Reset() {
if (is_reset_) return;
#if CUDA_VERSION >= 10010
if (graph_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph_));
graph_ = nullptr;
}
if (exec_graph_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph_));
exec_graph_ = nullptr;
}
#endif
// callback should be called in reverse order because the latter added
// callback may rely on the former added callback.
for (auto iter = callbacks_.rbegin(); iter != callbacks_.rend(); ++iter) {
(*iter)();
}
callbacks_.clear();
is_reset_ = true;
}
void CUDAGraph::Replay() {
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(is_reset_, false,
errors::PermissionDenied(
"Cannot replay the CUDA Graph after reset is called."));
PADDLE_ENFORCE_NOT_NULL(exec_graph_,
errors::PermissionDenied(
"CUDA Graph must be captured before replaying."));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph_, stream_));
#endif
}
void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
cudaStreamCaptureMode mode) {
ThrowErrorIfNotSupportCUDAGraph();
PADDLE_ENFORCE_EQ(
IsCapturing(), false,
errors::PermissionDenied("CUDA Graph can only captured one by one."));
PADDLE_ENFORCE_NOT_NULL(
stream, errors::PermissionDenied(
"CUDA Graph cannot be captured in default CUDA stream 0."));
capturing_graph_.reset(new CUDAGraph());
capturing_graph_->place_ = place;
capturing_graph_->stream_ = stream;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamBeginCapture(capturing_graph_->stream_, mode));
cudaStreamCaptureStatus status;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamGetCaptureInfo(
capturing_graph_->stream_, &status, &(capturing_graph_->id_)));
PADDLE_ENFORCE_EQ(IsValidCapturing(), true,
platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_;
}
std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(IsCapturing(), true,
errors::PermissionDenied("No CUDA Graph is capturing."));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamEndCapture(
capturing_graph_->stream_, &(capturing_graph_->graph_)));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGraphInstantiate(&(capturing_graph_->exec_graph_),
capturing_graph_->graph_, nullptr, nullptr, 0));
VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_;
return std::move(capturing_graph_);
#endif
}
bool CUDAGraph::IsValidCapturing() {
if (!IsCapturing()) return false;
cudaStreamCaptureStatus status;
CUDAGraphID id;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id));
return status == cudaStreamCaptureStatusActive;
}
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <mutex>
#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/fluid/platform/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
#if CUDA_VERSION >= 10010
static void ThrowErrorIfNotSupportCUDAGraph() {}
#else
enum cudaStreamCaptureMode {
cudaStreamCaptureModeGlobal = 0,
cudaStreamCaptureModeThreadLocal = 1,
cudaStreamCaptureModeRelaxed = 2
};
static void ThrowErrorIfNotSupportCUDAGraph() {
PADDLE_THROW(platform::errors::Unimplemented(
"CUDA Graph is only supported when CUDA version >= 10.1"));
}
#endif
// NOTE: Currently, we do not support to capture CUDA graph in parallel
// NOTE: Do not use this class directly because it should be used with
// the memory pool.
class CUDAGraph {
DISABLE_COPY_AND_ASSIGN(CUDAGraph);
// Since the constructor would throw error is CUDA_VERSION < 10010.
// The non-static method of CUDAGraph need not check CUDA_VERSION
// again.
CUDAGraph() { ThrowErrorIfNotSupportCUDAGraph(); }
public:
~CUDAGraph() { Reset(); }
CUDAGraphID ID() const { return id_; }
void Replay();
void Reset();
void AddResetCallback(std::function<void()> callback) {
std::lock_guard<std::mutex> guard(mtx_);
callbacks_.push_back(std::move(callback));
}
static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
cudaStreamCaptureMode mode);
static std::unique_ptr<CUDAGraph> EndCapture();
static void AddResetCallbackDuringCapturing(std::function<void()> callback) {
capturing_graph_->AddResetCallback(std::move(callback));
}
// No need to add CUDA_VERSION macro because capturing_graph_ would
// always be nullptr (constructor throws error)
static bool IsCapturing() { return capturing_graph_ != nullptr; }
static CUDAGraphID CapturingID() { return capturing_graph_->id_; }
static platform::CUDAPlace CapturingPlace() {
return capturing_graph_->place_;
}
// This API can be used to debug which GPU operation is not
// supported during capturing CUDA Graph.
static bool IsValidCapturing();
private:
#if CUDA_VERSION >= 10010
cudaGraph_t graph_{nullptr};
cudaGraphExec_t exec_graph_{nullptr};
#endif
cudaStream_t stream_{nullptr};
platform::CUDAPlace place_;
CUDAGraphID id_{0};
std::vector<std::function<void()>> callbacks_;
bool is_reset_{false};
std::mutex mtx_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
#if CUDA_VERSION >= 10010
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode));
// After cudaThreadExchangeStreamCaptureMode is called,
// the variable "mode" would be set to the old capturing mode.
old_mode_ = mode;
}
}
~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaThreadExchangeStreamCaptureMode(&old_mode_));
}
}
private:
cudaStreamCaptureMode old_mode_;
};
#else
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {}
};
#endif
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
cudaStreamCaptureMode mode) {
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode);
auto id = CUDAGraph::CapturingID();
memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph(
id);
AddResetCallbackIfCapturingCUDAGraph([id] {
memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph(
id);
});
}
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
auto place = CUDAGraph::CapturingPlace();
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
return CUDAGraph::EndCapture();
}
#endif
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_graph.h"
#endif
namespace paddle {
namespace platform {
// NOTE: These APIs are not thread-safe.
#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
cudaStreamCaptureMode mode);
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture();
#endif
inline bool IsCUDAGraphCapturing() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::IsCapturing();
#else
return false;
#endif
}
inline platform::CUDAPlace CUDAGraphCapturingPlace() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::CapturingPlace();
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CUDA Graph is only supported on NVIDIA GPU device."));
#endif
}
// Add reset callback if CUDA Graph is capturing.
// Otherwise, invoke callback directly.
template <typename Callback>
inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
return CUDAGraph::AddResetCallbackDuringCapturing(
std::forward<Callback>(callback));
}
#endif
callback();
}
} // namespace platform
} // namespace paddle
......@@ -44,6 +44,9 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) {
inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) {
std::vector<int> transformed_dims(dims.begin(), dims.end());
if (dims.size() < 4) {
return transformed_dims;
}
int H, W, D, C;
if (dims.size() == 4) {
H = dims[1];
......@@ -155,8 +158,8 @@ class TensorDescriptor {
dims_with_group.data(), strides.data()));
}
void set(const Tensor& tensor, const cudnnTensorFormat_t format) {
auto dims = framework::vectorize<int>(tensor.dims());
void set(const std::vector<int>& dims, const cudnnTensorFormat_t format,
const cudnnDataType_t dtype) {
std::vector<int> transformed_dims;
if (format == CUDNN_TENSOR_NHWC) {
transformed_dims = TransformDimOrder(dims);
......@@ -164,8 +167,14 @@ class TensorDescriptor {
transformed_dims = dims;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptorEx(
desc_.get(), format, ToCudnnDataType(tensor.type()),
transformed_dims.size(), transformed_dims.data()));
desc_.get(), format, dtype, transformed_dims.size(),
transformed_dims.data()));
}
void set(const Tensor& tensor, const cudnnTensorFormat_t format) {
auto dims = framework::vectorize<int>(tensor.dims());
auto dtype = ToCudnnDataType(tensor.type());
set(dims, format, dtype);
}
private:
......@@ -191,9 +200,8 @@ class FilterDescriptor {
T* desc() { return desc_.get(); }
T* desc() const { return desc_.get(); }
void set(const Tensor& tensor, const cudnnTensorFormat_t format,
const int groups = 1) {
auto dims = framework::vectorize<int>(tensor.dims());
void set(const std::vector<int>& dims, const cudnnTensorFormat_t format,
const cudnnDataType_t dtype, const int groups = 1) {
std::vector<int> transformed_dims;
if (format == CUDNN_TENSOR_NHWC) {
transformed_dims = TransformDimOrder(dims);
......@@ -204,8 +212,15 @@ class FilterDescriptor {
transformed_dims[1] = transformed_dims[1] / groups;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetFilterNdDescriptor(
desc_.get(), ToCudnnDataType(tensor.type()), format,
transformed_dims.size(), transformed_dims.data()));
desc_.get(), dtype, format, transformed_dims.size(),
transformed_dims.data()));
}
void set(const Tensor& tensor, const cudnnTensorFormat_t format,
const int groups = 1) {
auto dims = framework::vectorize<int>(tensor.dims());
auto dtype = ToCudnnDataType(tensor.type());
set(dims, format, dtype, groups);
}
private:
......
......@@ -180,7 +180,18 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#if CUDNN_VERSION >= 8000
#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) __macro(cudnnSetRNNDescriptor_v8);
#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) \
__macro(cudnnSetRNNDescriptor_v8); \
__macro(cudnnCreateFusedOpsPlan); \
__macro(cudnnCreateFusedOpsConstParamPack); \
__macro(cudnnCreateFusedOpsVariantParamPack); \
__macro(cudnnDestroyFusedOpsPlan); \
__macro(cudnnDestroyFusedOpsConstParamPack); \
__macro(cudnnDestroyFusedOpsVariantParamPack); \
__macro(cudnnFusedOpsExecute); \
__macro(cudnnSetFusedOpsConstParamPackAttribute); \
__macro(cudnnSetFusedOpsVariantParamPackAttribute); \
__macro(cudnnMakeFusedOpsPlan);
CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
#include "paddle/fluid/platform/cuda_graph.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#endif
#include "paddle/fluid/memory/malloc.h"
......@@ -557,6 +558,7 @@ class RecordedCudaMallocHelper {
#ifdef PADDLE_WITH_HIP
auto result = hipMalloc(ptr, size);
#else
CUDAGraphCaptureModeGuard capture_mode_guard;
auto result = cudaMalloc(ptr, size);
#endif
if (result == gpuSuccess) {
......
......@@ -30,3 +30,9 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
......@@ -36,4 +36,5 @@ using gpuEvent_t = cudaEvent_t;
using gpuDeviceProp = cudaDeviceProp;
#endif
using CUDAGraphID = unsigned long long; // NOLINT
} // namespace paddle
......@@ -7,7 +7,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model)
cost_model cuda_graph_with_memory_pool)
if (WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
......
......@@ -125,6 +125,8 @@ limitations under the License. */
#include "paddle/fluid/platform/xpu/xpu_info.h"
#endif
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#ifdef PADDLE_WITH_CRYPTO
#include "paddle/fluid/pybind/crypto.h"
#endif
......@@ -485,6 +487,17 @@ static int GetNCCLVersion() {
}
#endif
template <typename PlaceType>
static void TensorCopyFrom(framework::Tensor *dst, const framework::Tensor &src,
const PlaceType &place, int64_t batch_size) {
if (batch_size < 0) {
framework::TensorCopy(src, place, dst);
} else {
auto sliced = src.Slice(0, batch_size);
framework::TensorCopy(sliced, place, dst);
}
}
#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE(core_avx, m) {
#else
......@@ -520,6 +533,19 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("nccl_version", &GetNCCLVersion);
#endif
m.def("is_cuda_graph_capturing", &platform::IsCUDAGraphCapturing);
#ifdef PADDLE_WITH_CUDA
py::class_<platform::CUDAGraph>(m, "CUDAGraph")
.def_static("begin_capture",
[](platform::CUDAPlace place, int mode) {
platform::BeginCUDAGraphCapture(
place, static_cast<cudaStreamCaptureMode>(mode));
})
.def_static("end_capture", &platform::EndCUDAGraphCapture)
.def("replay", &platform::CUDAGraph::Replay)
.def("reset", &platform::CUDAGraph::Reset);
#endif
m.def("wait_device", [](const platform::Place &place) {
platform::DeviceContextPool::Instance().Get(place)->Wait();
});
......@@ -721,6 +747,18 @@ PYBIND11_MODULE(core_noavx, m) {
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_copy_from", &TensorCopyFrom<paddle::platform::CPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::XPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CUDAPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::NPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CUDAPinnedPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::Place>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("set", SetTensorFromPyArray<paddle::platform::CPUPlace>,
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
.def("set", SetTensorFromPyArray<paddle::platform::XPUPlace>,
......@@ -2301,7 +2339,14 @@ All parameter, weight, gradient are variables in Paddle.
m.def("op_support_gpu", OpSupportGPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
m.def("cuda_empty_cache", platform::EmptyCache);
m.def("cuda_empty_cache", [] {
for (int dev_id : platform::GetSelectedDevices()) {
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(
platform::CUDAPlace(dev_id));
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
}
platform::EmptyCache();
});
m.def("get_device_properties",
[](int id) -> const gpuDeviceProp & {
return platform::GetDeviceProperties(id);
......@@ -3213,6 +3258,13 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool fix_op_run_order) {
self.fix_op_run_order_ = fix_op_run_order;
})
.def_property("allow_cuda_graph_capture",
[](const BuildStrategy &self) {
return self.allow_cuda_graph_capture_;
},
[](BuildStrategy &self, bool allow_cuda_graph_capture) {
self.allow_cuda_graph_capture_ = allow_cuda_graph_capture;
})
.def("_copy",
[](const BuildStrategy &self) {
auto new_bs = self;
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace
if is_compiled_with_cuda() and not is_compiled_with_rocm():
from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None
if place is None:
place = CUDAPlace(0)
self._place = place
assert mode in ALL_MODES
self._mode = ALL_MODES.index(mode)
def capture_begin(self):
CoreCUDAGraph.begin_capture(self._place, self._mode)
def capture_end(self):
self._graph = CoreCUDAGraph.end_capture()
def replay(self):
self._graph.replay()
def reset(self):
self._graph.reset()
else:
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
raise NotImplementedError()
def capture_begin(self):
raise NotImplementedError()
def capture_end(self):
raise NotImplementedError()
def replay(self):
raise NotImplementedError()
def reset(self):
raise NotImplementedError()
......@@ -1932,3 +1932,38 @@ def fused_bn_add_act(x,
attrs=attrs)
return batch_norm_out
def pow2_decay_with_linear_warmup(warmup_steps,
total_steps,
base_lr,
end_lr,
dtype='float32',
name=None):
if paddle.fluid.in_dygraph_mode():
raise NotImplementedError(
"pow2_decay_with_linear_warmup does not support dygraph mode yet.")
helper = LayerHelper("pow2_decay_with_linear_warmup", **locals())
lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1])
helper.set_variable_initializer(
lr, Constant(value=float(base_lr) / warmup_steps))
step = helper.create_global_variable(
persistable=True, dtype='int64', shape=[1])
helper.set_variable_initializer(step, Constant(value=0))
assert warmup_steps <= total_steps, "warmup_steps cannot be larger than total_steps"
helper.append_op(
type="pow2_decay_with_linear_warmup",
inputs={"LearningRate": lr,
"Step": step},
outputs={"LearningRateOut": lr,
"StepOut": step},
attrs={
"warmup_steps": warmup_steps,
"total_steps": total_steps,
"base_lr": base_lr,
"end_lr": end_lr,
})
return lr
......@@ -127,11 +127,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
num_cast_ops = 0
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
]:
if in_name not in {'X', 'Z'}:
continue
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(op,
in_name):
continue
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
......@@ -184,9 +182,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
] and out_name != 'Y':
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
......@@ -401,9 +397,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
keep_fp32_ops.add(op)
continue # processed below
for in_name in op.input_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and in_name not in {'X', 'Z'}:
if _keep_fp32_input(op, in_name):
continue
for in_var_name in op.input(in_name):
in_var = None
......@@ -431,9 +425,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
format(op.type, in_var_name, in_var.dtype))
for out_name in op.output_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and out_name != 'Y':
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = None
......
......@@ -1041,9 +1041,15 @@ class Executor(object):
lr_value = lr_sheduler()
lr_var = program._program.global_block().vars[lr_sheduler._var_name]
lr_tensor = _as_lodtensor(lr_value, core.CPUPlace(), lr_var.dtype)
exe.feed_and_split_tensor_into_local_scopes({
lr_sheduler._var_name: lr_tensor
})
if core.is_cuda_graph_capturing():
warnings.warn(
"Caution!!! When capturing CUDA Graph, the learning rate scheduler would not "
"take any effect! Please set the learning rate manually before each batch!"
)
else:
exe.feed_and_split_tensor_into_local_scopes({
lr_sheduler._var_name: lr_tensor
})
fetch_var_names = list(map(_to_name_str, fetch_list))
tensors = exe.run(fetch_var_names, return_merged)._move_to_list()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import core
import numpy as np
def get_var_and_memory_size(block, var_name, batch_size=None):
var = block._find_var_recursive(var_name)
assert var is not None, "Variable {} cannot be found".format(var_name)
assert var.type == core.VarDesc.VarType.LOD_TENSOR, "Variable {} is not Tensor".format(
var_name)
shape = list(var.shape)
if not shape:
return var, 0
has_none = False
for i, s in enumerate(shape):
if s is None or s < 0:
assert not has_none
shape[i] = batch_size
has_none = True
assert all(
[s >= 0 for s in shape]), "shape {} is not deterministic".format(shape)
mem_size = int(np.prod(shape)) * core.size_of_dtype(var.dtype)
return var, mem_size
def pre_allocate_memory(size, place):
t = core.LoDTensor()
t._set_dims([size])
t._mutable_data(place, core.VarDesc.VarType.INT8)
del t
# NOTE: does not consider inplace yet.
def get_max_memory_info(program, batch_size=None):
assert program.num_blocks == 1, "only support to analysis program with only one block"
cur_tmp_mem = 0
max_tmp_mem = 0
max_persistable_mem = 0
visited_vars = set()
alived_vars = []
block = program.global_block()
gc_vars = core._get_eager_deletion_vars(program.desc, [])[0]
for i, op in enumerate(block.ops):
var_names = op.input_arg_names + op.output_arg_names
for var_name in var_names:
if var_name in visited_vars:
continue
visited_vars.add(var_name)
var, mem_size = get_var_and_memory_size(block, var_name, batch_size)
if var.persistable:
max_persistable_mem += mem_size
else:
cur_tmp_mem += mem_size
max_tmp_mem = max(max_tmp_mem, cur_tmp_mem)
cur_gc_vars = gc_vars[i]
for var_name in var_names:
if var_name not in cur_gc_vars:
continue
_, mem_size = get_var_and_memory_size(block, var_name, batch_size)
cur_tmp_mem -= mem_size
return max_tmp_mem, max_persistable_mem
......@@ -2064,8 +2064,9 @@ class LarsMomentumOptimizer(Optimizer):
attrs = {
"mu": self._momentum,
"lars_coeff": self._lars_coeff,
"lars_weight_decay": _lars_weight_decay,
"lars_weight_decay": [_lars_weight_decay],
"multi_precision": find_master,
"epsilon": self._epsilon,
"rescale_grad": self._rescale_grad
}
......@@ -2084,7 +2085,7 @@ class LarsMomentumOptimizer(Optimizer):
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type,
type=self.type if _lars_weight_decay != 0.0 else 'momentum',
inputs=inputs,
outputs=outputs,
attrs=attrs,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.device.cuda.graphs import CUDAGraph
import unittest
import numpy as np
from paddle.fluid.dygraph.base import switch_to_static_graph
from simple_nets import simple_fc_net_with_inputs
class TestCUDAGraph(unittest.TestCase):
def setUp(self):
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(
):
fluid.set_flags({
'FLAGS_allocator_strategy': 'auto_growth',
'FLAGS_sync_nccl_allreduce': False,
'FLAGS_cudnn_deterministic': True
})
def random_tensor(self, shape):
return paddle.to_tensor(
np.random.randint(
low=0, high=10, size=shape).astype("float32"))
@switch_to_static_graph
def test_cuda_graph_static_graph(self):
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm():
return
seed = 100
loss_cuda_graph = self.cuda_graph_static_graph_main(
seed, use_cuda_graph=True)
loss_no_cuda_graph = self.cuda_graph_static_graph_main(
seed, use_cuda_graph=False)
self.assertEqual(loss_cuda_graph, loss_no_cuda_graph)
def cuda_graph_static_graph_main(self, seed, use_cuda_graph):
batch_size = 1
class_num = 10
image_shape = [batch_size, 784]
label_shape = [batch_size, 1]
paddle.seed(seed)
np.random.seed(seed)
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
image = paddle.static.data(
name="image", shape=image_shape, dtype='float32')
label = paddle.static.data(
name="label", shape=label_shape, dtype='int64')
image.persistable = True
label.persistable = True
loss = simple_fc_net_with_inputs(image, label, class_num)
loss.persistable = True
lr = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04])
optimizer = paddle.optimizer.SGD(learning_rate=lr)
optimizer.minimize(loss)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
build_strategy = paddle.static.BuildStrategy()
build_strategy.allow_cuda_graph_capture = True
build_strategy.fix_op_run_order = True
build_strategy.fuse_all_optimizer_ops = True
compiled_program = paddle.static.CompiledProgram(
main).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
places=place)
image_t = scope.var(image.name).get_tensor()
label_t = scope.var(label.name).get_tensor()
loss_t = scope.var(loss.name).get_tensor()
lr_var = main.global_block().var(lr._var_name)
self.assertTrue(lr_var.persistable)
lr_t = scope.var(lr_var.name).get_tensor()
cuda_graph = None
for batch_id in range(20):
image_t.set(
np.random.rand(*image_shape).astype('float32'), place)
label_t.set(np.random.randint(
low=0, high=class_num, size=label_shape, dtype='int64'),
place)
if batch_id == 1 and use_cuda_graph:
cuda_graph = CUDAGraph(place, mode="global")
cuda_graph.capture_begin()
exe.run(compiled_program)
cuda_graph.capture_end()
if cuda_graph:
lr_t.set(np.array([lr()], dtype='float32'), place)
cuda_graph.replay()
else:
exe.run(compiled_program)
lr.step()
if cuda_graph:
cuda_graph.reset()
return np.array(loss_t)
def test_cuda_graph_dynamic_graph(self):
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm():
return
shape = [2, 3]
x = self.random_tensor(shape)
z = self.random_tensor(shape)
g = CUDAGraph()
g.capture_begin()
y = x + 10
z.add_(x)
g.capture_end()
for _ in range(10):
z_np_init = z.numpy()
x_new = self.random_tensor(shape)
x.copy_(x_new, False)
g.replay()
x_np = x_new.numpy()
y_np = y.numpy()
z_np = z.numpy()
self.assertTrue((y_np - x_np == 10).all())
self.assertTrue((z_np - z_np_init == x_np).all())
g.reset()
if __name__ == "__main__":
unittest.main()
......@@ -103,7 +103,7 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase):
'op_role_var')[0] or ".b" in op.attr('op_role_var')[0])
]
for op in ops_without_wd:
self.assertEqual(op.attr('lars_weight_decay'), 0)
self.assertEqual(op.attr('lars_weight_decay')[0], 0)
def test_lars_apply_with_amp(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid.memory_analysis import pre_allocate_memory, get_max_memory_info
from simple_nets import simple_fc_net
class TestMemoryAnalysis(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def test_get_memory_info(self):
loss = simple_fc_net()
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
optimizer.minimize(loss)
main_prog = paddle.static.default_main_program()
max_tmp_mem_1, max_persitable_mem_1 = get_max_memory_info(
main_prog, batch_size=32)
self.assertGreater(max_tmp_mem_1, 0)
self.assertGreater(max_persitable_mem_1, 0)
max_tmp_mem_2, max_persitable_mem_2 = get_max_memory_info(
main_prog, batch_size=64)
self.assertEqual(max_persitable_mem_1, max_persitable_mem_2)
self.assertLess(max_tmp_mem_1, max_tmp_mem_2)
class TestPreAllocateMemory(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def test_pre_allocate(self):
size = 32 * 1024 * 1024
pre_allocate_memory(size, paddle.CPUPlace())
if paddle.is_compiled_with_cuda():
pre_allocate_memory(size, paddle.CUDAPlace(0))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
from collections import OrderedDict
def run_momentum_op(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
}
param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
class TestMergedMomentum(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float16 if multi_precision and isinstance(
place, paddle.CUDAPlace) else np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_merged):
# FIXME(zengjinle): CPU Momentum Op does not support rescale_grad
rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01
return run_momentum_op(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged)
outs1 = run_op(True)
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out1, out2))
else:
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
def get_places(self):
places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
return places
def test_main(self):
for multi_precision in [False, True]:
for place in self.get_places():
self.check_with_place(place, multi_precision)
if __name__ == "__main__":
unittest.main()
......@@ -138,50 +138,70 @@ class TestMomentumOp2(OpTest):
"core is not compiled with CUDA")
class TestLarsMomentumOpWithMP(OpTest):
def setUp(self):
self.config()
self.op_type = "lars_momentum"
master_param = np.random.random((123, 321)).astype("float32")
param = master_param.astype("float16")
grad = np.random.random((123, 321)).astype("float16")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
rescale_grad = 1.0
params = []
grads = []
velocitys = []
learning_rates = []
master_params = []
param_outs = []
velocity_outs = []
master_param_outs = []
for i in range(self.params_num):
master_param = np.random.random((123, 321)).astype("float32")
param = master_param.astype("float16")
grad = np.random.random((123, 321)).astype("float16")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
fp32_grad = grad.astype("float32")
pnorm = np.sqrt(np.square(master_param).sum())
gnorm = np.sqrt(np.square(fp32_grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * pnorm)
fp32_grad = fp32_grad * rescale_grad
velocity_out = mu * velocity + local_lr * (
fp32_grad + lars_weight_decay * master_param)
p_new = master_param - velocity_out
param_out = p_new.astype("float16")
master_param_out = p_new
params.append(("SubParam_" + str(i), param))
grads.append(("SubGrad_" + str(i), grad))
velocitys.append(("SubVelocity_" + str(i), velocity))
learning_rates.append(("SubLearning_rate_" + str(i), learning_rate))
velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out))
param_outs.append(("SubParam_out_" + str(i), param_out))
master_params.append(("SubMasterParam_" + str(i), master_param))
master_param_outs.append(
("SubMasterParamOut_" + str(i), master_param_out))
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate,
'MasterParam': master_param,
'Param': params,
'Grad': grads,
'Velocity': velocitys,
'LearningRate': learning_rates,
'MasterParam': master_params,
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': lars_weight_decay,
'lars_weight_decay': [lars_weight_decay],
'multi_precision': True,
'rescale_grad': rescale_grad
}
fp32_grad = grad.astype("float32")
pnorm = np.sqrt(np.square(master_param).sum())
gnorm = np.sqrt(np.square(fp32_grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * pnorm)
fp32_grad = fp32_grad * rescale_grad
velocity_out = mu * velocity + local_lr * (fp32_grad + lars_weight_decay
* master_param)
p_new = master_param - velocity_out
param_out = p_new.astype("float16")
master_param_out = p_new
self.outputs = {
'ParamOut': param_out,
'VelocityOut': velocity_out,
'MasterParamOut': master_param_out
'ParamOut': param_outs,
'VelocityOut': velocity_outs,
'MasterParamOut': master_param_outs
}
def test_check_output(self):
......@@ -191,46 +211,65 @@ class TestLarsMomentumOpWithMP(OpTest):
if core.is_float16_supported(place):
self.check_output_with_place(place)
def config(self):
self.params_num = 1
class TestLarsMomentumOp(OpTest):
def setUp(self):
self.config()
self.op_type = "lars_momentum"
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
params = []
grads = []
velocitys = []
param_outs = []
velocity_outs = []
learning_rates = []
for i in range(self.params_num):
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
pnorm = np.sqrt(np.square(param).sum())
gnorm = np.sqrt(np.square(grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay
* param)
param_out = param - velocity_out
params.append(("SubParam_" + str(i), param))
grads.append(("SubGrad_" + str(i), grad))
velocitys.append(("SubVelocity_" + str(i), velocity))
learning_rates.append(("SubLearning_rate_" + str(i), learning_rate))
velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out))
param_outs.append(("SubParam_out_" + str(i), param_out))
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
'Param': params,
'Grad': grads,
'Velocity': velocitys,
'LearningRate': learning_rates
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': lars_weight_decay
'lars_weight_decay': [lars_weight_decay]
}
pnorm = np.sqrt(np.square(param).sum())
gnorm = np.sqrt(np.square(grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay *
param)
param_out = param - velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
self.outputs = {'ParamOut': param_outs, 'VelocityOut': velocity_outs}
def test_check_output(self):
paddle.enable_static()
self.check_output()
def config(self):
self.params_num = 1
class TestSparseMomentumOp(unittest.TestCase):
def setUp(self):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.contrib.layers.nn import pow2_decay_with_linear_warmup
from paddle.optimizer.lr import LinearWarmup
from paddle.optimizer.lr import PolynomialDecay
import unittest
def gen_pow2_warmup_op_lr(warmup_steps, total_steps, base_lr, end_lr, place):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, base_lr,
end_lr)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
while True:
lr_np = exe.run(main, fetch_list=[lr])[0]
yield lr_np[0]
class Pow2Warmup(LinearWarmup):
def __init__(self, warmup_steps, total_steps, base_lr, end_lr):
assert total_steps > warmup_steps
lr_sch = PolynomialDecay(
learning_rate=base_lr,
decay_steps=total_steps - warmup_steps,
end_lr=end_lr,
power=2)
super(Pow2Warmup, self).__init__(
learning_rate=lr_sch,
warmup_steps=warmup_steps,
start_lr=0.0,
end_lr=base_lr)
def gen_pow2_warmup_py_lr(warmup_steps, total_steps, base_lr, end_lr, place):
lr_sch = Pow2Warmup(warmup_steps, total_steps, base_lr, end_lr)
lr_sch.step()
while True:
yield lr_sch()
lr_sch.step()
class TestPow2WarmupLRScheduler(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.params = {
'warmup_steps': 30,
'total_steps': 100,
'base_lr': 0.02,
'end_lr': 0.001,
}
self.step_num = 1000
def check_with_place(self, place):
kwargs = dict(self.params)
kwargs['place'] = place
lr_sch_op = gen_pow2_warmup_op_lr(**kwargs)
lr_sch_py = gen_pow2_warmup_py_lr(**kwargs)
for i, (lr_op, lr_py) in enumerate(zip(lr_sch_op, lr_sch_py)):
self.assertLess(abs(lr_op - lr_py), 1e-6)
if i > self.step_num:
break
def test_main(self):
self.check_with_place(paddle.CPUPlace())
if paddle.is_compiled_with_cuda():
self.check_with_place(paddle.CUDAPlace(0))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
from paddle.fluid.core import LoDTensor as Tensor
class TestTensorCopyFrom(unittest.TestCase):
def test_main(self):
place = paddle.CPUPlace()
np_value = np.random.random(size=[10, 30]).astype('float32')
t_src = Tensor()
t_src.set(np_value, place)
self.assertTrue(np.array_equal(np_value, t_src))
t_dst1 = Tensor()
t_dst1._copy_from(t_src, place)
self.assertTrue(np.array_equal(np_value, t_dst1))
t_dst2 = Tensor()
t_dst2._copy_from(t_src, place, 5)
self.assertTrue(np.array_equal(np.array(np_value[0:5]), t_dst2))
if __name__ == "__main__":
unittest.main()
......@@ -14,3 +14,4 @@
from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401
from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401
from .resnet_unit import ResNetUnit #noqa: F401
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import collections
import itertools
import six
import math
import sys
import warnings
from functools import partial, reduce
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.device import get_device, get_cudnn_version
from paddle.nn import initializer as I
from paddle.nn import Layer, LayerList
from paddle.fluid.layers import utils
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle import _C_ops
__all__ = ['resnet_unit', 'ResNetUnit']
def resnet_unit(x, filter_x, scale_x, bias_x, mean_x, var_x, z, filter_z,
scale_z, bias_z, mean_z, var_z, stride, stride_z, padding,
dilation, groups, momentum, eps, data_format, fuse_add,
has_shortcut, use_global_stats, is_test, act):
helper = LayerHelper('resnet_unit', **locals())
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bit_mask_dtype = fluid.core.VarDesc.VarType.INT32
out = helper.create_variable_for_type_inference(x.dtype)
bit_mask = helper.create_variable_for_type_inference(
dtype=bit_mask_dtype, stop_gradient=True)
# intermediate_out for x
conv_x = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_x = mean_x
running_var_x = var_x
# intermediate_out for z
conv_z = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if mean_z is None else mean_z
running_var_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if var_z is None else var_z
inputs = {
'X': x,
'FilterX': filter_x,
'ScaleX': scale_x,
'BiasX': bias_x,
'MeanX': mean_x,
'VarX': var_x,
'Z': z,
'FilterZ': filter_z,
'ScaleZ': scale_z,
'BiasZ': bias_z,
'MeanZ': mean_z,
'VarZ': var_z
}
attrs = {
'stride': stride,
'stride_z': stride_z,
'padding': padding,
'dilation': dilation,
'group': groups,
'momentum': momentum,
'epsilon': eps,
'data_format': data_format,
'fuse_add': fuse_add,
'has_shortcut': has_shortcut,
'use_global_stats': use_global_stats,
'is_test': is_test,
'act_type': act
}
outputs = {
'Y': out,
'BitMask': bit_mask,
'ConvX': conv_x,
'SavedMeanX': saved_mean_x,
'SavedInvstdX': saved_invstd_x,
'RunningMeanX': running_mean_x,
'RunningVarX': running_var_x,
'ConvZ': conv_z,
'SavedMeanZ': saved_mean_z,
'SavedInvstdZ': saved_invstd_z,
'RunningMeanZ': running_mean_z,
'RunningVarZ': running_var_z,
}
helper.append_op(
type='resnet_unit', inputs=inputs, outputs=outputs, attrs=attrs)
return out
class ResNetUnit(Layer):
r"""
******Temporary version******.
ResNetUnit is designed for optimize the performence by using cudnnv8 API.
"""
def __init__(self,
num_channels_x,
num_filters,
filter_size,
stride=1,
momentum=0.9,
eps=1e-5,
data_format='NHWC',
act='relu',
fuse_add=False,
has_shortcut=False,
use_global_stats=False,
is_test=False,
filter_x_attr=None,
scale_x_attr=None,
bias_x_attr=None,
moving_mean_x_name=None,
moving_var_x_name=None,
num_channels_z=1,
stride_z=1,
filter_z_attr=None,
scale_z_attr=None,
bias_z_attr=None,
moving_mean_z_name=None,
moving_var_z_name=None):
super(ResNetUnit, self).__init__()
self._stride = stride
self._stride_z = stride_z
self._dilation = 1
self._kernel_size = utils.convert_to_list(filter_size, 2, 'kernel_size')
self._padding = (filter_size - 1) // 2
self._groups = 1
self._momentum = momentum
self._eps = eps
self._data_format = data_format
self._act = act
self._fuse_add = fuse_add
self._has_shortcut = has_shortcut
self._use_global_stats = use_global_stats
self._is_test = is_test
# check format
valid_format = {'NHWC'}
if data_format not in valid_format:
raise ValueError(
"conv_format must be one of {}, but got conv_format='{}'".
format(valid_format, data_format))
def _get_default_param_initializer(channels):
filter_elem_num = np.prod(self._kernel_size) * channels
std = (2.0 / filter_elem_num)**0.5
return I.Normal(0.0, std)
# initial filter
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bn_param_shape = [1, 1, 1, num_filters]
filter_x_shape = [num_filters, filter_size, filter_size, num_channels_x]
filter_z_shape = [num_filters, filter_size, filter_size, num_channels_z]
self.filter_x = self.create_parameter(
shape=filter_x_shape,
attr=filter_x_attr,
default_initializer=_get_default_param_initializer(num_channels_x))
self.scale_x = self.create_parameter(
shape=bn_param_shape,
attr=scale_x_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_x = self.create_parameter(
shape=bn_param_shape,
attr=bias_x_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_x = self.create_parameter(
attr=ParamAttr(
name=moving_mean_x_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_x.stop_gradient = True
self.var_x = self.create_parameter(
attr=ParamAttr(
name=moving_var_x_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_x.stop_gradient = True
if has_shortcut:
self.filter_z = self.create_parameter(
shape=filter_z_shape,
attr=filter_z_attr,
default_initializer=_get_default_param_initializer(
num_channels_z))
self.scale_z = self.create_parameter(
shape=bn_param_shape,
attr=scale_z_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_z = self.create_parameter(
shape=bn_param_shape,
attr=bias_z_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_z = self.create_parameter(
attr=ParamAttr(
name=moving_mean_z_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_z.stop_gradient = True
self.var_z = self.create_parameter(
attr=ParamAttr(
name=moving_var_z_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_z.stop_gradient = True
else:
self.filter_z = None
self.scale_z = None
self.bias_z = None
self.mean_z = None
self.var_z = None
def forward(self, x, z=None):
if self._fuse_add and z is None:
raise ValueError("z can not be None")
out = resnet_unit(
x, self.filter_x, self.scale_x, self.bias_x, self.mean_x,
self.var_x, z, self.filter_z, self.scale_z, self.bias_z,
self.mean_z, self.var_z, self._stride, self._stride_z,
self._padding, self._dilation, self._groups, self._momentum,
self._eps, self._data_format, self._fuse_add, self._has_shortcut,
self._use_global_stats, self._is_test, self._act)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册