提交 da478d1e 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_some_yaml_config

......@@ -258,6 +258,12 @@ copy(inference_lib_dist
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/any.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/utils/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/optional.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/utils/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/none.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/utils/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/extension.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
......
......@@ -39,8 +39,9 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
}
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) {
operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal(
......
......@@ -35,7 +35,7 @@ class GradNodeAccumulation : public GradNodeBase {
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
......@@ -145,8 +145,9 @@ void GradNodeScale::SetTensorWrappers_X(
void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) {
operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) {
// 1. Check Output Size
PADDLE_ENFORCE(
((grads.size() == 1) && (grads[0].size() == 1)),
......
......@@ -39,7 +39,7 @@ class GradNodeScale : public GradNodeBase {
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
......@@ -47,6 +47,9 @@ std::unordered_map<std::string, std::vector<std::string>>
static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {};
static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = {
"split"};
/* --- Black Ops list that's NO NEED to apply code generation --- */
static std::unordered_set<std::string> black_ops_list = {"run_program"};
......@@ -2243,11 +2246,21 @@ static std::string GenerateGradNodeCCContents(
// [Generation] Get Full Grad Function
const char* GRAD_FUNCTION_TEMPLATE =
"std::vector<std::vector<paddle::experimental::Tensor>> "
"GradNode%s::operator()(const "
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, "
"bool create_graph) {\n%s\n}";
std::string grad_function_str = paddle::string::Sprintf(
GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body);
"GradNode%s::operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"create_graph) {\n"
"%s"
"%s"
"\n}";
std::string fill_zero_str = "";
if (ops_to_fill_zero_for_empty_grads.count(fwd_op_type)) {
fill_zero_str =
"egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, "
"this->InputMeta());\n";
}
std::string grad_function_str =
paddle::string::Sprintf(GRAD_FUNCTION_TEMPLATE, fwd_op_type,
fill_zero_str, generated_grad_function_body);
VLOG(6) << "Generated returns";
......@@ -2279,9 +2292,9 @@ static std::string GenerateGradNodeHeaderContents(
" ~GradNode%s() override { VLOG(6) << \" Destruct GradNode%s \"; }\n"
"\n"
" virtual std::vector<std::vector<paddle::experimental::Tensor>> "
"operator()(const "
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, const "
"bool create_graph = false) "
"operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"create_graph = false) "
"override;\n"
"\n"
" void ClearTensorWrappers() override { \n"
......
......@@ -17,6 +17,8 @@ import re
import argparse
import os
ops_to_fill_zero_for_empty_grads = set(list("split"))
# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
core_ops_returns_info = {}
......@@ -599,7 +601,8 @@ class {} : public egr::GradNodeBase {{
~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
......@@ -657,10 +660,11 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_input_map.items():
if IsPlainTensorType(ttype):
grad_api_args[grad_api_position] = f"grads[{fwd_position}][0]"
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
else:
assert IsVectorTensorType(ttype)
grad_api_args[grad_api_position] = f"grads[{fwd_position}]"
grad_api_args[grad_api_position] = f"hooked_grads[{fwd_position}]"
for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name)
......@@ -688,23 +692,30 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_node_name = GetGradNodeName(fwd_api_name)
fill_zero_str = ""
if fwd_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
if len(namespace) > 0:
grad_api_namespace = f"paddle::experimental::{namespace}"
else:
grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
{}
auto hooked_grads = ApplyGradientHooks(grads);
// Call grad_api function
VLOG(3) << \"Finally State Running: \" << \"{}\";
VLOG(3) << \"Final State Running: \" << \"{}\";
auto grad_api_returns = {}::{}({});
{}
}}
"""
node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_api_namespace, bwd_api_name,
grad_api_args_str, returns_str)
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
bwd_api_name, grad_api_args_str, returns_str)
return node_definition_str
......@@ -799,8 +810,15 @@ def GenerateNodeCreationCodes(
# SetAttributes
set_attributes_list = []
for name, _, _, _ in backward_attrs_list:
set_attributes = f" grad_node->SetAttribute{name}({name});"
forward_attrs_name_set = set()
for name, _, _, _ in forward_attrs_list:
forward_attrs_name_set.add(name)
for name, _, default_val_attr, _ in backward_attrs_list:
if name in forward_attrs_name_set:
set_attributes = f" grad_node->SetAttribute{name}({name});"
else:
set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)
......
......@@ -20,8 +20,8 @@
namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode::
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) {
operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) { // NOLINT
paddle::CustomOpKernelContext ctx;
auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
......
......@@ -37,8 +37,9 @@ class RunCustomOpNode : public GradNodeBase {
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) override;
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) // NOLINT
override;
std::string name() {
return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_);
......
......@@ -102,6 +102,7 @@ const std::vector<std::vector<GradSlotMeta>>& GradNodeBase::OutputMeta() const {
void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
size_t slot_rank) {
VLOG(6) << "Set GradSlotMeta for Grad Inputs";
auto* fwd_out_meta = egr::EagerUtils::nullable_autograd_meta(fwd_out);
PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1),
......@@ -117,6 +118,12 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
auto& meta = metas[0];
meta.SetStopGradient(fwd_out_meta->StopGradient());
if (!fwd_out.is_initialized()) {
VLOG(6)
<< "Skip Configuring GradSlotMeta for uninitialized GradInput Tensor";
return;
}
// Record TensorMeta
if (phi::DenseTensor::classof(fwd_out.impl().get())) {
// Only Copy Meta
......@@ -128,7 +135,9 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
paddle::platform::errors::Fatal(
"Attempting to copy DenseTensorMeta with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_out.inner_place());
if (paddle::framework::IsComplexType(
paddle::framework::TransToProtoVarType(dense_tensor->type()))) {
......@@ -143,6 +152,7 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
void GradNodeBase::SetGradInMeta(
const std::vector<paddle::experimental::Tensor>& fwd_out,
size_t slot_rank) {
VLOG(6) << "Set GradSlotMeta for Grad Inputs";
size_t slot_size = fwd_out.size();
PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1),
......@@ -172,6 +182,12 @@ void GradNodeBase::SetGradInMeta(
meta.SetStopGradient(fwd_out_meta->StopGradient());
}
if (!fwd_out_tensor.is_initialized()) {
VLOG(6)
<< "Skip Configuring GradSlotMeta for uninitialized GradInput Tensor";
return;
}
// Record TensorMeta
if (phi::DenseTensor::classof(fwd_out_tensor.impl().get())) {
// Only Copy Meta
......@@ -184,6 +200,8 @@ void GradNodeBase::SetGradInMeta(
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_out_tensor.inner_place());
if (paddle::framework::IsComplexType(
paddle::framework::TransToProtoVarType(dense_tensor->type()))) {
need_complex_to_real_ = true;
......@@ -228,6 +246,7 @@ void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.inner_place());
}
} else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
......@@ -272,6 +291,7 @@ void GradNodeBase::SetGradOutMeta(
"phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in_tensor.inner_place());
}
} else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta "
......
......@@ -76,8 +76,12 @@ class GradSlotMeta {
return *meta_.get();
}
void SetPlace(const phi::Place& place) { place_ = place; }
const phi::Place& GetPlace() const { return place_; }
private:
bool stop_gradient_{false};
phi::Place place_;
std::shared_ptr<phi::DenseTensorMeta> meta_ = nullptr;
};
......@@ -102,7 +106,7 @@ class GradNodeBase {
* is better choice to fit this format.
* **/
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) = 0;
virtual void ClearTensorWrappers() = 0;
......
......@@ -53,7 +53,7 @@ class GradTensorHolder {
return buffer_[pos];
}
const std::vector<std::vector<paddle::experimental::Tensor>>& Buffers() {
std::vector<std::vector<paddle::experimental::Tensor>>& Buffers() {
return buffer_;
}
......
......@@ -80,13 +80,15 @@ TEST(AccumulationNode, Tensor) {
grad_meta->SetStopGradient(false);
// operator()
paddle::experimental::Tensor ret_et0 = node->operator()({{et0}})[0][0];
std::vector<std::vector<paddle::experimental::Tensor>> et0_vec = {{et0}};
paddle::experimental::Tensor ret_et0 = node->operator()(et0_vec)[0][0];
auto* ret_et0_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et0.impl())
->data<paddle::platform::float16>();
CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f));
paddle::experimental::Tensor ret_et1 = node->operator()({{et1}})[0][0];
std::vector<std::vector<paddle::experimental::Tensor>> et1_vec = {{et1}};
paddle::experimental::Tensor ret_et1 = node->operator()(et1_vec)[0][0];
auto* ret_et1_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et1.impl())
......@@ -121,7 +123,7 @@ TEST(AccumulationNode, Tensor) {
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_1));
// operator()
paddle::experimental::Tensor _ret = node->operator()({{et0}})[0][0];
paddle::experimental::Tensor _ret = node->operator()(et0_vec)[0][0];
// Check operator() result, should be 36.0
auto* _ret_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(_ret.impl())
......
......@@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase {
GradTestNode() : GradNodeBase() { val_ = 1.0; }
std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0];
......
......@@ -247,4 +247,20 @@ TEST(EagerUtils, GetGradAccumulationNode) {
ASSERT_ANY_THROW(egr::EagerUtils::GetGradAccumulationNode(t0));
}
TEST(EagerUtils, FillZeroForEmptyGradInputs) {
std::vector<std::vector<paddle::experimental::Tensor>> grads = {
std::vector<paddle::experimental::Tensor>(1)};
std::vector<std::vector<GradSlotMeta>> slot_metas = {
std::vector<GradSlotMeta>(1)};
phi::DenseTensorMeta tensor_meta;
tensor_meta.dtype = paddle::experimental::DataType::FLOAT32;
tensor_meta.dims = {2, 4};
slot_metas[0][0].SetTensorMeta(tensor_meta);
slot_metas[0][0].SetPlace(phi::CPUPlace());
EagerUtils::FillZeroForEmptyGradInputs(&grads, slot_metas);
eager_test::CompareTensorWithValue<float>(grads[0][0], 0.0);
}
} // namespace egr
......@@ -370,7 +370,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
~GradNodeRunProgram() override = default;
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>> &grads,
std::vector<std::vector<paddle::experimental::Tensor>> &grads, // NOLINT
bool create_graph) override {
VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram";
PADDLE_ENFORCE_EQ(
......
......@@ -20,6 +20,7 @@
#include "paddle/phi/api/all.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/framework/data_layout.h"
......@@ -392,4 +393,28 @@ std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
}
}
void EagerUtils::FillZeroForEmptyGradInputs(
std::vector<std::vector<paddle::experimental::Tensor>>* in_grads,
const std::vector<std::vector<GradSlotMeta>>& grad_in_metas) {
for (size_t i = 0; i < in_grads->size(); i++) {
for (size_t j = 0; j < (*in_grads)[0].size(); j++) {
paddle::experimental::Tensor& grad = (*in_grads)[i][j];
if (!grad.is_initialized()) {
const GradSlotMeta& grad_in_meta = grad_in_metas[i][j];
PADDLE_ENFORCE(
grad_in_meta.HasTensorMeta(),
paddle::platform::errors::Fatal(
"Unable to fill empty grad inputs due to empty GradSlotMeta"));
const auto& tensor_meta = grad_in_meta.GetTensorMeta();
phi::Place place = grad_in_meta.GetPlace();
auto tensor_with_zero = paddle::experimental::full(
phi::vectorize(tensor_meta.dims), 0.0, tensor_meta.dtype, place);
grad.set_impl(tensor_with_zero.impl());
}
}
}
}
} // namespace egr
......@@ -217,6 +217,13 @@ class EagerUtils {
const std::vector<paddle::experimental::Tensor>& tensors);
static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor);
/**
* Fill Zero
* **/
static void FillZeroForEmptyGradInputs(
std::vector<std::vector<paddle::experimental::Tensor>>* out_grads,
const std::vector<std::vector<GradSlotMeta>>& grad_out_metas);
};
} // namespace egr
......@@ -176,6 +176,20 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
const std::map<std::string, std::string>& inplace_map,
paddle::framework::AttributeMap* passed_default_attrs_,
bool use_default_attr_map) {
TraceOpImpl<VarType>(type, ins, outs, attrs, place, trace_backward,
inplace_map, passed_default_attrs_,
use_default_attr_map);
}
template <typename VarType>
void Tracer::TraceOpImpl(const std::string& type,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
framework::AttributeMap& attrs,
const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map,
paddle::framework::AttributeMap* passed_default_attrs_,
bool use_default_attr_map) {
platform::RecordEvent op_type_record_event(
type + " trace_op", platform::TracerEventType::Operator, 1);
platform::ScopedFlushDenormal flush;
......@@ -340,25 +354,33 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs,
paddle::framework::AttributeMap attrs,
paddle::framework::AttributeMap& attrs,
const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs,
bool use_default_attr_map,
const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: "
<< use_default_attr_map;
TraceOp<egr::EagerVariable>(type, ins, outs, std::move(attrs), place, false,
inplace_map, default_attrs, use_default_attr_map);
TraceOpImpl<egr::EagerVariable>(type, ins, outs, attrs, place, false,
inplace_map, default_attrs,
use_default_attr_map);
}
void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs,
paddle::framework::AttributeMap attrs) {
VLOG(6) << "Running On Eager TraceOp(4 agrs): ";
TraceOpImpl<egr::EagerVariable>(type, ins, outs, attrs, expected_place_,
false, {}, nullptr, true);
}
void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs,
paddle::framework::AttributeMap attrs,
paddle::framework::AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp(less): ";
TraceOp<egr::EagerVariable>(type, ins, outs, std::move(attrs),
expected_place_, false, inplace_map, nullptr,
true);
TraceOpImpl<egr::EagerVariable>(type, ins, outs, attrs, expected_place_,
false, inplace_map, nullptr, true);
}
void Tracer::SetExpectedPlace(platform::Place place) {
......
......@@ -74,16 +74,32 @@ class Tracer {
paddle::framework::AttributeMap* passed_default_attrs_ = nullptr,
bool use_default_attr_map = true);
template <typename VarType>
void TraceOpImpl(
const std::string& type, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
framework::AttributeMap& attrs, // NOLINT
const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map = {},
paddle::framework::AttributeMap* passed_default_attrs_ = nullptr,
bool use_default_attr_map = true);
void TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const std::map<std::string, std::string>& inplace_map = {});
void TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs, paddle::framework::AttributeMap attrs,
const NameTensorMap& outs,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::map<std::string, std::string>& inplace_map = {});
void TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs, paddle::framework::AttributeMap attrs,
const NameTensorMap& outs,
paddle::framework::AttributeMap attrs);
void TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs,
paddle::framework::AttributeMap& attrs, // NOLINT
const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs,
bool use_default_attr_map,
......
......@@ -34,6 +34,7 @@
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
......@@ -210,13 +211,28 @@ class AllocatorFacadePrivate {
InitNaiveBestFitCPUAllocator();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
allow_free_idle_chunk_ = allow_free_idle_chunk;
if (!FLAGS_use_stream_safe_cuda_allocator) {
for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount();
++dev_id) {
InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id),
allow_free_idle_chunk_);
}
for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount(); ++dev_id) {
InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id),
allow_free_idle_chunk_);
}
// Note(Ruibiao): For GPU multi-stream case, the 'allocators_' map(place
// -> Allocator) hold the StreamSafeCUDAAllocator releate to default
// stream (i.e., the stream directly got from DeviceContex), while the
// 'cuda_allocators_' map(place -> map(stream -> Allocator)) hold the
// StreamSafeCUDAAllocator releate to non-default stream (i.e., the
// stream users pass in). The default stream Allocator is built in the
// structure of AllocatorFacadePrivate, while the non-default stream is
// build in a delayed manner in GetAllocator function with
// 'create_if_not_found = ture'. We make special treatment for the
// default stream for performance reasons. Since most Alloc calls are
// for default stream in application, treating it separately can avoid
// lots of overhead of acquiring default stream and applying read-write
// lock.
if (FLAGS_use_stream_safe_cuda_allocator) {
WrapStreamSafeCUDAAllocatorForDefault();
}
InitNaiveBestFitCUDAPinnedAllocator();
#endif
#ifdef PADDLE_WITH_ASCEND_CL
......@@ -301,7 +317,8 @@ class AllocatorFacadePrivate {
CheckAllocThreadSafe();
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) {
if (FLAGS_use_stream_safe_cuda_allocator == false &&
UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) {
WrapCUDAGraphAllocator();
}
#endif
......@@ -341,7 +358,12 @@ class AllocatorFacadePrivate {
const std::shared_ptr<Allocator>& GetAllocator(
const platform::CUDAPlace& place, const gpuStream_t& stream,
bool create_if_not_found = false) {
{ // shared_lock_guard
if (stream == GetDefaultStream(place)) {
VLOG(7) << "Get Allocator by passing in a default stream";
return GetAllocator(place, /* A non-zero num to choose allocator_ */ 1);
}
/* shared_lock_guard */ {
std::shared_lock<std::shared_timed_mutex> lock_guard(
cuda_allocator_mutex_);
if (LIKELY(HasCUDAAllocator(place, stream))) {
......@@ -355,7 +377,7 @@ class AllocatorFacadePrivate {
}
}
{ // unique_lock_guard
/* unique_lock_guard */ {
std::unique_lock<std::shared_timed_mutex> lock_guard(
cuda_allocator_mutex_);
InitStreamSafeCUDAAllocator(place, stream);
......@@ -363,9 +385,40 @@ class AllocatorFacadePrivate {
}
}
gpuStream_t GetDefaultStream(const platform::CUDAPlace& place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
return static_cast<platform::CUDADeviceContext*>(pool.Get(place))->stream();
const std::shared_ptr<StreamSafeCUDAAllocator>
GetDefaultStreamSafeCUDAAllocator(const platform::CUDAPlace& place) const {
const auto iter = default_stream_safe_cuda_allocators_.find(place);
PADDLE_ENFORCE_NE(
iter, default_stream_safe_cuda_allocators_.end(),
platform::errors::NotFound(
"No StreamSafeCUDAAllocator found for the place, %s", place));
return iter->second;
}
const gpuStream_t& GetDefaultStream(const platform::CUDAPlace& place) const {
const std::shared_ptr<StreamSafeCUDAAllocator>& allocator =
GetDefaultStreamSafeCUDAAllocator(place);
return allocator->GetDefaultStream();
}
void SetDefaultStream(const platform::CUDAPlace& place,
const gpuStream_t& stream) {
const std::shared_ptr<StreamSafeCUDAAllocator>& allocator =
GetDefaultStreamSafeCUDAAllocator(place);
allocator->SetDefaultStream(stream);
VLOG(8) << "Set default stream to " << stream
<< " for StreamSafeCUDAAllocator(" << allocator.get() << ") in "
<< place;
}
void SetDefaultStreamFromDeviceContext() {
VLOG(8) << "Set default stream from DeviceContex";
for (auto& pair : default_stream_safe_cuda_allocators_) {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
pair.second->SetDefaultStream(
static_cast<phi::GPUContext*>(pool.Get(pair.first))->stream());
}
}
void RecordStream(std::shared_ptr<phi::Allocation> allocation,
......@@ -635,6 +688,26 @@ class AllocatorFacadePrivate {
/* in_cuda_graph_capturing = */ !allow_free_idle_chunk_);
}
void WrapStreamSafeCUDAAllocatorForDefault() {
for (auto& pair : allocators_) {
auto& place = pair.first;
if (platform::is_gpu_place(place)) {
std::shared_ptr<StreamSafeCUDAAllocator>&& allocator =
std::make_shared<StreamSafeCUDAAllocator>(
pair.second, place, /* default_stream = */ nullptr,
/* in_cuda_graph_capturing = */ !allow_free_idle_chunk_);
pair.second = allocator;
// NOTE(Ruibiao): A tricky implement to give StreamSafeCUDAAllocator an
// ability to interact with the outside world, i.e., change default
// stream from outside
default_stream_safe_cuda_allocators_[place] = allocator;
VLOG(8) << "WrapStreamSafeCUDAAllocator for " << place
<< ", allocator address = " << pair.second.get();
}
}
}
void WrapCUDARetryAllocator(platform::CUDAPlace p, gpuStream_t stream,
size_t retry_time) {
PADDLE_ENFORCE_GT(
......@@ -813,7 +886,6 @@ class AllocatorFacadePrivate {
#endif
}
// NOTE(Ruibiao): Old single-stream version, will be removed later
void WrapCUDARetryAllocator(size_t retry_time) {
PADDLE_ENFORCE_GT(
retry_time, 0,
......@@ -828,6 +900,8 @@ class AllocatorFacadePrivate {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// a standalone CUDA allocator to support multi-stream GC in new executor
std::map<platform::Place, std::shared_ptr<StreamSafeCUDAAllocator>>
default_stream_safe_cuda_allocators_;
CUDAAllocatorMap cuda_allocators_;
std::shared_timed_mutex cuda_allocator_mutex_;
#endif
......@@ -870,15 +944,6 @@ AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const {
const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
const platform::Place& place) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) &&
FLAGS_use_system_allocator == false) {
AllocatorFacadePrivate* m = GetPrivate();
platform::CUDAPlace cuda_place(place.GetDeviceId());
return m->GetAllocator(cuda_place, m->GetDefaultStream(cuda_place));
}
#endif
return GetPrivate()->GetAllocator(
place, /* A non-zero num to choose allocator_ */ 1);
}
......@@ -898,19 +963,6 @@ void* AllocatorFacade::GetBasePtr(
return GetPrivate()->GetBasePtr(allocation);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
const platform::Place& place, const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) &&
FLAGS_use_system_allocator == false) {
return GetPrivate()->GetAllocator(place, stream,
/*create_if_not_found=*/true);
}
return GetPrivate()->GetAllocator(
place, /* A non-zero num to choose allocator_ */ 1);
}
#endif
const std::shared_ptr<Allocator>& AllocatorFacade::GetZeroAllocator(
const platform::Place& place) {
return GetPrivate()->GetAllocator(place, /* zero size */ 0);
......@@ -923,26 +975,10 @@ std::shared_ptr<phi::Allocation> AllocatorFacade::AllocShared(
AllocationPtr AllocatorFacade::Alloc(const platform::Place& place,
size_t size) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) &&
size > 0 && FLAGS_use_system_allocator == false) {
platform::CUDAPlace cuda_place(place.GetDeviceId());
phi::Stream default_stream = phi::Stream(reinterpret_cast<phi::StreamId>(
GetPrivate()->GetDefaultStream(cuda_place)));
return Alloc(cuda_place, size, default_stream);
}
#endif
return GetPrivate()->GetAllocator(place, size)->Allocate(size);
}
uint64_t AllocatorFacade::Release(const platform::Place& place) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) &&
FLAGS_use_system_allocator == false) {
platform::CUDAPlace cuda_place(place.GetDeviceId());
return Release(cuda_place, GetPrivate()->GetDefaultStream(cuda_place));
}
#endif
return GetPrivate()
->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1)
->Release(place);
......@@ -1028,6 +1064,17 @@ void AllocatorFacade::RecordStream(std::shared_ptr<phi::Allocation> allocation,
GetPrivate()->RecordStream(allocation, stream);
}
const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
const platform::Place& place, const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) &&
FLAGS_use_system_allocator == false) {
return GetPrivate()->GetAllocator(place, stream,
/*create_if_not_found=*/true);
}
return GetPrivate()->GetAllocator(
place, /* A non-zero num to choose allocator_ */ 1);
}
const gpuStream_t& AllocatorFacade::GetStream(
const std::shared_ptr<phi::Allocation>& allocation) const {
PADDLE_ENFORCE_EQ(
......@@ -1040,6 +1087,13 @@ const gpuStream_t& AllocatorFacade::GetStream(
return GetPrivate()->GetStream(allocation);
}
void AllocatorFacade::SetDefaultStream(const platform::CUDAPlace& place,
const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
GetPrivate()->SetDefaultStream(place, stream);
}
}
#ifdef PADDLE_WITH_CUDA
void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) {
PADDLE_ENFORCE_EQ(GetAllocatorStrategy(), AllocatorStrategy::kAutoGrowth,
......@@ -1055,6 +1109,8 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) {
"The memory pool of the CUDA Graph with ID %d have been prepared.",
id));
allocator.reset(new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false));
allocator->SetDefaultStreamFromDeviceContext();
VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id;
}
......
......@@ -55,11 +55,6 @@ class AllocatorFacade {
void* GetBasePtr(const std::shared_ptr<Allocation>& allocation);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place,
const gpuStream_t& stream);
#endif
const std::shared_ptr<Allocator>& GetZeroAllocator(
const platform::Place& place);
......@@ -86,8 +81,12 @@ class AllocatorFacade {
uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream);
void RecordStream(std::shared_ptr<Allocation> allocation,
const gpuStream_t& stream);
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place,
const gpuStream_t& stream);
const gpuStream_t& GetStream(
const std::shared_ptr<Allocation>& allocation) const;
void SetDefaultStream(const platform::CUDAPlace& place,
const gpuStream_t& stream);
#endif
#ifdef PADDLE_WITH_CUDA
......
......@@ -154,6 +154,14 @@ StreamSafeCUDAAllocator::~StreamSafeCUDAAllocator() {
bool StreamSafeCUDAAllocator::IsAllocThreadSafe() const { return true; }
const gpuStream_t& StreamSafeCUDAAllocator::GetDefaultStream() const {
return default_stream_;
}
void StreamSafeCUDAAllocator::SetDefaultStream(const gpuStream_t& stream) {
default_stream_ = stream;
}
phi::Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) {
platform::RecordEvent("StreamSafeCUDAAllocator::Allocate",
platform::TracerEventType::UserDefined, 9 /*level*/);
......@@ -187,12 +195,8 @@ void StreamSafeCUDAAllocator::FreeImpl(phi::Allocation* allocation) {
platform::RecordEvent("StreamSafeCUDAAllocator::Free",
platform::TracerEventType::UserDefined, 9 /*level*/);
StreamSafeCUDAAllocation* stream_safe_cuda_allocation =
dynamic_cast<StreamSafeCUDAAllocation*>(allocation);
PADDLE_ENFORCE_NOT_NULL(stream_safe_cuda_allocation,
platform::errors::InvalidArgument(
"Failed to dynamic cast %p from Allocation* to "
"StreamSafeCUDAAllocation*",
allocation));
static_cast<StreamSafeCUDAAllocation*>(allocation);
VLOG(8) << "Try free allocation " << stream_safe_cuda_allocation->ptr();
if (stream_safe_cuda_allocation->CanBeFreed()) {
VLOG(9) << "Directly delete allocation";
......@@ -221,6 +225,12 @@ uint64_t StreamSafeCUDAAllocator::ReleaseImpl(const platform::Place& place) {
}
void StreamSafeCUDAAllocator::ProcessUnfreedAllocations() {
// NOTE(Ruibiao): This condition is to reduce lock competion. It does not need
// to be thread-safe since here occasional misjudgments are permissible.
if (unfreed_allocations_.empty()) {
return;
}
std::lock_guard<SpinLock> lock_guard(unfreed_allocation_lock_);
for (auto it = unfreed_allocations_.begin();
it != unfreed_allocations_.end();) {
......
......@@ -64,7 +64,10 @@ class StreamSafeCUDAAllocator
platform::CUDAPlace place, gpuStream_t default_stream,
bool in_cuda_graph_capturing = false);
~StreamSafeCUDAAllocator();
bool IsAllocThreadSafe() const override;
const gpuStream_t &GetDefaultStream() const;
void SetDefaultStream(const gpuStream_t &stream);
protected:
phi::Allocation *AllocateImpl(size_t size) override;
......
......@@ -24,7 +24,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnInstructionRun");
const CinnCompiledObject& compiled_object =
......@@ -43,6 +45,53 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
});
ctx->SetOutputsDim(kOutputs, output_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// Why we need override GetExpectedKernelType?
// A cinn-graph may has no inpute var, if we use the base function,
// it will check wheter input tensors is initialized. Here we rewrite
// the function so that we can infer kernel type by output date type.
if (ctx.InputSize(kX)) {
// if the instruction has input, infer kernel type by input date type:
return OperatorWithKernel::GetExpectedKernelType(ctx);
}
// Else infer kernel type by output date type:
// The `OutputVar` will check wheter the kOutputs iff has one output var
const framework::Variable* var = ctx.OutputVar(kOutputs);
PADDLE_ENFORCE_NE(
var, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Variable should not empty."));
const framework::Tensor* tensor = nullptr;
if (var->IsType<framework::Tensor>()) {
tensor = &var->Get<framework::Tensor>();
} else if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
tensor = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = &var->Get<framework::LoDTensorArray>();
PADDLE_ENFORCE_EQ(t_arr->size(), 1UL,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op should just has One "
"Output when Input empty."));
tensor = &(t_arr->front());
}
PADDLE_ENFORCE_NE(
tensor, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Tensor should not empty."));
VLOG(4) << "The tensor [" << ctx.OutputName(kOutputs) << "]'s dtype is "
<< paddle::framework::DataType2String(tensor->dtype());
auto output_type = paddle::framework::TransToProtoVarType(tensor->dtype());
return framework::OpKernelType(output_type, ctx.device_context());
}
};
class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -87,9 +87,12 @@ class CinnLaunchOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
"Input", string::format_string("%s|%s", kX, kNoNeedBufferX),
"CinnLaunchOp");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
// "Input", string::format_string("%s|%s", kX,
// kNoNeedBufferX),
// "CinnLaunchOp");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnLaunchOp");
}
......
......@@ -35,143 +35,99 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h"
namespace paddle {
namespace operators {
template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskGenerator {
const float dropout_prob_;
const bool is_upscale_in_train_;
using MT = typename details::MPTypeTrait<T1>::Type;
MT factor;
HOSTDEVICE inline DstMaskGenerator(const float dropout_prob,
const bool is_upscale_in_train)
: dropout_prob_(dropout_prob), is_upscale_in_train_(is_upscale_in_train) {
factor = static_cast<MT>(1.0f / (1.0f - dropout_prob_));
}
template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src,
MaskType* mask, T* dst,
bool is_upscale_in_train, uint64_t increment) {
using MT = typename details::MPTypeTrait<T>::Type;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, increment, &state);
#else
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
#endif
MaskType mask_val;
T dst_val;
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
for (; idx < n; idx += blockDim.x * gridDim.x) {
T src_val = src[idx];
#ifdef PADDLE_WITH_HIP
if (hiprand_uniform(&state) < dropout_prob) {
#else
if (curand_uniform(&state) < dropout_prob) {
#endif
mask_val = 0;
dst_val = 0;
} else {
mask_val = 1;
dst_val = is_upscale_in_train
? static_cast<T>(static_cast<MT>(src_val) * factor)
: src_val;
HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val,
const T2* rand, int num) const {
static constexpr int kCount =
phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for (int i = 0; i < kCount; i++) {
if (rand[i] < dropout_prob_) {
dst[i] = static_cast<T1>(0);
dst[i + kCount] = dst[i];
} else {
dst[i] = is_upscale_in_train_
? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
: static_cast<T1>(src_val[i]);
dst[i + kCount] = static_cast<T1>(1);
}
}
mask[idx] = mask_val;
dst[idx] = dst_val;
}
}
};
template <typename T, typename MaskType, int VecSize>
template <typename T, typename MaskType>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask, T* dst,
bool is_upscale_in_train,
uint64_t increment) {
using MT = typename details::MPTypeTrait<T>::Type;
using LoadT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
uint64_t increment,
size_t main_offset) {
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount =
phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
#ifdef PADDLE_WITH_HIP
int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, increment, &state);
hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = hiprandStatePhilox4_32_10_t;
#else
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
#endif
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
LoadT src_val;
phi::Load<T, VecSize>(&src[i], &src_val);
#ifdef PADDLE_WITH_HIP
float4 rand = hiprand_uniform4(&state);
#else
float4 rand = curand_uniform4(&state);
curand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = curandStatePhilox4_32_10_t;
#endif
LoadT dst_val;
MaskLoadT mask_val;
#pragma unroll
for (int j = 0; j < VecSize; j++) {
if ((&rand.x)[j] < dropout_prob) {
dst_val[j] = 0;
mask_val[j] = 0;
} else {
dst_val[j] = is_upscale_in_train
? static_cast<T>(static_cast<MT>(src_val[j]) * factor)
: src_val[j];
mask_val[j] = 1;
}
}
phi::Store<T, VecSize>(dst_val, &dst[i]);
phi::Store<MaskType, VecSize>(mask_val, &mask[i]);
T dst_mask[kCount * 2]; // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
float rands[kCount];
MaskType mask_result[kCount];
using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>;
int deal_size = BLOCK_NUM_X * kCount;
auto dst_functor =
DstMaskGenerator<T, float>(dropout_prob, is_upscale_in_train);
size_t fix = idx * kCount;
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], deal_size);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
deal_size);
}
}
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}
__device__ __forceinline__ T operator()(const T dout,
const MaskType mask) const {
return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
factor_);
}
private:
MT factor_;
};
template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(
const T* dout, const MaskType* mask,
const typename details::MPTypeTrait<T>::Type factor, const int64_t size,
T* dx) {
using MT = typename details::MPTypeTrait<T>::Type;
using LoadT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_val;
phi::Load<T, VecSize>(&dout[i], &dout_val);
MaskLoadT mask_val;
phi::Load<MaskType, VecSize>(&mask[i], &mask_val);
LoadT dx_val;
#pragma unroll
for (int j = 0; j < VecSize; j++) {
dx_val[j] = static_cast<T>(static_cast<MT>(dout_val[j]) *
static_cast<MT>(mask_val[j]) * factor);
}
phi::Store<T, VecSize>(dx_val, &dx[i]);
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
remainder);
}
}
......@@ -218,42 +174,21 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
uint64_t seed_data;
uint64_t increment;
// VectorizedRandomGenerator use curand_uniform4, so we only support
// vec_size is 4;
int vec_size = (phi::GetVectorizedSize<T>(x_data) == 4) ? 4 : 1;
// kVecSize is 4;
constexpr int kVecSize =
phi::funcs::uniform_distribution<float>::kReturnsCount;
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size);
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
auto offset =
((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size;
((x_numel - 1) / (gpu_config.GetThreadNum() * kVecSize) + 1) * kVecSize;
GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
&seed_data, &increment);
#ifdef __HIPCC__
if (vec_size == 4 && size % 4 == 0) {
hipLaunchKernelGGL(
HIP_KERNEL_NAME(VectorizedRandomGenerator<T, uint8_t, 4>),
gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, stream, size,
seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train,
increment);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator<T, uint8_t>),
gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0,
stream, size, seed_data, dropout_prob, x_data,
mask_data, y_data, upscale_in_train, increment);
}
#else
if (vec_size == 4 && size % 4 == 0) {
VectorizedRandomGenerator<T, uint8_t, 4><<<
gpu_config.block_per_grid, gpu_config.thread_per_block, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment);
} else {
RandomGenerator<T, uint8_t><<<gpu_config.block_per_grid,
gpu_config.thread_per_block, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment);
}
#endif
size_t main_offset = size / (gpu_config.GetBlockSize() * kVecSize) *
(gpu_config.GetBlockSize() * kVecSize);
VectorizedRandomGenerator<T, uint8_t><<<
gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment, main_offset);
} else {
if (upscale_in_train) {
// todo: can y share with data with x directly?
......@@ -278,6 +213,22 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
}
}
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}
__device__ __forceinline__ T operator()(const T dout,
const MaskType mask) const {
return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
factor_);
}
private:
MT factor_;
};
template <typename T>
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
const std::string dropout_implementation,
......
......@@ -58,19 +58,15 @@ __global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale,
}
template <typename T>
__global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
T max_range, const int num,
const int cin, const int cout,
T* out) {
int bid = blockIdx.x;
T s = scale[bid % cout];
int wh_size = num / (cin * cout);
const T* in_current = in + bid * wh_size;
T* out_current = out + bid * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
out_current[i] = in_current[i] * s / max_range;
__global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale,
const T max_range,
const int64_t num,
const int n_scales,
const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % n_scales];
out[i] = in[i] * s / max_range;
}
}
......@@ -98,20 +94,32 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) {
int num = in->numel();
int64_t num = in->numel();
const T* scale_factor = scales[0]->data<T>();
if (quant_axis == 0) {
int grid = in_dims[0];
int block = 1024;
DequantizeOneScaleQuantAxis0<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], out_data);
} else if (quant_axis == 1) {
// Dequantize weight of Cin * Cout * W * H
int grid = in_dims[0] * in_dims[1];
int block = 1024;
DequantizeOneScaleQuantAxis1<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], in_dims[1],
out_data);
} else {
int quant_stride = 1;
for (int i = quant_axis + 1; i < in_dims.size(); i++) {
quant_stride *= in_dims[i];
}
int64_t block_size = std::min(
num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads =
dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks = std::max(
((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
DequantizeOneScaleQuantAxisN<
T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[quant_axis],
quant_stride, out_data);
}
} else if (scale_num == 2) {
// Not need to consider quant_axis
......
......@@ -273,18 +273,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const int bin_cnt,
const int n, const int c,
T* out) {
const int64_t n,
const int c, T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
int64_t channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
T inv_s = inverse(s);
for (int i = tid; i < channel_size; i += blockDim.x) {
for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
......@@ -293,25 +293,20 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
}
}
// ChannelClipAndQuantKernel for quant_axis is 1
// ChannelClipAndQuantKernel for quant_axis is N
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale,
const int bin_cnt,
const int n, const int cin,
const int cout, T* out) {
T s = scale[blockIdx.x % cout];
T inv_s = inverse(s);
int wh_size = n / (cin * cout);
const T* in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i];
__global__ void ChannelClipAndQuantKernelQuantAxisN(
const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int nScale, const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % nScale];
T inv_s = 1.0 / s;
T x = in[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v);
out[i] = round(v);
}
}
......@@ -327,7 +322,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
"the received is %d",
quant_axis));
int num = in.numel();
int64_t num = in.numel();
auto in_dims = in.dims();
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
......@@ -338,11 +333,24 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
int block = 1024;
ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], out_data);
} else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1];
int block = 1024;
ChannelClipAndQuantKernelQuantAxis1<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
} else {
int quant_stride = 1;
for (int i = quant_axis + 1; i < in_dims.size(); i++) {
quant_stride *= in_dims[i];
}
int64_t block_size =
std::min(num, static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads =
ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
ChannelClipAndQuantKernelQuantAxisN<T><<<grid_size, block_size>>>(
in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride,
out_data);
}
}
};
......
......@@ -64,18 +64,26 @@ class FrameOp : public framework::OperatorWithKernel {
end_axis = x_rank - 2;
}
PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
}
// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
n_frames = 1 + (seq_length - frame_length) / hop_length;
if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}
if (axis == 0) {
// (n_frames, frame_length, ...)
......
......@@ -98,9 +98,17 @@ REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
paddle::platform::bfloat16>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
paddle::platform::bfloat16>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
......@@ -102,10 +102,17 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
mean_grad,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
paddle::platform::complex<float>>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
......@@ -54,6 +54,7 @@ class OverlapAddOp : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape;
int n_frames;
int frame_length;
int seq_length;
int start_axis;
int end_axis;
......@@ -69,14 +70,22 @@ class OverlapAddOp : public framework::OperatorWithKernel {
end_axis = x_rank - 3;
}
PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
}
const int seq_length = (n_frames - 1) * hop_length + frame_length;
if (n_frames == -1) {
seq_length = -1;
} else {
seq_length = (n_frames - 1) * hop_length + frame_length;
}
// It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) {
......
......@@ -13,28 +13,7 @@
// limitations under the License.
#include "paddle/fluid/operators/spectral_op.h"
#include <algorithm>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_ONEMKL)
#include "paddle/phi/backends/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/operators/spectral_helper.h"
namespace paddle {
namespace operators {
......@@ -355,465 +334,6 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
norm));
}
// FFT Functors
#if defined(PADDLE_WITH_ONEMKL)
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!phi::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW( \
platform::errors::External(phi::dynload::DftiErrorMessage(status))); \
} while (0);
namespace {
struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(phi::dynload::DftiFreeDescriptor(&handle));
}
}
};
// A RAII wrapper for MKL_DESCRIPTOR*
class DftiDescriptor {
public:
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type,
MKL_LONG signal_ndim, MKL_LONG* sizes) {
PADDLE_ENFORCE_EQ(desc_.get(), nullptr,
platform::errors::AlreadyExists(
"DftiDescriptor has already been initialized."));
DFTI_DESCRIPTOR* raw_desc;
MKL_DFTI_CHECK(phi::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes));
desc_.reset(raw_desc);
}
DFTI_DESCRIPTOR* get() const {
DFTI_DESCRIPTOR* raw_desc = desc_.get();
PADDLE_ENFORCE_NOT_NULL(raw_desc,
platform::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized."));
return raw_desc;
}
private:
std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
};
DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
const framework::proto::VarType::Type& out_dtype,
const framework::DDim& in_strides,
const framework::DDim& out_strides,
const std::vector<int>& signal_sizes,
FFTNormMode normalization, bool forward) {
const DFTI_CONFIG_VALUE precision = [&] {
switch (in_dtype) {
case framework::proto::VarType::FP32:
return DFTI_SINGLE;
case framework::proto::VarType::COMPLEX64:
return DFTI_SINGLE;
case framework::proto::VarType::FP64:
return DFTI_DOUBLE;
case framework::proto::VarType::COMPLEX128:
return DFTI_DOUBLE;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128.",
framework::DataTypeToString(in_dtype)));
}
}();
// C2C, R2C, C2R
const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype);
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(), DFTI_PLACEMENT,
DFTI_NOT_INPLACE));
// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));
// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
std::vector<MKL_LONG> mkl_out_stride(1 + signal_ndim, 0);
for (MKL_LONG i = 1; i <= signal_ndim; i++) {
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
}
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
// conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
}
MKL_LONG signal_numel =
std::accumulate(fft_sizes.cbegin() + 1, fft_sizes.cend(), 1UL,
std::multiplies<MKL_LONG>());
if (normalization != FFTNormMode::none) {
const double scale =
((normalization == FFTNormMode::by_sqrt_n)
? 1.0 / std::sqrt(static_cast<double>(signal_numel))
: 1.0 / static_cast<double>(signal_numel));
const auto scale_direction = [&]() {
if (fft_type == FFTTransformType::R2C ||
(fft_type == FFTTransformType::C2C && forward)) {
return DFTI_FORWARD_SCALE;
} else {
// (fft_type == FFTTransformType::C2R ||
// (fft_type == FFTTransformType::C2C && !forward))
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), scale_direction, scale));
}
// commit the descriptor
MKL_DFTI_CHECK(phi::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor;
}
// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward) {
const framework::DDim& in_sizes = x->dims();
const int ndim = in_sizes.size();
const int signal_ndim = axes.size();
const int batch_ndim = ndim - signal_ndim;
const framework::DDim& out_sizes = out->dims();
// make a dim permutation
std::vector<int> dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), 0);
std::vector<bool> is_transformed_dim(ndim, false);
for (const auto& d : axes) {
is_transformed_dim[d] = true;
}
const auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(),
[&](size_t axis) { return !is_transformed_dim[axis]; });
std::copy(axes.cbegin(), axes.cend(), batch_end);
// transpose input according to that permutation
framework::DDim transposed_input_shape = in_sizes.transpose(dim_permute);
std::vector<int64_t> transposed_input_shape_ =
phi::vectorize(transposed_input_shape);
framework::Tensor transposed_input;
transposed_input.Resize(transposed_input_shape);
const auto place = ctx.GetPlace();
transposed_input.mutable_data<Ti>(place);
TransCompute<platform::CPUDeviceContext, Ti>(ndim, ctx, *x, &transposed_input,
dim_permute);
// make an collapsed input: collapse batch axes for input
const int batch_size = std::accumulate(
transposed_input_shape.Get(), transposed_input_shape.Get() + batch_ndim,
1L, std::multiplies<int64_t>());
std::vector<int> collapsed_input_shape_(1 + signal_ndim);
collapsed_input_shape_[0] = batch_size;
std::copy(transposed_input_shape_.begin() + batch_ndim,
transposed_input_shape_.end(), collapsed_input_shape_.begin() + 1);
const framework::DDim collapsed_input_shape =
phi::make_ddim(collapsed_input_shape_);
transposed_input.Resize(collapsed_input_shape);
framework::Tensor& collapsed_input = transposed_input;
// make a collapsed output
std::vector<int> collapsed_output_shape_(1 + signal_ndim);
collapsed_output_shape_[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
collapsed_output_shape_[1 + i] = out_sizes[axes[i]];
}
const framework::DDim collapsed_output_shape =
phi::make_ddim(collapsed_output_shape_);
framework::Tensor collapsed_output;
collapsed_output.Resize(collapsed_output_shape);
collapsed_output.mutable_data(place, out->type());
// signal sizes
std::vector<int> signal_sizes(1 + signal_ndim);
signal_sizes[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
signal_sizes[1 + i] =
std::max(collapsed_input_shape[1 + i], collapsed_output_shape[1 + i]);
}
// input & output stride
const framework::DDim input_stride = phi::stride(collapsed_input_shape);
const framework::DDim output_stride = phi::stride(collapsed_output_shape);
// make a DFTI_DESCRIPTOR
DftiDescriptor desc =
_plan_mkl_fft(framework::TransToProtoVarType(x->dtype()),
framework::TransToProtoVarType(out->dtype()), input_stride,
output_stride, signal_sizes, normalization, forward);
const FFTTransformType fft_type =
GetFFTTransformType(framework::TransToProtoVarType(x->dtype()),
framework::TransToProtoVarType(out->type()));
if (fft_type == FFTTransformType::C2R && forward) {
framework::Tensor collapsed_input_conj(collapsed_input.dtype());
collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
ctx.GetPlace());
// conjugate the input
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
phi::funcs::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data(), collapsed_output.data()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.dtype());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output_conj.data()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
phi::funcs::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(),
collapsed_output.data<To>());
for_range(functor);
} else {
if (forward) {
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
} else {
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
}
}
// resize for the collapsed output
framework::DDim transposed_output_shape = out_sizes.transpose(dim_permute);
collapsed_output.Resize(transposed_output_shape);
framework::Tensor& transposed_output = collapsed_output;
// reverse the transposition
std::vector<int> reverse_dim_permute(ndim);
for (int i = 0; i < ndim; i++) {
reverse_dim_permute[dim_permute[i]] = i;
}
TransCompute<platform::CPUDeviceContext, To>(ndim, ctx, transposed_output,
out, reverse_dim_permute);
}
} // anonymous namespace
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
};
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
if (axes.size() > 1) {
const std::vector<int64_t> c2c_dims(axes.begin(), axes.end() - 1);
Tensor temp;
temp.mutable_data<Ti>(x->dims(), ctx.GetPlace());
FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward);
const std::vector<int64_t> new_axes{axes.back()};
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
normalization, forward);
} else {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
}
};
#elif defined(PADDLE_WITH_POCKETFFT)
namespace {
template <typename T>
T compute_factor(int64_t size, FFTNormMode normalization) {
constexpr auto one = static_cast<T>(1);
switch (normalization) {
case FFTNormMode::none:
return one;
case FFTNormMode::by_n:
return one / static_cast<T>(size);
case FFTNormMode::by_sqrt_n:
return one / std::sqrt(static_cast<T>(size));
}
PADDLE_THROW(
platform::errors::InvalidArgument("Unsupported normalization type"));
}
} // anonymous namespace
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = typename Ti::value_type;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = reinterpret_cast<C*>(out->data<To>());
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= in_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2c(in_sizes, in_strides, in_strides, axes_, forward, in_data,
out_data, factor);
}
};
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = Ti;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
{
const int64_t data_size = sizeof(R);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto& output_dim = out->dims();
const std::vector<size_t> out_sizes = phi::vectorize<size_t>(output_dim);
std::vector<std::ptrdiff_t> out_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(output_dim));
{
const int64_t data_size = sizeof(C);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto* in_data = x->data<R>();
auto* out_data = reinterpret_cast<C*>(out->data<To>());
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet normalization factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= in_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::r2c(in_sizes, in_strides, out_strides, axes_, forward, in_data,
out_data, factor);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = To;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
{
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto& output_dim = out->dims();
const std::vector<size_t> out_sizes = phi::vectorize<size_t>(output_dim);
std::vector<std::ptrdiff_t> out_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(output_dim));
{
const int64_t data_size = sizeof(R);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = out->data<R>();
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet normalization factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= out_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2r(out_sizes, in_strides, out_strides, axes_, forward, in_data,
out_data, factor);
}
};
#endif
} // namespace operators
} // namespace paddle
......
此差异已折叠。
......@@ -11,8 +11,11 @@
#pragma once
#define NOMINMAX // to use std::min std::max correctly on windows
#include <algorithm>
#include <functional>
#include <iostream>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
......@@ -23,8 +26,10 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/padding.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "thrust/device_vector.h"
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stft_op.h"
#include "paddle/fluid/operators/spectral_helper.h"
namespace paddle {
namespace operators {
class StftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame");
const int n_fft = ctx->Attrs().Get<int>("n_fft");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
const bool onesided = ctx->Attrs().Get<bool>("onesided");
PADDLE_ENFORCE_EQ(
x_rank, 2,
platform::errors::InvalidArgument(
"Input(X) of StftOp should be a tensor with shape [N, T], "
"but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(
hop_length, 0,
platform::errors::InvalidArgument(
"Attribute(hop_length) should be greater than 0, but got %s.",
hop_length));
int seq_length = x_dims[x_rank - 1];
int n_frames = 1 + (seq_length - n_fft) / hop_length;
PADDLE_ENFORCE_LE(n_fft, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) should be less equal than "
"sequence length, but got (%s) > (%s).",
n_fft, seq_length));
std::vector<int64_t> output_shape;
output_shape.push_back(x_dims[0]);
if (onesided) {
output_shape.push_back(n_fft / 2 + 1);
} else {
output_shape.push_back(n_fft);
}
output_shape.push_back(n_frames);
ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class StftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input waveforms with shape (N, T)");
AddOutput("Out",
"The complex STFT output tensor with shape (N, n_fft, "
"num_frames) or (N, n_fft/2 + 1, num_frames)");
AddAttr<int>("n_fft", "The number of input samples to perform FFT");
AddAttr<int>("hop_length", "Number of samples between adjacent frames");
AddAttr<bool>("normalized",
"Control whether to scale the output by 1/sqrt(n_fft)");
AddAttr<bool>("onesided",
"Control whether to return half of the FFT output");
AddComment(R"DOC(
Short-time Fourier transform (STFT).
)DOC");
}
};
template <typename T>
class StftGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("stft_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class StftGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"stft_grad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "stft_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"stft_grad");
ctx->ShareDim("X", /*->*/ x_grad_name);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(stft, ops::StftOp, ops::StftOpMaker,
ops::StftGradOpMaker<paddle::framework::OpDesc>,
ops::StftGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(stft_grad, ops::StftGradOp);
REGISTER_OP_CPU_KERNEL(
stft, ops::StftKernel<paddle::platform::CPUDeviceContext, float>,
ops::StftKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
stft_grad, ops::StftGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StftGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/spectral_op.cu.h"
#include "paddle/fluid/operators/stft_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
stft, ops::StftKernel<paddle::platform::CUDADeviceContext, float>,
ops::StftKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
stft_grad, ops::StftGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StftGradKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/frame_op.h"
#include "paddle/fluid/operators/spectral_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class StftKernel : public framework::OpKernel<T> {
public:
/*
Batch Signals (N, T) -> Frames (N, n_fft, num_frames) -> FFTR2C -> (N,
n_fft/2 + 1, num_frames) or (N, n_fft, num_frames)
*/
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<C>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int n_fft = ctx.Attr<int>("n_fft");
const int hop_length = ctx.Attr<int>("hop_length");
const bool normalized = ctx.Attr<bool>("normalized");
const bool onesided = ctx.Attr<bool>("onesided");
const int n_frames = out->dims()[out_rank - 1];
const int seq_length = x->dims()[x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<int64_t> axes = {1};
// Frame
Tensor frames;
framework::DDim frames_dims(out->dims());
frames_dims.at(axes.back()) = n_fft;
frames.mutable_data<T>(frames_dims, ctx.GetPlace());
FrameFunctor<DeviceContext, T>()(dev_ctx, x, &frames, seq_length, n_fft,
n_frames, hop_length, /*is_grad*/ false);
// FFTR2C
FFTNormMode normalization;
if (normalized) {
normalization = get_norm_from_string("ortho", true);
} else {
normalization = get_norm_from_string("backward", true);
}
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(dev_ctx, &frames, out, axes, normalization, true);
} else {
framework::DDim onesided_dims(out->dims());
const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_axis_size;
Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(dev_ctx, &frames, &onesided_out, axes, normalization, true);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, out, axes);
}
}
};
template <typename DeviceContext, typename T>
class StftGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
const size_t dy_rank = dy->dims().size();
const size_t dx_rank = dx->dims().size();
const int n_fft = ctx.Attr<int>("n_fft");
const int hop_length = ctx.Attr<int>("hop_length");
const bool normalized = ctx.Attr<bool>("normalized");
const bool onesided = ctx.Attr<bool>("onesided");
const int n_frames = dy->dims()[dy_rank - 1];
const int seq_length = dx->dims()[dx_rank - 1];
std::vector<int64_t> axes = {1};
Tensor d_frames;
framework::DDim d_frames_dims(dy->dims());
d_frames_dims.at(axes.back()) = n_fft;
d_frames.mutable_data<T>(d_frames_dims, ctx.GetPlace());
Tensor complex_d_frames;
complex_d_frames.mutable_data<C>(d_frames_dims, ctx.GetPlace());
// dy -> d_frames
FFTNormMode normalization;
if (normalized) {
normalization = get_norm_from_string("ortho", true);
} else {
normalization = get_norm_from_string("backward", true);
}
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) {
fft_c2c_func(dev_ctx, dy, &complex_d_frames, axes, normalization, false);
} else {
Tensor full_dy;
full_dy.mutable_data<C>(d_frames_dims, ctx.GetPlace());
auto zero_length = static_cast<int>(full_dy.dims().at(axes.back()) -
dy->dims().at(axes.back()));
auto rank = dy->dims().size();
std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length;
phi::funcs::PaddingFunctor<DeviceContext, C>(
rank, ctx.template device_context<DeviceContext>(), pads,
static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames, axes, normalization,
false);
}
framework::TransComplexToReal(
framework::TransToProtoVarType(d_frames.dtype()),
framework::TransToProtoVarType(complex_d_frames.dtype()),
complex_d_frames, &d_frames);
// d_frames -> dx
FrameFunctor<DeviceContext, T>()(dev_ctx, &d_frames, dx, seq_length, n_fft,
n_frames, hop_length, /*is_grad*/ true);
}
};
} // namespace operators
} // namespace paddle
......@@ -159,10 +159,8 @@ inline void EmplaceDeviceContext(
cuda_ctx,
platform::errors::InvalidArgument(
"Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
// Note: A trick method to init context, why GetAllocator interface
// needs a stream parameter?
dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p, cuda_ctx->stream())
.GetAllocator(p)
.get());
cuda_ctx->PartialInitWithAllocator();
dev_ctx->SetGenerator(
......@@ -517,10 +515,10 @@ CUDAContext::~CUDAContext() {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
phi::GPUContext::PartialInitWithoutAllocator();
cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
workspace_.reset(new phi::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place, phi::GPUContext::stream())
.get()));
auto& instance = memory::allocation::AllocatorFacade::Instance();
instance.SetDefaultStream(place, phi::GPUContext::stream());
workspace_.reset(
new phi::DnnWorkspaceHandle(instance.GetAllocator(place).get()));
}
CUDADeviceContext::~CUDADeviceContext() = default;
......@@ -618,7 +616,7 @@ phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
// return workspace_.get();
return phi::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(GetPlace(), phi::GPUContext::stream())
.GetAllocator(GetPlace())
.get());
}
return phi::GPUContext::cudnn_workspace_handle();
......
......@@ -118,8 +118,9 @@ float CpuUtilization::GetCpuUtilization() {
float busy_time = (system_kernel_time_end - system_kernel_time_start) +
(system_user_time_end - system_user_time_start);
float idle_time = system_idle_time_end - system_idle_time_start;
cpu_utilization = busy_time / (busy_time + idle_time);
if (busy_time + idle_time != 0) {
cpu_utilization = busy_time / (busy_time + idle_time);
}
#elif defined(__linux__)
float busy_time = (system_tms_end_.tms_utime - system_tms_start_.tms_utime) +
(system_tms_end_.tms_stime - system_tms_start_.tms_stime) +
......@@ -127,7 +128,9 @@ float CpuUtilization::GetCpuUtilization() {
(irq_end_ - irq_start_) + (softirq_end_ - softirq_start_) +
(steal_end_ - steal_start_);
float idle_time = (idle_end_ - idle_start_) + (iowait_end_ - iowait_start_);
cpu_utilization = busy_time / (busy_time + idle_time);
if (busy_time + idle_time != 0) {
cpu_utilization = busy_time / (busy_time + idle_time);
}
#else
LOG(WARNING)
<< "Current System is not supported to get system cpu utilization"
......@@ -148,13 +151,16 @@ float CpuUtilization::GetCpuCurProcessUtilization() {
uint64_t end = FileTimeToUint64(end_);
float busy_time = (process_kernel_time_end - process_kernel_time_start) +
(process_user_time_end - process_user_time_start);
cpu_process_utilization = busy_time / (end - start);
LOG(INFO) << "Process Utilization = " << cpu_process_utilization << std::endl;
if (end - start != 0) {
cpu_process_utilization = busy_time / (end - start);
}
#elif defined(__linux__)
float busy_time =
(process_tms_end_.tms_utime - process_tms_start_.tms_utime) +
(process_tms_end_.tms_stime - process_tms_start_.tms_stime);
cpu_process_utilization = busy_time / (end_ - start_);
if (end_ - start_ != 0) {
cpu_process_utilization = busy_time / (end_ - start_);
}
#else
LOG(WARNING)
<< "Current System is not supported to get process cpu utilization"
......
......@@ -44,6 +44,14 @@ std::unique_ptr<Profiler> Profiler::Create(const ProfilerOptions& options) {
return std::unique_ptr<Profiler>(new Profiler(options));
}
bool Profiler::IsCuptiSupported() {
bool supported = false;
#ifdef PADDLE_WITH_CUPTI
supported = true;
#endif
return supported;
}
Profiler::Profiler(const ProfilerOptions& options) {
options_ = options;
std::bitset<32> trace_switch(options_.trace_switch);
......
......@@ -43,6 +43,8 @@ class Profiler {
public:
static std::unique_ptr<Profiler> Create(const ProfilerOptions& options);
static bool IsCuptiSupported();
void Prepare();
void Start();
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/cupti.h"
namespace paddle {
namespace platform {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <ctime>
#include <string>
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/os_info.h"
......
此差异已折叠。
......@@ -52,6 +52,12 @@ PyObject* tensor_properties_get_type(TensorObject* self, void* closure) {
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyObject* tensor_properties_is_leaf(TensorObject* self, void* closure) {
EAGER_TRY
return ToPyObject(egr::egr_utils_api::IsLeafTensor(self->tensor));
EAGER_CATCH_AND_THROW_RETURN_NULL
}
int tensor_properties_set_name(TensorObject* self, PyObject* value,
void* closure) {
EAGER_TRY
......@@ -179,6 +185,7 @@ struct PyGetSetDef variable_properties[] = {
nullptr},
{"dtype", (getter)tensor_properties_get_dtype, nullptr, nullptr, nullptr},
{"type", (getter)tensor_properties_get_type, nullptr, nullptr, nullptr},
{"is_leaf", (getter)tensor_properties_is_leaf, nullptr, nullptr, nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}};
} // namespace pybind
......
......@@ -386,46 +386,6 @@ GetVarBaseListFromPyHandle(const py::handle &handle) {
return result;
}
// cast numpy type form S to T, this may allocate new memory
template <class T, class S>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
if (std::is_same<T, S>::value) {
return array;
}
auto dim = array.ndim();
std::vector<py::ssize_t> result_shape(dim);
for (auto i = 0; i < dim; i++) {
result_shape[i] = array.shape(i);
}
py::array_t<T> result(result_shape);
return py::vectorize([](S s) { return static_cast<T>(s); })(array);
}
template <class T>
static py::array_t<T> CastNumpyArray(const py::object &array) {
if (py::isinstance<py::array_t<float>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<float>>());
} else if (py::isinstance<py::array_t<double>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<double>>());
} else if (py::isinstance<py::array_t<int32_t>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<int32_t>>());
} else if (py::isinstance<py::array_t<int64_t>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<int64_t>>());
} else if (py::isinstance<py::array_t<bool>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<bool>>());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Value type error. The assign numpy value allows integer, float, "
"double and bool, "
"but received %s.",
Py_TYPE(array.ptr())->tp_name));
}
// can't reach here
return py::array_t<T>();
}
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
const PyNameVarBaseMap &map) {
imperative::NameVarBaseMap result;
......@@ -854,27 +814,29 @@ void BindImperative(py::module *m_ptr) {
py::object value = value_obj;
if (self->DataType() == framework::proto::VarType::FP32) {
if (!py::isinstance<py::array_t<float>>(value_obj)) {
value = CastNumpyArray<float>(value_obj);
value = pybind11::detail::CastNumpyArray<float>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::FP64) {
if (!py::isinstance<py::array_t<double>>(value_obj)) {
value = CastNumpyArray<double>(value_obj);
value = pybind11::detail::CastNumpyArray<double>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::INT32) {
if (!py::isinstance<py::array_t<int32_t>>(value_obj)) {
value = CastNumpyArray<int32_t>(value_obj);
value =
pybind11::detail::CastNumpyArray<int32_t>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::INT64) {
if (!py::isinstance<py::array_t<int64_t>>(value_obj)) {
value = CastNumpyArray<int64_t>(value_obj);
value =
pybind11::detail::CastNumpyArray<int64_t>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::BOOL) {
if (!py::isinstance<py::array_t<bool>>(value_obj)) {
value = CastNumpyArray<bool>(value_obj);
value = pybind11::detail::CastNumpyArray<bool>(value_obj);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......
......@@ -38,7 +38,15 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"assign", {"X"}},
{"reshape2", {"X", "Shape"}},
{"expand", {"X", "ExpandTimes"}},
{"slice", {"Input", "StartsTensor", "EndsTensor"}},
{"slice",
{"Input", "StartsTensor", "EndsTensor", "StartsTensorList",
"EndsTensorList"}},
{"strided_slice",
{"Input", "StartsTensor", "EndsTensor", "StridesTensor",
"StartsTensorList", "EndsTensorList", "StridesTensorList"}},
{"set_value",
{"Input", "ValueTensor", "StartsTensorList", "EndsTensorList",
"StepsTensorList"}},
{"fake_quantize_dequantize_moving_average_abs_max",
{"X", "InScale", "InAccum", "InState"}},
{"nll_loss", {"X", "Label", "Weight"}},
......@@ -89,6 +97,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"Input", "Label", "Weight", "Bias", "SampleWeight", "CustomDistProbs",
"CustomDistAlias", "CustomDistAliasProbs"}},
{"check_finite_and_unscale", {"X", "Scale", "FloatStatus"}},
{"group_norm", {"X", "Scale", "Bias"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......
......@@ -3322,6 +3322,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<paddle::platform::Profiler>(m, "_Profiler")
.def("create", &paddle::platform::Profiler::Create,
py::return_value_policy::take_ownership)
.def("is_cupti_supported", &paddle::platform::Profiler::IsCuptiSupported)
.def("prepare",
[](paddle::platform::Profiler *profiler) {
platform::EnableHostEventRecorder();
......
......@@ -52,6 +52,46 @@ constexpr int NPY_UINT16_ = 4;
constexpr int NPY_COMPLEX64 = 14;
constexpr int NPY_COMPLEX128 = 15;
// cast numpy type form S to T, this may allocate new memory
template <class T, class S>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
if (std::is_same<T, S>::value) {
return array;
}
auto dim = array.ndim();
std::vector<py::ssize_t> result_shape(dim);
for (auto i = 0; i < dim; i++) {
result_shape[i] = array.shape(i);
}
py::array_t<T> result(result_shape);
return py::vectorize([](S s) { return static_cast<T>(s); })(array);
}
template <class T>
static py::array_t<T> CastNumpyArray(const py::object &array) {
if (py::isinstance<py::array_t<float>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<float>>());
} else if (py::isinstance<py::array_t<double>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<double>>());
} else if (py::isinstance<py::array_t<int32_t>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<int32_t>>());
} else if (py::isinstance<py::array_t<int64_t>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<int64_t>>());
} else if (py::isinstance<py::array_t<bool>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<bool>>());
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Value type error. The assign numpy value allows integer, float, "
"double and bool, "
"but received %s.",
Py_TYPE(array.ptr())->tp_name));
}
// can't reach here
return py::array_t<T>();
}
// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776
......
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/pad3d_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
......@@ -574,5 +575,13 @@ void Pad3dKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
pad3d, CPU, ALL_LAYOUT, phi::Pad3dKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(pad3d,
CPU,
ALL_LAYOUT,
phi::Pad3dKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -50,11 +50,15 @@ struct exponential_transform {
HOSTDEVICE inline T operator()(T val) const {
#if defined(__NVCC__) || defined(__HIPCC__)
if (std::is_same<T, double>::value) {
return static_cast<T>(-1.0) / lambda_ * log(val);
} else {
return static_cast<T>(-1.0) / lambda_ * __logf(val);
T log = -std::numeric_limits<T>::epsilon() / 2;
if (val < static_cast<T>(1.) - std::numeric_limits<T>::epsilon() / 2) {
if (std::is_same<T, double>::value) {
log = logf(val);
} else {
log = __logf(val);
}
}
return static_cast<T>(-1.0) / lambda_ * log;
#else
return static_cast<T>(-1.0) / lambda_ * std::log(static_cast<T>(1.0) - val);
#endif
......@@ -114,13 +118,19 @@ struct normal_transform {
namespace kps = phi::kps;
/*********************** Distribution Function *************************/
template <typename T>
struct uniform_distribution;
template <typename T>
struct normal_distribution;
#if defined(__NVCC__)
template <typename T>
struct uniform_distribution {
__device__ inline T operator()(curandStatePhilox4_32_10_t *state) const {
return static_cast<T>(curand_uniform(state));
}
static constexpr int kReturnsCount = 1;
};
template <>
struct uniform_distribution<float> {
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
......@@ -177,6 +187,14 @@ struct normal_distribution<double> {
};
#else
template <typename T>
struct uniform_distribution {
__device__ inline T operator()(hiprandStatePhilox4_32_10_t *state) const {
return hiprand_uniform(state);
}
static constexpr int kReturnsCount = 1;
};
template <>
struct uniform_distribution<float> {
__device__ inline float4 operator()(
......
此差异已折叠。
......@@ -19,6 +19,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
......@@ -585,4 +586,6 @@ PD_REGISTER_KERNEL(pad3d,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册