未验证 提交 859fc01b 编写于 作者: W wanghuancoder 提交者: GitHub

Support stride2 (#55156)

support stride
上级 ddc6feab
......@@ -83,7 +83,7 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
set(kernel_declare_id "")
......@@ -115,13 +115,26 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
is_all_backend
"${first_registry}")
if(NOT is_all_backend STREQUAL "")
# parse the registerd kernel message
string(
REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM("
"" kernel_msg "${first_registry}")
else()
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
is_all_backend
"${first_registry}")
# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(" ""
kernel_msg "${first_registry}")
# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(" ""
kernel_msg "${first_registry}")
endif()
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${kernel_msg}")
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_DTYPE(" "" kernel_msg
"${kernel_msg}")
......@@ -146,7 +159,7 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
endif()
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/collective/reducer.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/flags.h"
......@@ -835,8 +836,18 @@ void EagerReducer::MarkVarReady(const size_t var_index,
const auto length = group.length_[inside_group_index];
if (is_used_var) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor =
paddle::Tensor grad_tensor =
static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
if (grad_tensor.is_dense_tensor()) {
const auto &tensor_impl = grad_tensor.impl();
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_impl);
if (!dense_tensor->meta().is_contiguous()) {
grad_tensor.set_impl(std::make_shared<phi::DenseTensor>(std::move(
paddle::experimental::Trans2Contiguous(*dense_tensor))));
}
}
group_tensor
.ShareDataWith(*(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
......@@ -851,6 +862,17 @@ void EagerReducer::MarkVarReady(const size_t var_index,
if (HasGrad(var_index)) {
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
if (grad_tensor->is_dense_tensor()) {
const auto &tensor_impl = grad_tensor->impl();
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_impl);
if (!dense_tensor->meta().is_contiguous()) {
grad_tensor->set_impl(std::make_shared<phi::DenseTensor>(std::move(
paddle::experimental::Trans2Contiguous(*dense_tensor))));
}
}
group_tensor
.ShareDataWith(*(std::dynamic_pointer_cast<phi::DenseTensor>(
grad_tensor->impl())))
......
......@@ -287,6 +287,7 @@ class EagerVariable final {
auto* framework_tensor = var_.GetMutable<VarType>();
// Contruct phi::DenseTensor from egr::EagerVariable
auto tensor_dense = std::dynamic_pointer_cast<VarType>(tensor.impl());
PADDLE_ENFORCE_EQ(
(tensor_dense.get() && tensor_dense),
true,
......@@ -296,6 +297,12 @@ class EagerVariable final {
"treat all kinds of tensor as what they are.",
tensor.name()));
*framework_tensor = *tensor_dense;
if (tensor.is_dense_tensor()) {
dynamic_cast<phi::DenseTensor*>(framework_tensor)
->set_strides(
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_dense)
->strides());
}
}
template <typename VarType>
......
......@@ -350,7 +350,8 @@ cc_library(
selected_rows_utils
data_device_transform
data_type_transform
data_layout_transform)
data_layout_transform
phi)
cc_library(
attribute
......
......@@ -310,6 +310,7 @@ static void RunKernelFunc(
true_out_meta->dtype = calc_out->dtype();
true_out_meta->layout = calc_out->layout();
true_out_meta->offset = calc_out->offset();
true_out_meta->strides = true_out_meta->calc_strides(true_out_meta->dims);
// lod no need to be reset
// reset holder if needed
if (true_out->Holder() != calc_out->Holder()) {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_device_transform.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/phi/api/lib/data_transform.h"
namespace paddle {
namespace framework {
......@@ -47,6 +48,13 @@ void TransformData(const phi::KernelKey &expected_kernel_type,
phi::DenseTensor out;
const DataLayout lin = kernel_type_for_var.layout();
const DataLayout lout = expected_kernel_type.layout();
if (NeedTransform2Contiguous(in.meta().is_contiguous())) {
out = paddle::experimental::Trans2Contiguous(in);
transformed = true;
PassTensorData(&out, &in);
}
// do layout transform
if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -290,7 +290,9 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
if (var == nullptr) return;
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
auto meta = phi::DenseTensorUtils::GetMutableMeta(tensor);
meta->dims = dims;
meta->strides = meta->calc_strides(dims);
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dims[0]);
} else if (var->IsType<phi::SparseCooTensor>()) {
......@@ -355,10 +357,12 @@ void CompatMetaTensor::set_layout(DataLayout layout) {
if (var == nullptr) return;
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
auto meta = phi::DenseTensorUtils::GetMutableMeta(tensor);
meta->layout = layout;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
auto meta = phi::DenseTensorUtils::GetMutableMeta(tensor);
meta->layout = layout;
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
......@@ -424,8 +428,10 @@ void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
selected_rows->set_rows(input_selected_rows.rows());
selected_rows->set_height(input_selected_rows.height());
phi::DenseTensorUtils::GetMutableMeta(selected_rows->mutable_value())
->dims = input_selected_rows.value().dims();
auto meta =
phi::DenseTensorUtils::GetMutableMeta(selected_rows->mutable_value());
meta->dims = input_selected_rows.value().dims();
meta->strides = meta->calc_strides(meta->dims);
}
}
}
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/kernels/concat_kernel.h"
namespace paddle {
......@@ -553,6 +554,9 @@ void MultiEncoderXPUFusePass::PrepareQKVWeight(Graph* graph,
qkv_w_int16_t.set_type(q_w_fp32_t.type());
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
paddle::experimental::CheckAndTrans2Contiguous(&q_w_fp32_t);
paddle::experimental::CheckAndTrans2Contiguous(&k_w_fp32_t);
paddle::experimental::CheckAndTrans2Contiguous(&v_w_fp32_t);
std::vector<const phi::DenseTensor*> in_tensors{
&q_w_fp32_t, &k_w_fp32_t, &v_w_fp32_t};
phi::ConcatKernel<float>(*cpu_ctx, in_tensors, 0, &qkv_w_int16_t);
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/backends/xpu/xpu_info.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/assign_kernel.h"
......@@ -31,10 +32,14 @@ void Assign(const phi::DenseTensor& in, phi::DenseTensor* out) {
out->Resize(in.dims());
out->set_type(in.dtype());
out->set_layout(in.layout());
paddle::experimental::CheckAndTrans2Contiguous(
const_cast<phi::DenseTensor*>(&in));
phi::AssignKernel(*cpu_ctx, in, out);
}
void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) {
paddle::experimental::CheckAndTrans2Contiguous(in);
auto in_dims = in->dims();
PADDLE_ENFORCE_EQ(
in_dims.size(),
......@@ -108,6 +113,8 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
paddle::experimental::CheckAndTrans2Contiguous(in);
phi::DenseTensor fp32_tensor;
phi::DenseTensor* out_ptr = out == nullptr ? &fp32_tensor : out;
out_ptr->Resize(in->dims());
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/library_type.h"
......@@ -25,6 +26,8 @@ limitations under the License. */
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_factory.h"
DECLARE_bool(use_stride_kernel);
namespace paddle {
namespace framework {
......@@ -149,6 +152,10 @@ inline bool NeedTransformBackend(const phi::Backend& type_for_var_backend,
return !backends_are_same_class(type_for_var_backend, expected_backend);
}
inline bool NeedTransform2Contiguous(bool is_contiguous) {
return FLAGS_use_stride_kernel && !is_contiguous;
}
inline bool NeedTransform(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_key,
const phi::DenseTensor& tensor) {
......@@ -157,7 +164,8 @@ inline bool NeedTransform(const phi::KernelKey& kernel_type_for_var,
tensor) ||
NeedTransformDataType(kernel_type_for_var, expected_kernel_key) ||
NeedTransformLayout(kernel_type_for_var.layout(),
expected_kernel_key.layout());
expected_kernel_key.layout()) ||
NeedTransform2Contiguous(tensor.meta().is_contiguous());
}
} // namespace framework
......
......@@ -287,12 +287,14 @@ void TensorCopy(const phi::DenseTensor& src,
const platform::Place& dst_place,
phi::DenseTensor* dst) {
TensorCopyImpl<phi::DenseTensor>(src, dst_place, dst);
dst->set_strides(src.strides());
}
void TensorCopy(const phi::DenseTensor& src,
const platform::Place& dst_place,
const platform::DeviceContext& ctx,
phi::DenseTensor* dst) {
TensorCopyImpl<phi::DenseTensor>(src, dst_place, ctx, dst);
dst->set_strides(src.strides());
}
void TensorCopySync(const phi::DenseTensor& src,
......@@ -447,6 +449,7 @@ void TensorCopySync(const phi::DenseTensor& src,
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
dst->set_strides(src.strides());
}
void TensorToStream(std::ostream& os,
......
......@@ -74,7 +74,8 @@ cc_library(
garbage_collector
var_helper
layout_autotune
ops_extra_info)
ops_extra_info
phi)
cc_library(
basic_engine
SRCS basic_engine.cc
......
......@@ -37,6 +37,7 @@
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
namespace paddle {
......@@ -166,6 +167,10 @@ void TensorAdd(const VarType& src, VarType* dst) {
phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
const phi::DenseTensor& src_tensor = GetInnerTensor<phi::DenseTensor>(src);
paddle::experimental::CheckAndTrans2Contiguous(
const_cast<phi::DenseTensor*>(&src_tensor));
paddle::experimental::CheckAndTrans2Contiguous(dst_tensor);
auto numel = src_tensor.numel();
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
......@@ -298,6 +303,11 @@ void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
phi::DenseTensor* dst_tensor = GetInnerMutableTensor<phi::DenseTensor>(dst);
const phi::SelectedRows& src_selected_rows =
GetInnerTensor<phi::SelectedRows>(src);
paddle::experimental::CheckAndTrans2Contiguous(
const_cast<phi::SelectedRows*>(&src_selected_rows)->mutable_value());
paddle::experimental::CheckAndTrans2Contiguous(dst_tensor);
auto place = dst_tensor->place();
auto data_type =
framework::TransToProtoVarType(src_selected_rows.value().dtype());
......@@ -345,6 +355,10 @@ void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
GetInnerTensor<phi::SelectedRows>(src_selected_rows_var);
const phi::DenseTensor& src_tensor =
GetInnerTensor<phi::DenseTensor>(src_tensor_var);
paddle::experimental::CheckAndTrans2Contiguous(
const_cast<phi::SelectedRows*>(&src_selected_rows)->mutable_value());
const auto& place = src_tensor.place();
auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
......@@ -402,6 +416,11 @@ std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
const phi::SelectedRows& src_selected_rows2 =
GetInnerTensor<phi::SelectedRows>(src2);
paddle::experimental::CheckAndTrans2Contiguous(
const_cast<phi::SelectedRows*>(&src_selected_rows1)->mutable_value());
paddle::experimental::CheckAndTrans2Contiguous(
const_cast<phi::SelectedRows*>(&src_selected_rows2)->mutable_value());
auto place = src_selected_rows1.value().place();
auto data_type =
framework::TransToProtoVarType(src_selected_rows1.value().dtype());
......
......@@ -31,8 +31,9 @@ void SetOutDataLayout(std::shared_ptr<VarType> var,
if (var->MutableVar()->IsInitialized()) {
paddle::framework::Variable* tmp_var = var->MutableVar();
auto* out = tmp_var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(static_cast<phi::DenseTensor*>(out))
->layout = layout;
auto meta = phi::DenseTensorUtils::GetMutableMeta(
static_cast<phi::DenseTensor*>(out));
meta->layout = layout;
}
}
}
......
......@@ -18,6 +18,7 @@
#include <unordered_set>
#include <utility>
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/execution_context.h"
......@@ -29,12 +30,15 @@
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(use_mkldnn);
PHI_DECLARE_string(tracer_mkldnn_ops_on);
PHI_DECLARE_string(tracer_mkldnn_ops_off);
DECLARE_bool(use_stride_kernel);
namespace paddle {
namespace imperative {
......@@ -222,12 +226,14 @@ void Tracer::TraceOpImpl(const std::string& type,
attrs["use_mkldnn"] = !is_off;
}
}
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker();
if (attr_checker) {
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
}
const auto& extra_attr_checkers =
operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type);
for (const auto& checker : extra_attr_checkers) {
......@@ -296,6 +302,7 @@ void Tracer::TraceOpImpl(const std::string& type,
"CustomPlace."));
#endif
}
if (!use_default_attr_map) {
PADDLE_ENFORCE_NOT_NULL(passed_default_attrs_,
paddle::platform::errors::PermissionDenied(
......@@ -400,15 +407,53 @@ void Tracer::TraceOp(const std::string& type,
const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: "
<< use_default_attr_map;
TraceOpImpl<egr::EagerVariable>(type,
ins,
outs,
attrs,
place,
false,
inplace_map,
default_attrs,
use_default_attr_map);
std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs;
if (FLAGS_use_stride_kernel) {
for (auto& iter : inplace_map) {
auto inputs_iter = ins.find(iter.first);
for (size_t i = 0; i < inputs_iter->second.size(); i++) {
auto var = inputs_iter->second[i]->MutableVar();
if (var->IsType<phi::DenseTensor>()) {
auto dense_tensor = var->GetMutable<phi::DenseTensor>();
if (!dense_tensor->meta().is_contiguous()) {
NameTensorMap* tmp_out = const_cast<NameTensorMap*>(&outs);
auto outputs_iter = tmp_out->find(iter.second);
outputs_iter->second[i] = std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName());
need_backup_inputs2outputs[dense_tensor] =
outputs_iter->second[i]
->MutableVar()
->GetMutable<phi::DenseTensor>();
}
}
}
}
TraceOpImpl<egr::EagerVariable>(type,
ins,
outs,
attrs,
place,
false,
{},
default_attrs,
use_default_attr_map);
auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
for (auto& iter : need_backup_inputs2outputs) {
paddle::experimental::TransStride(dev_ctx, iter.second, iter.first);
}
} else {
TraceOpImpl<egr::EagerVariable>(type,
ins,
outs,
attrs,
place,
false,
inplace_map,
default_attrs,
use_default_attr_map);
}
}
void Tracer::TraceOp(const std::string& type,
......@@ -426,15 +471,40 @@ void Tracer::TraceOp(const std::string& type,
paddle::framework::AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp(less): ";
TraceOpImpl<egr::EagerVariable>(type,
ins,
outs,
attrs,
expected_place_,
false,
inplace_map,
nullptr,
true);
std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs;
if (FLAGS_use_stride_kernel) {
for (auto& iter : inplace_map) {
auto inputs_iter = ins.find(iter.first);
for (size_t i = 0; i < inputs_iter->second.size(); i++) {
auto var = inputs_iter->second[i]->MutableVar();
if (var->IsType<phi::DenseTensor>()) {
auto dense_tensor = var->GetMutable<phi::DenseTensor>();
if (!dense_tensor->meta().is_contiguous()) {
NameTensorMap* tmp_out = const_cast<NameTensorMap*>(&outs);
auto outputs_iter = tmp_out->find(iter.second);
outputs_iter->second[i] = std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName());
need_backup_inputs2outputs[dense_tensor] =
outputs_iter->second[i]
->MutableVar()
->GetMutable<phi::DenseTensor>();
}
}
}
}
} else {
TraceOpImpl<egr::EagerVariable>(type,
ins,
outs,
attrs,
expected_place_,
false,
inplace_map,
nullptr,
true);
}
}
void Tracer::SetExpectedPlace(platform::Place place) {
......
......@@ -31,6 +31,8 @@
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/compat/arg_map_context.h"
DECLARE_bool(use_stride_kernel);
namespace paddle {
namespace imperative {
......@@ -195,7 +197,16 @@ class Tracer {
void DisableLayoutAutoTune() { use_layout_autotune_ = false; }
void EnableLayoutAutoTune() { use_layout_autotune_ = true; }
void EnableLayoutAutoTune() {
use_layout_autotune_ = true;
if (FLAGS_use_stride_kernel) {
LOG(WARNING) << "When the layout_autotune policy is on, Paddle will turn "
"off the Stride policy. This will cause the input and "
"output of the Strided API no longer share memory, which "
"may cause problems with model accuracy.";
FLAGS_use_stride_kernel = false;
}
}
bool UseLayoutAutoTune() {
#if defined(PADDLE_WITH_CUDA)
......
......@@ -136,6 +136,10 @@ class FeedOp : public framework::OperatorWithKernel {
meta.dtype = feed_tensor.dtype();
meta.layout = feed_tensor.layout();
meta.lod = feed_tensor.lod();
meta.strides = feed_tensor.strides();
if (meta.strides.size() == -1) {
meta.strides = meta.calc_strides(meta.dims);
}
out_tensor->set_meta(meta);
} else if (feed_item.index() == 1) { // Strings
auto& feed_str = PADDLE_GET_CONST(framework::Strings, feed_item);
......
......@@ -88,8 +88,6 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out = *scope.FindVar(Output("Out"))->GetMutable<phi::DenseTensor>();
auto &out_inx =
*scope.FindVar(Output("OutIndex"))->GetMutable<phi::DenseTensor>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(
......@@ -102,23 +100,18 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
// get the input tensorarray items' dim in out_inx
auto out_inx_dim = out_inx.dims();
out_inx_dim[0] = inx.size();
out_inx.Resize(out_inx_dim);
int *tmp_index_data = out_inx.mutable_data<int>(platform::CPUPlace());
auto out_dims = inx[0].dims();
size_t out_dim_sum = 0;
for (size_t index = 0; index < inx.size(); index++) {
auto inx_dims = inx[index].dims();
out_dim_sum += inx_dims[axis];
tmp_index_data[index] = inx_dims[axis];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == static_cast<size_t>(axis)) {
out_dims[axis] += inx[i].dims()[j];
}
}
}
// get input array items' dims
out_dims[axis] = out_dim_sum;
out.Resize(out_dims);
auto vec = phi::vectorize<int>(out_dims);
vec.insert(vec.begin() + axis, inx.size());
out.Resize(phi::make_ddim(vec));
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
......
......@@ -58,6 +58,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -67,6 +68,7 @@ typedef SSIZE_T ssize_t;
#endif
PHI_DECLARE_bool(set_to_1d);
DECLARE_bool(use_stride_kernel);
namespace paddle {
namespace pybind {
......@@ -155,6 +157,14 @@ static PyObject* tensor_method_numpy(TensorObject* self,
py_dims[0] = 1;
py_strides[0] = sizeof_dtype * numel;
}
} else if (self->tensor.is_dense_tensor()) {
auto tensor_stride = self->tensor.strides();
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
py_dims[i] = static_cast<size_t>(tensor_dims[i]);
py_strides[i] = sizeof_dtype * tensor_stride[i];
numel *= py_dims[i];
}
} else {
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
py_dims[i] = static_cast<size_t>(tensor_dims[i]);
......@@ -163,18 +173,18 @@ static PyObject* tensor_method_numpy(TensorObject* self,
}
}
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_,
api.PyArray_DescrFromType_(numpy_dtype),
py_rank,
py_dims,
py_strides,
nullptr,
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
nullptr);
if (!self->tensor.impl()->initialized()) {
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_,
api.PyArray_DescrFromType_(numpy_dtype),
py_rank,
py_dims,
py_strides,
nullptr,
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
nullptr);
if (tensor_dims.empty()) {
py_dims[0] = 0;
py_strides[0] = 0;
......@@ -193,6 +203,9 @@ static PyObject* tensor_method_numpy(TensorObject* self,
return array;
}
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
eager_gil_scoped_release guard;
platform::CPUPlace place;
......@@ -202,25 +215,32 @@ static PyObject* tensor_method_numpy(TensorObject* self,
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
auto* dense_tensor =
static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
// deep copy
paddle::memory::Copy(
place,
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
place,
dense_tensor->data(),
sizeof_dtype * numel);
paddle::memory::Copy(place,
cpu_tensor.Holder()->ptr(),
place,
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size());
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
// deep copy
paddle::memory::Copy(
place,
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
place,
dense_tensor->data(),
sizeof_dtype * numel);
paddle::memory::Copy(place,
cpu_tensor.Holder()->ptr(),
place,
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -237,20 +257,28 @@ static PyObject* tensor_method_numpy(TensorObject* self,
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
auto* dense_tensor =
static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
paddle::platform::GpuMemcpySync(
pybind11::detail::array_proxy(array)->data,
dense_tensor->data(),
phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
kind);
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size(),
kind);
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
paddle::platform::GpuMemcpySync(
pybind11::detail::array_proxy(array)->data,
dense_tensor->data(),
phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
kind);
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size(),
kind);
}
#endif
#if defined(PADDLE_WITH_XPU)
......@@ -262,22 +290,30 @@ static PyObject* tensor_method_numpy(TensorObject* self,
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
auto* dense_tensor =
static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
paddle::memory::Copy(
place,
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
dense_tensor->place(),
dense_tensor->data(),
sizeof_dtype * numel);
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::memory::Copy(place,
cpu_tensor.Holder()->ptr(),
dense_tensor->place(),
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size());
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
paddle::memory::Copy(
place,
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
dense_tensor->place(),
dense_tensor->data(),
sizeof_dtype * numel);
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::memory::Copy(place,
cpu_tensor.Holder()->ptr(),
dense_tensor->place(),
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size());
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
......@@ -289,11 +325,15 @@ static PyObject* tensor_method_numpy(TensorObject* self,
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
auto* dense_tensor =
static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
->MemoryCopyD2H(
pybind11::detail::array_proxy(array)->data,
dense_tensor->data(),
phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size());
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
......@@ -306,11 +346,15 @@ static PyObject* tensor_method_numpy(TensorObject* self,
dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(temp_tensor.impl());
}
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
->MemoryCopyD2H(
pybind11::detail::array_proxy(array)->data,
dense_tensor->data(),
phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size());
}
#endif
} else {
......@@ -319,6 +363,26 @@ static PyObject* tensor_method_numpy(TensorObject* self,
RETURN_PY_NONE
}
void* array_buffer = cpu_tensor.Holder()->ptr();
size_t array_offset = cpu_tensor.offset();
PyObject* base = ToPyObject(paddle::Tensor(
std::make_shared<phi::DenseTensor>(std::move(cpu_tensor))));
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_,
api.PyArray_DescrFromType_(numpy_dtype),
py_rank,
py_dims,
py_strides,
reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(array_buffer) +
array_offset),
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
nullptr);
api.PyArray_SetBaseObject_(array, base);
return array;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......@@ -786,6 +850,25 @@ static PyObject* tensor_method_detach(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_detach_(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE_EQ(
self->tensor.defined(),
true,
platform::errors::InvalidArgument("Tensor %s has not been initialized!",
self->tensor.name()));
auto autograd_meta = std::make_shared<egr::AutogradMeta>();
autograd_meta->SetPersistable(
egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
self->tensor.set_autograd_meta(autograd_meta);
return reinterpret_cast<PyObject*>(self);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
......@@ -864,7 +947,8 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
PyObject* _index = PyTuple_GET_ITEM(args, 0);
VLOG(4) << "Call _getitem_index_not_tensor";
std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
decrease_axis, none_axes, infer_flags, list_select_idxs;
decrease_axis, none_axes, infer_flags;
std::vector<int64_t> list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
// Note(0x45f): Using defined() instead of initialized()
......@@ -985,16 +1069,19 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
// the index is a list
if (list_select_flag) {
eager_gil_scoped_release guard;
auto select_index =
paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
auto idx_tensor = std::make_shared<phi::DenseTensor>();
select_index.set_impl(idx_tensor);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
egr::Controller::Instance().GetExpectedPlace());
paddle::framework::TensorFromVector(
list_select_idxs, *dev_ctx, idx_tensor.get());
framework::AttributeMap attrs = {{"dim", 0}};
out = index_select_ad_func(self->tensor, select_index, 0);
if (FLAGS_use_stride_kernel && list_select_idxs.size() == 1) {
out = index_select_strided_ad_func(self->tensor, list_select_idxs[0], 0);
} else {
auto select_index =
paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
auto idx_tensor = std::make_shared<phi::DenseTensor>();
select_index.set_impl(idx_tensor);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
egr::Controller::Instance().GetExpectedPlace());
paddle::framework::TensorFromVector(
list_select_idxs, *dev_ctx, idx_tensor.get());
out = index_select_ad_func(self->tensor, select_index, 0);
}
}
return ToPyObject(out);
......@@ -1027,11 +1114,10 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self,
const auto& tensor_dims = tensor.dims();
std::vector<size_t> dims(tensor_dims.size());
std::vector<size_t> strides(tensor_dims.size());
std::vector<size_t> stride = phi::vectorize<size_t>(tensor.strides());
size_t numel = 1;
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
strides[i] = numel;
dims[i] = static_cast<size_t>(tensor_dims[i]);
numel *= dims[i];
}
......@@ -1065,7 +1151,7 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self,
index,
i,
dims[i]));
offset += index * strides[i];
offset += index * stride[i];
}
}
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
......@@ -1157,7 +1243,8 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
// copys data to cpu place, which reduces performance.
if (parse_index) {
std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
infer_flags, list_select_idxs;
infer_flags;
std::vector<int64_t> list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
ParseIndexingSlice(self_tensor,
......@@ -1994,6 +2081,62 @@ static PyObject* tensor__grad_ivar(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_strides(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
std::vector<int64_t> value;
if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) {
return ToPyObject(value);
}
auto stride = self->tensor.strides();
size_t rank = static_cast<size_t>(stride.size());
value.resize(rank);
for (size_t i = 0; i < rank; i++) {
value[i] = stride[i];
}
return ToPyObject(value);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_contiguous(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (self->tensor.is_dense_tensor()) {
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
if (dense_tensor->meta().is_contiguous()) {
Py_INCREF(self);
return reinterpret_cast<PyObject*>(self);
} else {
eager_gil_scoped_release guard;
return ToPyObject(
paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(
paddle::experimental::Trans2Contiguous(*(dense_tensor.get()))))));
}
} else {
Py_INCREF(self);
return reinterpret_cast<PyObject*>(self);
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_is_contiguous(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (self->tensor.is_dense_tensor()) {
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
return ToPyObject(dense_tensor->meta().is_contiguous());
} else {
return ToPyObject(true);
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self,
PyObject* args,
......@@ -2102,6 +2245,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)())tensor_method_detach,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"detach_",
(PyCFunction)(void (*)(void))tensor_method_detach_,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"get_tensor",
(PyCFunction)(void (*)())tensor_method_get_underline_tensor,
METH_VARARGS | METH_KEYWORDS,
......@@ -2265,6 +2412,18 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)())tensor__grad_ivar,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"contiguous",
(PyCFunction)(void (*)(void))tensor_contiguous,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_contiguous",
(PyCFunction)(void (*)(void))tensor_is_contiguous,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"get_strides",
(PyCFunction)(void (*)(void))tensor_method_strides,
METH_VARARGS | METH_KEYWORDS,
NULL},
#if defined(PADDLE_WITH_CUDA)
{"_tensor_uva",
(PyCFunction)(void (*)())tensor_method__uva,
......
......@@ -97,6 +97,27 @@ PyObject* tensor_properties_get_stop_gradient(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyObject* tensor_properties_get_data(TensorObject* self, void* closure) {
EAGER_TRY
return reinterpret_cast<PyObject*>(self);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
int tensor_properties_set_data(TensorObject* self,
PyObject* value,
void* closure) {
EAGER_TRY
auto src = CastPyArg2Tensor(value, 0);
self->tensor = src;
phi::DenseTensor tmp;
auto dense_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
if (dense_tensor) {
dense_tensor->ShareInplaceVersionCounterWith(tmp);
}
return 0;
EAGER_CATCH_AND_THROW_RETURN_NEG
}
PyObject* tensor_properties_get_grad(TensorObject* self, void* closure) {
EAGER_TRY
VLOG(6) << "Get grad for tensor: " << self->tensor.name();
......@@ -130,6 +151,26 @@ int tensor_properties_set_grad(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NEG
}
int tensor_properties_set_grad_(TensorObject* self,
PyObject* value,
void* closure) {
EAGER_TRY
auto src = CastPyArg2Tensor(value, 0);
PADDLE_ENFORCE(
egr::EagerUtils::IsLeafTensor(self->tensor),
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad."));
paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE(grad != nullptr,
paddle::platform::errors::Fatal(
"Detected NULL grad"
"Please check if you have manually cleared"
"the grad inside autograd_meta"));
*grad = src;
return 0;
EAGER_CATCH_AND_THROW_RETURN_NEG
}
int tensor_properties_set_stop_gradient(TensorObject* self,
PyObject* value,
void* closure) {
......@@ -245,6 +286,43 @@ PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) {
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyObject* tensor_properties_get_strides(TensorObject* self, void* closure) {
EAGER_TRY
std::vector<int64_t> value;
if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) {
return ToPyObject(value);
}
auto stride = self->tensor.strides();
size_t rank = static_cast<size_t>(stride.size());
value.resize(rank);
for (size_t i = 0; i < rank; i++) {
value[i] = stride[i];
}
return ToPyObject(value);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyObject* tensor_properties_get_offset(TensorObject* self, void* closure) {
EAGER_TRY
if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) {
RETURN_PY_NONE;
}
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
if (dense_tensor == nullptr) {
RETURN_PY_NONE;
} else {
return ToPyObject(dense_tensor->offset());
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyObject* tensor_properties_get_layout(TensorObject* self, void* closure) {
EAGER_TRY
std::string layout = "";
......@@ -337,11 +415,21 @@ PyObject* tensor_properties_get_grad_fn(TensorObject* self, void* closure) {
}
struct PyGetSetDef variable_properties[] = {
{"data",
(getter)tensor_properties_get_data,
(setter)tensor_properties_set_data,
nullptr,
nullptr},
{"grad",
(getter)tensor_properties_get_grad,
(setter)tensor_properties_set_grad,
nullptr,
nullptr},
{"grad_",
(getter)tensor_properties_get_grad,
(setter)tensor_properties_set_grad_,
nullptr,
nullptr},
{"name",
(getter)tensor_properties_get_name,
(setter)tensor_properties_set_name,
......@@ -359,10 +447,13 @@ struct PyGetSetDef variable_properties[] = {
nullptr},
{"shape", (getter)tensor_properties_get_shape, nullptr, nullptr, nullptr},
{"layout", (getter)tensor_properties_get_layout, nullptr, nullptr, nullptr},
// {"is_leaf", (getter)tensor_properties_get_is_leaf, nullptr,
// nullptr,
// nullptr},
{"strides",
(getter)tensor_properties_get_strides,
nullptr,
nullptr,
nullptr},
{"place", (getter)tensor_properties_get_place, nullptr, nullptr, nullptr},
{"offset", (getter)tensor_properties_get_offset, nullptr, nullptr, nullptr},
{"dist_attr",
(getter)tensor_properties_get_dist_attr,
nullptr,
......
......@@ -810,7 +810,8 @@ void BindImperative(py::module *m_ptr) {
// copys data to cpu place, which reduces performance.
if (parse_index) {
std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags, list_select_idxs;
none_axes, infer_flags;
std::vector<int64_t> list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
ParseIndexingSlice(self_tensor,
......@@ -1008,8 +1009,8 @@ void BindImperative(py::module *m_ptr) {
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
VLOG(4) << "Call _getitem_index_not_tensor";
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, none_axes, infer_flags,
list_select_idxs;
slice_strides, decrease_axis, none_axes, infer_flags;
std::vector<int64_t> list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
auto tensor = self->MutableVar()->GetMutable<phi::DenseTensor>();
......
......@@ -150,7 +150,7 @@ static void ParseIndexingSlice(phi::DenseTensor* tensor,
std::vector<int>* decrease_axis,
std::vector<int>* none_axes,
std::vector<int>* infer_flags,
std::vector<int>* list_select_idxs,
std::vector<int64_t>* list_select_idxs,
bool* list_select_flag) {
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
// types, and list of Bool and Integers.
......
......@@ -964,16 +964,19 @@ inline py::array TensorToPyArray(const phi::DenseTensor &tensor,
bool is_xpu_tensor = platform::is_xpu_place(tensor.place());
bool is_custom_device_tensor = platform::is_custom_place(tensor.place());
const auto &tensor_dims = tensor.dims();
auto tensor_dtype = framework::TransToProtoVarType(tensor.dtype());
size_t sizeof_dtype = framework::SizeOfType(tensor_dtype);
size_t sizeof_dtype = phi::SizeOf(tensor.type());
std::vector<size_t> py_dims(tensor_dims.size());
std::vector<size_t> py_strides(tensor_dims.size());
auto rank = tensor_dims.size() == -1 ? 0 : tensor_dims.size();
std::vector<ssize_t> py_dims(rank);
std::vector<ssize_t> py_strides(rank);
size_t numel = 1;
auto tensor_stride = tensor.strides();
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
py_dims[i] = static_cast<size_t>(tensor_dims[i]);
py_strides[i] = sizeof_dtype * numel;
py_strides[i] = sizeof_dtype * tensor_stride[i];
numel *= py_dims[i];
}
......@@ -991,47 +994,52 @@ inline py::array TensorToPyArray(const phi::DenseTensor &tensor,
const_cast<void *>(tensor_buf_ptr),
base);
} else {
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE_EQ(
py_arr.writeable(),
true,
platform::errors::InvalidArgument(
"PyArray is not writable, in which case memory leak "
"or double free would occur"));
PADDLE_ENFORCE_EQ(
py_arr.owndata(),
true,
platform::errors::InvalidArgument(
"PyArray does not own data, in which case memory leak "
"or double free would occur"));
platform::CPUPlace place;
size_t copy_bytes = sizeof_dtype * numel;
paddle::memory::Copy(
place, py_arr.mutable_data(), place, tensor_buf_ptr, copy_bytes);
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
cpu_tensor.set_meta(tensor.meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::memory::Copy(cpu_place,
cpu_tensor.Holder()->ptr(),
cpu_place,
tensor.Holder()->ptr(),
tensor.Holder()->size());
auto data_ptr = cpu_tensor.data();
auto base = py::cast(std::move(cpu_tensor));
auto py_arr = py::array(
py::dtype(py_dtype_str.c_str()), py_dims, py_strides, data_ptr, base);
return py_arr;
}
} else if (is_xpu_tensor) {
#ifdef PADDLE_WITH_XPU
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE_EQ(py_arr.writeable(),
true,
platform::errors::InvalidArgument(
"PyArray is not writable, in which case memory leak "
"or double free would occur"));
PADDLE_ENFORCE_EQ(
py_arr.owndata(),
true,
platform::errors::InvalidArgument(
"PyArray does not own data, in which case memory leak "
"or double free would occur"));
size_t copy_bytes = sizeof_dtype * numel;
auto p = tensor.place();
paddle::memory::Copy(platform::CPUPlace(),
py_arr.mutable_data(),
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
cpu_tensor.set_meta(tensor.meta());
auto tmp_allocation_ptr = memory::Alloc(cpu_place, tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::memory::Copy(cpu_place,
cpu_tensor.Holder()->ptr(),
p,
tensor_buf_ptr,
copy_bytes);
tensor.Holder()->ptr(),
tensor.Holder()->size());
auto data_ptr = cpu_tensor.data();
auto base = py::cast(std::move(cpu_tensor));
auto py_arr = py::array(
py::dtype(py_dtype_str.c_str()), py_dims, py_strides, data_ptr, base);
return py_arr;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
......@@ -1040,27 +1048,30 @@ inline py::array TensorToPyArray(const phi::DenseTensor &tensor,
#endif
} else if (is_gpu_tensor) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE_EQ(py_arr.writeable(),
true,
platform::errors::InvalidArgument(
"PyArray is not writable, in which case memory leak "
"or double free would occur"));
PADDLE_ENFORCE_EQ(
py_arr.owndata(),
true,
platform::errors::InvalidArgument(
"PyArray does not own data, in which case memory leak "
"or double free would occur"));
size_t copy_bytes = sizeof_dtype * numel;
auto p = tensor.place();
paddle::memory::Copy(platform::CPUPlace(),
py_arr.mutable_data(),
p,
tensor_buf_ptr,
copy_bytes,
nullptr);
#if defined(PADDLE_WITH_CUDA)
gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
gpuMemcpyKind kind = hipMemcpyDeviceToHost;
#endif
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
cpu_tensor.set_meta(tensor.meta());
auto tmp_allocation_ptr = memory::Alloc(cpu_place, tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
tensor.Holder()->ptr(),
tensor.Holder()->size(),
kind);
auto data_ptr = cpu_tensor.data();
auto base = py::cast(std::move(cpu_tensor));
auto py_arr = py::array(
py::dtype(py_dtype_str.c_str()), py_dims, py_strides, data_ptr, base);
return py_arr;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
......@@ -1069,19 +1080,6 @@ inline py::array TensorToPyArray(const phi::DenseTensor &tensor,
#endif
} else if (is_custom_device_tensor) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE_EQ(py_arr.writeable(),
true,
platform::errors::InvalidArgument(
"PyArray is not writable, in which case memory leak "
"or double free would occur"));
PADDLE_ENFORCE_EQ(
py_arr.owndata(),
true,
platform::errors::InvalidArgument(
"PyArray does not own data, in which case memory leak "
"or double free would occur"));
// TODO(qili93): temporary for ascned npu performance to be removed along
// with npu_identity op
paddle::Tensor tensor_out(std::make_shared<phi::DenseTensor>());
......@@ -1090,21 +1088,66 @@ inline py::array TensorToPyArray(const phi::DenseTensor &tensor,
tensor_out = npu_identity_ad_func(tensor_in, -1);
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_out.impl());
tensor_buf_ptr = dense_tensor->data();
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(tensor.place());
auto p = dense_tensor->place();
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
cpu_tensor.set_meta(dense_tensor->meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor->Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::memory::Copy(
cpu_place,
cpu_tensor.Holder()->ptr(),
p,
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size(),
reinterpret_cast<const platform::CustomDeviceContext &>(ctx)
.stream());
ctx.Wait();
auto data_ptr = cpu_tensor.data();
auto base = py::cast(std::move(cpu_tensor));
auto py_arr = py::array(
py::dtype(py_dtype_str.c_str()), py_dims, py_strides, data_ptr, base);
return py_arr;
}
size_t copy_bytes = sizeof_dtype * numel;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(tensor.place());
auto p = tensor.place();
phi::DenseTensor cpu_tensor;
platform::CPUPlace cpu_place;
cpu_tensor.set_meta(tensor.meta());
auto tmp_allocation_ptr = memory::Alloc(cpu_place, tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::memory::Copy(
platform::CPUPlace(),
py_arr.mutable_data(),
tensor.place(),
tensor_buf_ptr,
copy_bytes,
cpu_place,
cpu_tensor.Holder()->ptr(),
p,
tensor.Holder()->ptr(),
tensor.Holder()->size(),
reinterpret_cast<const platform::CustomDeviceContext &>(ctx).stream());
ctx.Wait();
auto data_ptr = cpu_tensor.data();
auto base = py::cast(std::move(cpu_tensor));
auto py_arr = py::array(
py::dtype(py_dtype_str.c_str()), py_dims, py_strides, data_ptr, base);
return py_arr;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot use CustomPlace in CPU/GPU/XPU version, "
......
......@@ -177,6 +177,13 @@ class PADDLE_API Tensor final {
*/
std::vector<int64_t> shape() const;
/**
* @brief Return the strides (dimensions) of Tensor.
*
* @return phi::DDim
*/
const phi::DDim& strides() const;
/**
* @brief Reset the shape of the tensor.
* @note: This method means Reset the shape of the tensor,
......
......@@ -103,7 +103,8 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
temp_dense_tensots.reserve(x.size());
for (size_t i = 0; i < input_x.size(); ++i) {
if (phi::DenseTensor::classof(x[i].impl().get())) {
temp_dense_tensots.push_back(PrepareData(x[i], kernel.InputAt(0), {}));
temp_dense_tensots.push_back(
PrepareData(x[i], kernel.InputAt(0), {}, false));
input_x[i] = temp_dense_tensots.back().get();
} else {
input_x[i] = x[i].impl().get();
......@@ -167,9 +168,9 @@ void embedding_grad_impl(const Tensor& x,
auto* dev_ctx = GetDeviceContextByBackend(
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend());
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = PrepareData(weight, kernel.InputAt(1), {});
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
auto input_x = PrepareData(x, kernel.InputAt(0), {}, false);
auto input_weight = PrepareData(weight, kernel.InputAt(1), {}, false);
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}, false);
if (sparse) {
auto* kernel_out = SetSelectedRowsKernelOutput(weight_grad);
......@@ -222,9 +223,9 @@ void embedding_grad_impl(const Tensor& x,
auto* dev_ctx = GetDeviceContextByBackend(
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend());
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_x = PrepareData(x, kernel.InputAt(0), {}, false);
auto input_weight = TensorToSelectedRows(weight);
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}, false);
if (sparse) {
auto* kernel_out = SetSelectedRowsKernelOutput(weight_grad);
......
......@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "gflags/gflags.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/strided_copy_kernel.h"
DECLARE_bool(use_stride_kernel);
namespace paddle {
namespace experimental {
......@@ -289,5 +294,186 @@ phi::TensorBase* SetStringsKernelOutput(Tensor* out, TensorType type) {
return out->impl().get();
}
phi::DenseTensor* ProcessStrideBackup(phi::DenseTensor** tensor) {
if (!FLAGS_use_stride_kernel || *tensor == nullptr ||
!(*tensor)->IsInitialized() || (*tensor)->meta().is_contiguous()) {
return nullptr;
} else {
phi::DenseTensor* backup = *tensor;
*tensor = new phi::DenseTensor();
return backup;
}
}
std::vector<phi::DenseTensor*> ProcessStrideBackup(
std::vector<phi::DenseTensor*>* tensor) {
std::vector<phi::DenseTensor*> backup;
backup.reserve(tensor->size());
for (auto& t : *tensor) {
if (!FLAGS_use_stride_kernel || t == nullptr || !t->IsInitialized() ||
t->meta().is_contiguous()) {
backup.emplace_back(nullptr);
} else {
backup.emplace_back(t);
t = new phi::DenseTensor();
}
}
return backup;
}
phi::SelectedRows* ProcessStrideBackup(phi::SelectedRows** tensor) {
return nullptr;
}
template <typename Context>
void TransStride(const Context& dev_ctx,
phi::DenseTensor* from,
phi::DenseTensor* to) {
if (to) {
PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx,
*from,
phi::vectorize<int64_t>(to->dims()),
phi::vectorize<int64_t>(to->strides()),
to->offset(),
to);
}));
delete from;
}
}
template <typename Context>
void TransStride(const Context& dev_ctx,
const std::vector<phi::DenseTensor*>& from,
const std::vector<phi::DenseTensor*>& to) {
for (size_t i = 0; i < to.size(); i++) {
if (to[i]) {
PD_VISIT_ALL_TYPES(to[i]->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx,
*from[i],
phi::vectorize<int64_t>(to[i]->dims()),
phi::vectorize<int64_t>(to[i]->strides()),
to[i]->offset(),
to[i]);
}));
delete from[i];
}
}
}
void TransStride(phi::DeviceContext* dev_ctx,
phi::DenseTensor* from,
phi::DenseTensor* to) {
if (to) {
auto* cpu_ctx = dynamic_cast<phi::CPUContext*>(dev_ctx);
if (cpu_ctx) {
PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::CPUContext>(
*cpu_ctx,
*from,
phi::vectorize<int64_t>(to->dims()),
phi::vectorize<int64_t>(to->strides()),
to->offset(),
to);
}));
delete from;
return;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* gpu_ctx = dynamic_cast<phi::GPUContext*>(dev_ctx);
if (gpu_ctx) {
PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::GPUContext>(
*gpu_ctx,
*from,
phi::vectorize<int64_t>(to->dims()),
phi::vectorize<int64_t>(to->strides()),
to->offset(),
to);
}));
delete from;
return;
}
#endif
#ifdef PADDLE_WITH_XPU
auto* xpu_ctx = dynamic_cast<phi::XPUContext*>(dev_ctx);
if (xpu_ctx) {
PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::XPUContext>(
*xpu_ctx,
*from,
phi::vectorize<int64_t>(to->dims()),
phi::vectorize<int64_t>(to->strides()),
to->offset(),
to);
}));
delete from;
return;
}
#endif
}
}
void TransStride(phi::DeviceContext* dev_ctx,
const std::vector<phi::DenseTensor*>& from,
const std::vector<phi::DenseTensor*>& to) {
for (size_t i = 0; i < to.size(); i++) {
if (to[i]) {
auto* cpu_ctx = dynamic_cast<phi::CPUContext*>(dev_ctx);
if (cpu_ctx) {
PD_VISIT_ALL_TYPES(to[i]->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::CPUContext>(
*cpu_ctx,
*from[i],
phi::vectorize<int64_t>(to[i]->dims()),
phi::vectorize<int64_t>(to[i]->strides()),
to[i]->offset(),
to[i]);
}));
delete from[i];
continue;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* gpu_ctx = dynamic_cast<phi::GPUContext*>(dev_ctx);
if (gpu_ctx) {
PD_VISIT_ALL_TYPES(to[i]->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::GPUContext>(
*gpu_ctx,
*from[i],
phi::vectorize<int64_t>(to[i]->dims()),
phi::vectorize<int64_t>(to[i]->strides()),
to[i]->offset(),
to[i]);
}));
delete from[i];
continue;
}
#endif
#ifdef PADDLE_WITH_XPU
auto* xpu_ctx = dynamic_cast<phi::XPUContext*>(dev_ctx);
if (xpu_ctx) {
PD_VISIT_ALL_TYPES(to[i]->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::XPUContext>(
*xpu_ctx,
*from[i],
phi::vectorize<int64_t>(to[i]->dims()),
phi::vectorize<int64_t>(to[i]->strides()),
to[i]->offset(),
to[i]);
}));
delete from[i];
continue;
}
#endif
}
}
}
void TransStride(phi::DeviceContext* dev_ctx,
phi::SelectedRows* from,
phi::SelectedRows* to) {}
} // namespace experimental
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/meta_tensor.h"
......@@ -107,5 +108,24 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type);
phi::TensorBase* SetStringsKernelOutput(Tensor* out, TensorType type);
phi::DenseTensor* ProcessStrideBackup(phi::DenseTensor** tensor);
std::vector<phi::DenseTensor*> ProcessStrideBackup(
std::vector<phi::DenseTensor*>* tensor);
phi::SelectedRows* ProcessStrideBackup(phi::SelectedRows** tensor);
void TransStride(phi::DeviceContext* dev_ctx,
phi::DenseTensor* from,
phi::DenseTensor* to);
void TransStride(phi::DeviceContext* dev_ctx,
const std::vector<phi::DenseTensor*>& from,
const std::vector<phi::DenseTensor*>& to);
void TransStride(phi::DeviceContext* dev_ctx,
phi::SelectedRows* from,
phi::SelectedRows* to);
} // namespace experimental
} // namespace paddle
......@@ -16,13 +16,18 @@ limitations under the License. */
#include "glog/logging.h"
#include "gflags/gflags.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/contiguous_kernel.h"
#include "paddle/phi/kernels/transfer_layout_kernel.h"
DECLARE_bool(use_stride_kernel);
namespace paddle {
namespace experimental {
......@@ -39,6 +44,10 @@ inline bool NeedTransformLayout(const DataLayout& input,
const DataLayout& target,
const phi::Place& place,
const TransformFlag& transform_flag) {
if (FLAGS_use_stride_kernel && target == DataLayout::STRIDED) {
return false;
}
bool ret = transform_flag.need_trans_layout() &&
(input != DataLayout::ALL_LAYOUT &&
target != DataLayout::ALL_LAYOUT && input != target);
......@@ -48,6 +57,11 @@ inline bool NeedTransformLayout(const DataLayout& input,
return ret;
}
inline bool NeedTransform2Contiguous(bool is_stride_kernel,
bool is_contiguous) {
return FLAGS_use_stride_kernel && !is_stride_kernel && !is_contiguous;
}
inline phi::DenseTensor TransDataLayout(const phi::DenseTensor& tensor,
DataLayout layout) {
auto& pool = phi::DeviceContextPool::Instance();
......@@ -181,24 +195,83 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor,
return out;
}
template <typename Context>
phi::DenseTensor TensorContiguous(const Context& dev_ctx,
const phi::DenseTensor& tensor) {
phi::DenseTensor dense_out;
phi::MetaTensor meta_input(tensor);
phi::MetaTensor meta_out(&dense_out);
UnchangedInferMeta(meta_input, &meta_out);
PD_VISIT_ALL_TYPES(tensor.dtype(), "TensorContiguous", ([&] {
phi::ContiguousKernel<data_t, Context>(
dev_ctx, tensor, &dense_out);
}));
return dense_out;
}
phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) {
auto& pool = paddle::platform::DeviceContextPool::Instance();
VLOG(3) << "Trans2Contiguous...";
if (tensor.place().GetType() == phi::AllocationType::CPU) {
auto* dev_ctx = static_cast<phi::CPUContext*>(pool.Get(tensor.place()));
return TensorContiguous<phi::CPUContext>(*dev_ctx, tensor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (tensor.place().GetType() == phi::AllocationType::GPU) {
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
return TensorContiguous<phi::GPUContext>(*dev_ctx, tensor);
#endif
#ifdef PADDLE_WITH_XPU
} else if (tensor.place().GetType() == phi::AllocationType::XPU) {
auto* dev_ctx = static_cast<phi::XPUContext*>(pool.Get(tensor.place()));
return TensorContiguous<phi::XPUContext>(*dev_ctx, tensor);
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Place type is not supported when casting data type."));
}
return tensor;
}
void CheckAndTrans2Contiguous(phi::DenseTensor* tensor) {
if (!tensor->meta().is_contiguous()) {
phi::DenseTensor tmp = Trans2Contiguous(*tensor);
tensor->ShareDataWith(tmp);
}
}
phi::DenseTensor TransformData(phi::DenseTensor* tensor,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
const TransformFlag& transform_flag,
bool is_stride_kernel) {
phi::DenseTensor out = *tensor;
bool trans_layout = false;
bool trans_dtype = false;
if (NeedTransform2Contiguous(is_stride_kernel, out.meta().is_contiguous())) {
out = Trans2Contiguous(out);
}
if (NeedTransformLayout(tensor->layout(),
target_args_def.layout,
tensor->place(),
transform_flag) &&
tensor->dims().size() != 1) {
if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
out = Trans2Contiguous(out);
}
out = TransDataLayout(out, target_args_def.layout);
trans_layout = true;
}
if (NeedTransformDataType(
tensor->dtype(), target_args_def.dtype, transform_flag)) {
if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
out = Trans2Contiguous(out);
}
out = TransDataType(out, target_args_def.dtype);
trans_dtype = true;
}
......@@ -217,7 +290,8 @@ phi::DenseTensor TransformData(phi::DenseTensor* tensor,
std::shared_ptr<phi::DenseTensor> PrepareData(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
const TransformFlag& transform_flag,
bool is_stride_kernel) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::DenseTensor& dense_tensor =
......@@ -230,11 +304,13 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
!NeedTransformLayout(dense_tensor.layout(),
target_args_def.layout,
dense_tensor.place(),
transform_flag))) {
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
}
phi::DenseTensor out =
TransformData(&dense_tensor, target_args_def, transform_flag);
phi::DenseTensor out = TransformData(
&dense_tensor, target_args_def, transform_flag, is_stride_kernel);
return std::make_shared<phi::DenseTensor>(std::move(out));
}
return nullptr;
......@@ -243,9 +319,11 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
paddle::optional<phi::DenseTensor> PrepareData(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (input) {
return {*PrepareData(*input, target_args_def, transform_flag)};
return {*PrepareData(
*input, target_args_def, transform_flag, is_stride_kernel)};
}
return paddle::none;
}
......@@ -253,12 +331,14 @@ paddle::optional<phi::DenseTensor> PrepareData(
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
const TransformFlag& transform_flag,
bool is_stride_kernel) {
auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
pt_tensors->reserve(inputs.size());
for (const auto& input : inputs) {
const auto& tensor_in = input.impl();
auto dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in);
if (!transform_flag.NeedTransform() || !tensor_in->initialized() ||
(!NeedTransformPlace(
tensor_in->place(), target_args_def.backend, transform_flag) &&
......@@ -267,14 +347,18 @@ std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
!NeedTransformLayout(tensor_in->layout(),
target_args_def.layout,
tensor_in->place(),
transform_flag))) {
transform_flag) &&
!(dense_tensor &&
NeedTransform2Contiguous(is_stride_kernel,
dense_tensor->meta().is_contiguous())))) {
pt_tensors->emplace_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in));
} else {
pt_tensors->emplace_back(
TransformData((static_cast<phi::DenseTensor*>(tensor_in.get())),
target_args_def,
transform_flag));
transform_flag,
is_stride_kernel));
}
}
......@@ -284,9 +368,11 @@ std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
const paddle::optional<std::vector<Tensor>>& inputs,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (inputs) {
return {*PrepareData(*inputs, target_args_def, transform_flag)};
return {*PrepareData(
*inputs, target_args_def, transform_flag, is_stride_kernel)};
}
return paddle::none;
}
......@@ -299,23 +385,48 @@ std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
if (tensor_in) {
phi::SelectedRows& selected_rows =
*static_cast<phi::SelectedRows*>(tensor_in.get());
if (!transform_flag.NeedTransform() || !selected_rows.initialized() ||
(!NeedTransformPlace(
selected_rows.place(), target_args_def.backend, transform_flag))) {
if ((!transform_flag.NeedTransform() || !selected_rows.initialized() ||
(!NeedTransformPlace(selected_rows.place(),
target_args_def.backend,
transform_flag))) &&
!NeedTransform2Contiguous(
false, selected_rows.value().meta().is_contiguous())) {
return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
}
auto dense_out = TransDataPlace(
selected_rows.value(), phi::TransToPhiPlace(target_args_def.backend));
if (selected_rows.place().GetType() == AllocationType::GPUPINNED) {
selected_rows.mutable_value()->ShareBufferWith(dense_out);
if (NeedTransform2Contiguous(
false, selected_rows.value().meta().is_contiguous())) {
auto dense_out = Trans2Contiguous(selected_rows.value());
selected_rows.mutable_value()->ShareDataWith(dense_out);
}
if (transform_flag.NeedTransform() && selected_rows.initialized() &&
NeedTransformPlace(
selected_rows.place(), target_args_def.backend, transform_flag)) {
auto dense_out =
TransDataPlace(selected_rows.value(),
phi::TransToPhiPlace(target_args_def.backend));
selected_rows.mutable_value()->ShareBufferWith(dense_out);
}
return std::static_pointer_cast<phi::SelectedRows>(tensor_in);
} else {
auto out_new = std::make_shared<phi::SelectedRows>(
selected_rows.rows(), selected_rows.height());
if (NeedTransform2Contiguous(
false, selected_rows.value().meta().is_contiguous())) {
auto dense_out = Trans2Contiguous(selected_rows.value());
*out_new->mutable_value() = dense_out;
}
if (transform_flag.NeedTransform() && selected_rows.initialized() &&
NeedTransformPlace(
selected_rows.place(), target_args_def.backend, transform_flag)) {
auto dense_out =
TransDataPlace(selected_rows.value(),
phi::TransToPhiPlace(target_args_def.backend));
*out_new->mutable_value() = dense_out;
}
return out_new;
}
auto out_new = std::make_shared<phi::SelectedRows>(selected_rows.rows(),
selected_rows.height());
*out_new->mutable_value() = dense_out;
return out_new;
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The impl() of input tensor is nullptr, it doesn't support for "
......@@ -332,6 +443,105 @@ paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
return paddle::none;
}
std::shared_ptr<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
const Tensor& input) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::SparseCooTensor& sparse_tensor =
*static_cast<phi::SparseCooTensor*>(tensor_in.get());
if (sparse_tensor.indices().meta().is_contiguous() &&
sparse_tensor.values().meta().is_contiguous()) {
return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
}
if (!sparse_tensor.indices().meta().is_contiguous()) {
*sparse_tensor.mutable_indices() =
Trans2Contiguous(sparse_tensor.indices());
}
if (!sparse_tensor.values().meta().is_contiguous()) {
*sparse_tensor.mutable_values() =
Trans2Contiguous(sparse_tensor.values());
}
return std::static_pointer_cast<phi::SparseCooTensor>(tensor_in);
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The impl() of input tensor is nullptr, it doesn't support for "
"SparseCooTensor data transform now."));
}
paddle::optional<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
const paddle::optional<Tensor>& input) {
if (input) {
return *PrepareDataForSparseCooTensor(*input);
}
return paddle::none;
}
std::shared_ptr<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
const Tensor& input) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::SparseCsrTensor& sparse_tensor =
*static_cast<phi::SparseCsrTensor*>(tensor_in.get());
if (sparse_tensor.crows().meta().is_contiguous() &&
sparse_tensor.cols().meta().is_contiguous() &&
sparse_tensor.values().meta().is_contiguous()) {
return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
}
if (!sparse_tensor.crows().meta().is_contiguous()) {
*sparse_tensor.mutable_crows() = Trans2Contiguous(sparse_tensor.crows());
}
if (!sparse_tensor.cols().meta().is_contiguous()) {
*sparse_tensor.mutable_cols() = Trans2Contiguous(sparse_tensor.cols());
}
if (!sparse_tensor.values().meta().is_contiguous()) {
*sparse_tensor.mutable_values() =
Trans2Contiguous(sparse_tensor.values());
}
return std::static_pointer_cast<phi::SparseCsrTensor>(tensor_in);
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The impl() of input tensor is nullptr, it doesn't support for "
"SparseCsrTensor data transform now."));
}
paddle::optional<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
const paddle::optional<Tensor>& input) {
if (input) {
return *PrepareDataForSparseCsrTensor(*input);
}
return paddle::none;
}
std::shared_ptr<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
const Tensor& input) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::DenseTensor& dense_tensor =
*static_cast<phi::DenseTensor*>(tensor_in.get());
if (dense_tensor.meta().is_contiguous()) {
return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
}
return std::make_shared<phi::DenseTensor>(
std::move(Trans2Contiguous(dense_tensor)));
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The impl() of input tensor is nullptr, it doesn't support for "
"DenseTensor data transform now."));
}
paddle::optional<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
const paddle::optional<Tensor>& input) {
if (input) {
return *PrepareDataForDenseTensorInSparse(*input);
}
return paddle::none;
}
void TransDataBackend(const phi::DenseTensor* tensor,
Backend target_backend,
phi::DenseTensor* out) {
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace paddle {
namespace experimental {
......@@ -78,22 +80,26 @@ static inline phi::TensorArgDef GetKernelInputArgDef(
std::shared_ptr<phi::DenseTensor> PrepareData(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
const TransformFlag& transform_flag,
bool is_stride_kernel);
paddle::optional<phi::DenseTensor> PrepareData(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
const TransformFlag& transform_flag,
bool is_stride_kernel);
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
const TransformFlag& transform_flag,
bool is_stride_kernel);
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
const paddle::optional<std::vector<Tensor>>& inputs,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
const TransformFlag& transform_flag,
bool is_stride_kernel);
// Only support transfering place for SelectedRows
std::shared_ptr<phi::SelectedRows> PrepareDataForSelectedRows(
......@@ -106,6 +112,27 @@ paddle::optional<phi::SelectedRows> PrepareDataForSelectedRows(
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
// Only support transfering contiguous for SparseCooTensor
std::shared_ptr<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
const Tensor& input);
paddle::optional<phi::SparseCooTensor> PrepareDataForSparseCooTensor(
const paddle::optional<Tensor>& input);
// Only support transfering contiguous for SparseCsrTensor
std::shared_ptr<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
const Tensor& input);
paddle::optional<phi::SparseCsrTensor> PrepareDataForSparseCsrTensor(
const paddle::optional<Tensor>& input);
// Only support transfering contiguous
std::shared_ptr<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
const Tensor& input);
paddle::optional<phi::DenseTensor> PrepareDataForDenseTensorInSparse(
const paddle::optional<Tensor>& input);
void TransDataBackend(const phi::DenseTensor* tensor,
Backend target_backend,
phi::DenseTensor* out);
......@@ -118,6 +145,9 @@ void TransDataBackend(const phi::SelectedRows* tensor,
Backend target_backend,
phi::SelectedRows* out);
phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor);
void CheckAndTrans2Contiguous(phi::DenseTensor* tensor);
inline bool NeedTransformPlace(const phi::Place& src_place,
const Backend& target,
const TransformFlag& transform_flag) {
......
......@@ -104,6 +104,15 @@ std::vector<int64_t> Tensor::shape() const {
return phi::vectorize<int64_t>(dims);
}
const phi::DDim &Tensor::strides() const {
if (is_dense_tensor()) {
return static_cast<phi::DenseTensor *>(impl_.get())->strides();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support strides operation on DenseTensor now."));
}
}
void Tensor::reshape(const std::vector<int64_t> &shape) {
LOG_FIRST_N(WARNING, 1)
<< "The function of resetting the shape of the uninitialized "
......
......@@ -104,6 +104,16 @@
output : Tensor(x_grad)
invoke : as_complex(out_grad)
- backward_op : as_strided_grad
forward : as_strided (Tensor input, int64_t[] dims = {}, int64_t[] stride = {}, int64_t offset = 0) -> Tensor(out)
args : (Tensor input, Tensor out_grad, int64_t[] dims = {}, int64_t[] stride = {}, int64_t offset = 0)
output : Tensor(input_grad)
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : as_strided_grad
- backward_op : asin_grad
forward : asin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......@@ -1121,6 +1131,18 @@
data_type : out_grad
no_need_buffer : x
- backward_op : index_select_strided_grad
forward : index_select_strided(Tensor x, int64_t index, int axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int64_t index, int axis)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : index_select_strided_grad
data_type : out_grad
no_need_buffer : x
- backward_op : instance_norm_double_grad
forward : instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias)
args : (Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon)
......@@ -2335,6 +2357,16 @@
func : temporal_shift_grad
data_type : out_grad
- backward_op : tensor_unfold_grad
forward : tensor_unfold (Tensor input, int64_t axis, int64_t size, int64_t step) -> Tensor(out)
args : (Tensor input, Tensor out_grad, int64_t axis, int64_t size, int64_t step)
output : Tensor(input_grad)
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : tensor_unfold_grad
- backward_op : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float threshold)
......@@ -2463,6 +2495,27 @@
kernel :
func : unstack_grad
- backward_op : view_dtype_grad
forward : view_dtype (Tensor input, DataType dtype) -> Tensor(out)
args : (Tensor input, Tensor out_grad, DataType dtype)
output : Tensor(input_grad)
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : view_dtype_grad
data_type : out_grad
- backward_op : view_shape_grad
forward : view_shape (Tensor input, int64_t[] dims = {}) -> Tensor(out)
args : (Tensor input, Tensor out_grad, int64_t[] dims = {})
output : Tensor(input_grad)
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : view_shape_grad
- backward_op : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank = 0, bool norm_by_times = false) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
......
......@@ -666,6 +666,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
+ out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)
+ "("
+ out_name
+ ", kernel_result.is_stride_kernel"
+ ");\n"
)
if len(kernel_output_names) == 1:
......@@ -709,7 +710,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code = (
input_tensor_code
+ f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag});"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel);"""
)
return input_tensor_code
......@@ -761,7 +762,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code = (
input_tensor_code
+ f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag});
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel);
{code_indent} paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name};
{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_vec){{
{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::optional<std::vector<const phi::DenseTensor*>>({PREFIX_TENSOR_NAME}{input_name}_vec->size());
......@@ -799,7 +800,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code = (
input_tensor_code
+ f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag});
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel);
{code_indent} std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
......@@ -1204,6 +1205,19 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
code_indent,
inplace_flag,
)
pre_save_stride = ""
transdata2strided = ""
if inplace_flag and kernel_name not in [
"squeeze",
"unsqueeze",
"reshape",
"flatten",
]:
i = 0
for kernel_out in outputs_args:
pre_save_stride += f"""{code_indent} auto backup{i} = ProcessStrideBackup(&{kernel_out});\n"""
transdata2strided += f"""{code_indent} TransStride(dev_ctx, {kernel_out}, backup{i});\n"""
i = i + 1
fallback_kernel_output_trans = ""
for kernel_out in outputs_args:
fallback_kernel_output_trans += f"""
......@@ -1220,6 +1234,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
{input_tensors}
{output_create}
{pre_save_stride}
{code_indent} phi::RecordEvent *infer_shape_record_event = nullptr;
{code_indent} if(phi::RecordEvent::IsEnabled()){{
{code_indent} infer_shape_record_event = new phi::RecordEvent(\"{self.api} infer_meta\", phi::TracerEventType::OperatorInner, 1);
......@@ -1238,6 +1253,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} if(kernel_record_event != nullptr){{
{code_indent} delete kernel_record_event;
{code_indent} }}
{transdata2strided}
{code_indent} if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans}
{self.reset_view_after_fallback(self.outputs['types'], code_indent, inplace_flag)}
......
......@@ -66,7 +66,7 @@ class ForwardAPI(BaseAPI):
input_tensor_code = (
input_tensor_code
+ f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt(0), {trans_flag});"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt(0), {trans_flag}, kernel_result.is_stride_kernel);"""
)
else:
# do nothing
......
......@@ -107,26 +107,104 @@ class SparseAPI(ForwardAPI):
}
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
input_types = self.inputs['tensor_type']
tensor_type_map = {
'dense': 'phi::DenseTensor',
'sparse_coo': 'phi::SparseCooTensor',
'sparse_csr': 'phi::SparseCsrTensor',
}
inputsname2tensortype = {}
for i in range(len(input_names)):
inputsname2tensortype[input_names[i]] = input_types[i]
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
infer_meta = self.infer_meta
infer_meta_params = (
infer_meta['param']
if infer_meta['param'] is not None
else input_names + attr_names
)
kernel_context_code = ""
for param in kernel_param:
if param in input_names and param not in infer_meta_params:
var_name = " auto " + PREFIX_TENSOR_NAME + param + " = "
if self.inputs['input_info'][param] == "const Tensor&":
if inputsname2tensortype[param] == "sparse_coo":
kernel_context_code = (
kernel_context_code
+ var_name
+ "PrepareDataForSparseCooTensor("
+ param
+ ");\n"
)
elif inputsname2tensortype[param] == "sparse_csr":
kernel_context_code = (
kernel_context_code
+ var_name
+ "PrepareDataForSparseCsrTensor("
+ param
+ ");\n"
)
else:
kernel_context_code = (
kernel_context_code
+ var_name
+ "PrepareDataForDenseTensorInSparse("
+ param
+ ");\n"
)
elif param in self.optional_vars:
tensor_type = 'phi::DenseTensor'
for name, input_type in zip(input_names, input_types):
if param == name:
tensor_type = tensor_type_map[input_type]
break
optional_var = "paddle::optional<" + tensor_type + ">("
if inputsname2tensortype[param] == "sparse_coo":
kernel_context_code = (
kernel_context_code
+ var_name
+ "PrepareDataForSparseCooTensor("
+ param
+ ");\n"
)
elif inputsname2tensortype[param] == "sparse_csr":
kernel_context_code = (
kernel_context_code
+ var_name
+ "PrepareDataForSparseCsrTensor("
+ param
+ ");\n"
)
else:
kernel_context_code = (
kernel_context_code
+ var_name
+ "PrepareDataForDenseTensorInSparse("
+ param
+ ");\n"
)
for param in kernel_param:
if param in input_names:
if param in self.optional_vars:
kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);"""
kernel_context.EmplaceBackInput({param} ? &(*{PREFIX_TENSOR_NAME}{param}) : nullptr);"""
)
else:
kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackInput({param}.impl().get());"""
kernel_context.EmplaceBackInput({PREFIX_TENSOR_NAME}{param}.get());"""
)
continue
......@@ -167,6 +245,10 @@ class SparseAPI(ForwardAPI):
else input_names + attr_names
)
inputsname2tensortype = {}
for i in range(len(input_names)):
inputsname2tensortype[input_names[i]] = input_types[i]
create_input_var_code = ""
tensor_type_map = {
'dense': 'phi::DenseTensor',
......@@ -175,11 +257,32 @@ class SparseAPI(ForwardAPI):
}
for param in infer_meta_params:
if param in input_names:
var_name = "auto " + PREFIX_TENSOR_NAME + param + " = "
var_name = " auto " + PREFIX_TENSOR_NAME + param + " = "
if self.inputs['input_info'][param] == "const Tensor&":
create_input_var_code = (
create_input_var_code + var_name + param + ".impl();\n"
)
if inputsname2tensortype[param] == "sparse_coo":
create_input_var_code = (
create_input_var_code
+ var_name
+ "PrepareDataForSparseCooTensor("
+ param
+ ");\n"
)
elif inputsname2tensortype[param] == "sparse_csr":
create_input_var_code = (
create_input_var_code
+ var_name
+ "PrepareDataForSparseCsrTensor("
+ param
+ ");\n"
)
else:
create_input_var_code = (
create_input_var_code
+ var_name
+ "PrepareDataForDenseTensorInSparse("
+ param
+ ");\n"
)
elif param in self.optional_vars:
tensor_type = 'phi::DenseTensor'
for name, input_type in zip(input_names, input_types):
......@@ -187,20 +290,30 @@ class SparseAPI(ForwardAPI):
tensor_type = tensor_type_map[input_type]
break
optional_var = "paddle::optional<" + tensor_type + ">("
create_input_var_code = (
create_input_var_code
+ var_name
+ param
+ " ? "
+ optional_var
+ "*static_cast<"
+ tensor_type
+ "*>((*"
+ param
+ ").impl().get())) : "
+ optional_var
+ "paddle::none);\n"
)
if inputsname2tensortype[param] == "sparse_coo":
create_input_var_code = (
create_input_var_code
+ var_name
+ "PrepareDataForSparseCooTensor("
+ param
+ ");\n"
)
elif inputsname2tensortype[param] == "sparse_csr":
create_input_var_code = (
create_input_var_code
+ var_name
+ "PrepareDataForSparseCsrTensor("
+ param
+ ");\n"
)
else:
create_input_var_code = (
create_input_var_code
+ var_name
+ "PrepareDataForDenseTensorInSparse("
+ param
+ ");\n"
)
return f"""{create_input_var_code}"""
def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False):
......
......@@ -126,6 +126,7 @@ def source_include(header_file_path):
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/unary.h"
......
......@@ -178,6 +178,17 @@
func : as_real
backward : as_real_grad
- op : as_strided
args : (Tensor input, int64_t[] dims = {}, int64_t[] stride = {}, int64_t offset = 0)
output : Tensor
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : as_strided
backward : as_strided_grad
no_need_buffer : input
- op : asin
args : (Tensor x)
output : Tensor(out)
......@@ -1232,6 +1243,16 @@
data_type : x
backward : index_select_grad
- op : index_select_strided
args : (Tensor x, int64_t index, int axis = 0)
output : Tensor(out)
infer_meta :
func : IndexSelectStridedInferMeta
kernel :
func : index_select_strided
data_type : x
backward : index_select_strided_grad
- op : instance_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon=1e-5)
output : Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
......@@ -2503,6 +2524,17 @@
data_type : x
backward : temporal_shift_grad
- op : tensor_unfold
args : (Tensor input, int64_t axis, int64_t size, int64_t step)
output : Tensor
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : tensor_unfold
backward : tensor_unfold_grad
no_need_buffer : input
- op : thresholded_relu
args : (Tensor x, float threshold = 1.0)
output : Tensor(out)
......@@ -2650,6 +2682,29 @@
skip_transform : found_infinite
inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps)
- op : view_dtype
args : (Tensor input, DataType dtype)
output : Tensor(out)
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : view_dtype
data_type : input
backward : view_dtype_grad
no_need_buffer : input
- op : view_shape
args : (Tensor input, int64_t[] dims = {})
output : Tensor(out)
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : view_shape
backward : view_shape_grad
no_need_buffer : input
- op : viterbi_decode
args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag = true)
output : Tensor(scores), Tensor(path)
......
......@@ -45,6 +45,7 @@ enum class DataLayout {
SPARSE_COO,
SPARSE_CSR,
PSTRING_UNION,
STRIDED,
NUM_DATA_LAYOUTS,
......@@ -92,6 +93,8 @@ inline DataLayout StringToDataLayout(const std::string& str) {
return DataLayout::PSTRING_UNION;
} else if (s == "NCDHW") {
return DataLayout::kNCDHW;
} else if (s == "STRIDED") {
return DataLayout::STRIDED;
} else {
PD_THROW("Unknown data layout type string: ", s, ".");
}
......@@ -117,6 +120,8 @@ inline std::string DataLayoutToString(const DataLayout& layout) {
return "NCDHW";
case DataLayout::PSTRING_UNION:
return "PSTRING_UNION";
case DataLayout::STRIDED:
return "STRIDED";
default:
PD_THROW("Unknown Data Layout type ", static_cast<int>(layout), ".");
}
......
......@@ -42,8 +42,14 @@ struct DDimEqualityVisitor {
};
bool DDim::operator==(const DDim& d) const {
return size() == d.size() &&
this->apply_visitor(DDimEqualityVisitor(d.Get()));
if (size() == -1 && d.size() == -1) {
return true;
} else if (size() == -1 || d.size() == -1) {
return false;
} else {
return size() == d.size() &&
this->apply_visitor(DDimEqualityVisitor(d.Get()));
}
}
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
......@@ -66,6 +72,9 @@ struct ProductVisitor {
};
int64_t product(const DDim& ddim) {
if (ddim.size() == -1) {
return 0;
}
return ddim.apply_visitor(ProductVisitor());
}
......@@ -105,6 +114,9 @@ struct DDimPrinter {
};
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
if (ddim.size() == -1) {
return os;
}
ddim.apply_visitor(DDimPrinter(os));
return os;
}
......
......@@ -52,6 +52,9 @@ namespace phi {
template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
if (n == -1) {
return;
}
PADDLE_VISIT_DDIM(n, (static_dim_assign<kRank, T1, T2>(in, out)));
}
......@@ -64,7 +67,7 @@ class DDim {
public:
constexpr static int kMaxRank = 9;
DDim() : rank_(1) { dim_[0] = 0; }
DDim() : rank_(-1) { dim_[0] = 0; }
DDim(const DDim& ddim) : dim_() { CopyFrom(ddim); }
......@@ -177,6 +180,10 @@ class DDim {
}
inline DDim& CopyFrom(const DDim& ddim) {
if (ddim.rank_ == -1) {
rank_ = -1;
return *this;
}
PADDLE_VISIT_DDIM(ddim.rank_, (*this = ddim.UnsafeCast<kRank>()));
}
......@@ -210,6 +217,9 @@ DDim make_ddim(std::initializer_list<int64_t> dims);
template <typename T = int64_t>
std::vector<T> vectorize(const DDim& ddim) {
if (ddim.size() == -1) {
return std::vector<T>({0});
}
std::vector<T> result(DDim::kMaxRank);
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
......
......@@ -223,6 +223,11 @@ void DenseTensor::set_meta(const DenseTensorMeta& meta) {
meta_.lod = meta.lod;
meta_.offset = meta.offset;
meta_.use_gpudnn = meta.use_gpudnn;
if (meta.strides.size() == -1) {
meta_.strides = meta_.calc_strides(meta_.dims);
} else {
meta_.strides = meta.strides;
}
}
/* @jim19930609: This interface will be further modified until we finalized the
......@@ -236,7 +241,21 @@ void DenseTensor::set_meta(const DenseTensorMeta& meta) {
call to mutable_data(place)
*/
void DenseTensor::ResizeAndAllocate(const DDim& dims) {
if (meta_.dims.size() != -1 && meta_.dims != dims) {
PADDLE_ENFORCE_EQ(meta_.is_contiguous(),
true,
phi::errors::InvalidArgument(
"Right now Resize is only supported for contiguous "
"Tensor. Tensor dims is %s, Tensor layout is %s, "
"Tensor stride is %s. New dims is %s.",
meta_.dims,
meta_.layout,
meta_.strides,
dims));
}
meta_.dims = dims;
meta_.strides = meta_.calc_strides(meta_.dims);
if (holder_ != nullptr && place().GetType() != AllocationType::UNDEFINED) {
mutable_data(place());
}
......
......@@ -89,6 +89,14 @@ class DenseTensor : public TensorBase,
/// \return The dims of the tensor.
const DDim& dims() const noexcept override { return meta_.dims; }
/// \brief Returns the stride of the tensor.
/// \return The stride of the tensor.
const DDim& strides() const noexcept { return meta_.strides; }
/// \brief Sets the stride of the tensor.
/// \param meta The stride of the tensor.
void set_strides(const DDim& strides) { meta_.strides = strides; }
/// \brief Returns the lod of the tensor.
/// \return The lod of the tensor.
const LoD& lod() const noexcept { return meta_.lod; }
......
......@@ -42,16 +42,18 @@ void DenseTensor::check_memory_size() const {
holder_,
phi::errors::PreconditionNotMet("Tensor holds no memory. "
"Call Tensor::mutable_data firstly."));
PADDLE_ENFORCE_LE(
numel() * SizeOf(dtype()),
memory_size(),
phi::errors::PreconditionNotMet(
"Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its "
"memory."
"But received Tensor's dimension is %d, memory's size is %d.",
numel() * SizeOf(dtype()),
memory_size()));
if (meta_.is_contiguous()) {
PADDLE_ENFORCE_LE(
numel() * SizeOf(dtype()),
memory_size(),
phi::errors::PreconditionNotMet(
"Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its "
"memory."
"But received Tensor's dimension is %d, memory's size is %d.",
numel() * SizeOf(dtype()),
memory_size()));
}
}
const Place& DenseTensor::place() const {
......@@ -64,11 +66,16 @@ const Place& DenseTensor::place() const {
phi::DataType DenseTensor::type() const { return meta_.dtype; }
void DenseTensor::set_layout(const DataLayout layout) { meta_.layout = layout; }
void DenseTensor::set_layout(const DataLayout layout) {
if (meta_.strides.size() == -1) {
meta_.strides = meta_.calc_strides(meta_.dims);
}
meta_.layout = layout;
}
// Note: When you reset holder, you need to ensure the offset is correct
void DenseTensor::ResetHolder(const std::shared_ptr<phi::Allocation>& holder) {
if (holder_) {
if (holder_ && meta_.is_contiguous()) {
PADDLE_ENFORCE_LE(
numel() * static_cast<int64_t>(SizeOf(dtype())) +
static_cast<int64_t>(meta_.offset),
......@@ -156,7 +163,20 @@ inline T* DenseTensor::mutable_data(const DDim& dims,
const Place& place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
if (meta_.dims.size() != -1 && meta_.dims != dims) {
PADDLE_ENFORCE_EQ(meta_.is_contiguous(),
true,
phi::errors::InvalidArgument(
"Right now Resize is only supported for contiguous "
"Tensor. Tensor dims is %s, Tensor layout is %s, "
"Tensor stride is %s. New dims is %s.",
meta_.dims,
meta_.layout,
meta_.strides,
dims));
}
meta_.dims = dims;
meta_.strides = meta_.calc_strides(meta_.dims);
return mutable_data<T>(place, requested_size);
}
......@@ -250,7 +270,20 @@ size_t DenseTensor::NumElements(size_t level) const {
}
DenseTensor& DenseTensor::Resize(const DDim& dims) {
if (meta_.dims.size() != -1 && meta_.dims != dims) {
PADDLE_ENFORCE_EQ(meta_.is_contiguous(),
true,
phi::errors::InvalidArgument(
"Right now Resize is only supported for contiguous "
"Tensor. Tensor dims is %s, Tensor layout is %s, "
"Tensor stride is %s. New dims is %s.",
meta_.dims,
meta_.layout,
meta_.strides,
dims));
}
meta_.dims = dims;
meta_.strides = meta_.calc_strides(meta_.dims);
return *this;
}
......@@ -358,6 +391,7 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) {
meta_.layout = src.meta_.layout;
meta_.offset = src.meta_.offset;
meta_.use_gpudnn = src.meta_.use_gpudnn;
meta_.strides = src.meta_.strides;
storage_properties_ =
std::move(CopyStorageProperties(src.storage_properties_));
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -25,9 +25,15 @@
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/phi/backends/custom/custom_device_op_list.h"
#endif
#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/string/string_helper.h"
PADDLE_DEFINE_EXPORTED_bool(
use_stride_kernel,
false,
"Whether to use strdie kernel if op support stride.");
DECLARE_int32(low_precision_op_list);
DECLARE_bool(enable_api_kernel_fallback);
DECLARE_bool(run_kp_kernel);
......@@ -218,6 +224,18 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
if (FLAGS_use_stride_kernel) {
auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU
: const_kernel_key.backend(),
phi::DataLayout::STRIDED,
const_kernel_key.dtype()});
if (stride_kernel_iter != iter->second.end()) {
return {stride_kernel_iter->second, false, true};
}
}
KernelKey kernel_key = KernelKey(const_kernel_key.backend(),
phi::DataLayout::ALL_LAYOUT,
const_kernel_key.dtype());
......@@ -226,7 +244,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
auto kernel_iter = iter->second.find(
{Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()});
if (kernel_iter != iter->second.end()) {
return {kernel_iter->second, false};
return {kernel_iter->second, false, false};
}
kernel_key =
KernelKey(Backend::GPU, kernel_key.layout(), kernel_key.dtype());
......@@ -307,7 +325,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
<< ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!";
return {kernel_iter->second, true};
return {kernel_iter->second, true, false};
}
PADDLE_ENFORCE_NE(
......@@ -322,7 +340,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernel_name,
KernelSelectionErrorMessage(kernel_name, kernel_key)));
return {kernel_iter->second, false};
return {kernel_iter->second, false, false};
}
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
......
......@@ -298,11 +298,14 @@ using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;
struct KernelResult {
KernelResult(const Kernel& kernel, bool fallback_cpu)
: kernel(kernel), has_fallback_cpu(fallback_cpu) {}
KernelResult(const Kernel& kernel, bool fallback_cpu, bool is_stride_kernel)
: kernel(kernel),
has_fallback_cpu(fallback_cpu),
is_stride_kernel(is_stride_kernel) {}
const Kernel& kernel;
bool has_fallback_cpu = false;
bool is_stride_kernel = false;
};
/**
......
......@@ -1401,6 +1401,14 @@ struct KernelRegistrar {
meta_kernel_fn, \
BACKEND_LIST)
#define PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( \
kernel_name, layout, meta_kernel_fn) \
_PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(::phi::RegType::INNER, \
kernel_name, \
layout, \
meta_kernel_fn, \
BACKEND_LIST_EXCEPT_CUSTOM)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define _DEVICE GPU,
#elif defined(PADDLE_WITH_XPU)
......@@ -1415,6 +1423,7 @@ struct KernelRegistrar {
#endif
#define BACKEND_LIST _DEVICE _CUSTOM CPU
#define BACKEND_LIST_EXCEPT_CUSTOM _DEVICE CPU
#define _PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, kernel_name, layout, meta_kernel_fn, ...) \
......
......@@ -46,6 +46,14 @@ DDim MetaTensor::dims() const {
}
}
DDim MetaTensor::strides() const {
ValidCheck(*this);
if (dynamic_cast<DenseTensor*>(tensor_)) {
return dynamic_cast<DenseTensor*>(tensor_)->strides();
}
return DDim();
}
DataType MetaTensor::dtype() const {
ValidCheck(*this);
return tensor_->dtype();
......@@ -59,8 +67,12 @@ DataLayout MetaTensor::layout() const {
void MetaTensor::set_dims(const DDim& dims) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
dims;
auto meta =
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
meta->dims = dims;
if (!strided_kernel_used_) {
meta->strides = meta->calc_strides(dims);
}
} else if (phi::StringTensor::classof(tensor_)) {
StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
->dims = dims;
......@@ -78,6 +90,14 @@ void MetaTensor::set_dims(const DDim& dims) {
}
}
void MetaTensor::set_strides(const DDim& strides) {
ValidCheck(*this);
if (dynamic_cast<DenseTensor*>(tensor_)) {
DenseTensorUtils::GetMutableMeta(dynamic_cast<DenseTensor*>(tensor_))
->strides = strides;
}
}
void MetaTensor::set_dtype(DataType dtype) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_)) {
......@@ -105,14 +125,21 @@ void MetaTensor::set_dtype(DataType dtype) {
void MetaTensor::set_layout(DataLayout layout) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->layout = layout;
auto meta =
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
meta->layout = layout;
if (!strided_kernel_used_) {
meta->strides = meta->calc_strides(meta->dims);
}
} else if (phi::StringTensor::classof(tensor_)) {
// No need to set layout
} else if (phi::SelectedRows::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->layout = layout;
auto meta = DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value());
meta->layout = layout;
if (!strided_kernel_used_) {
meta->strides = meta->calc_strides(meta->dims);
}
} else if (phi::SparseCooTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
->layout = layout;
......@@ -192,9 +219,12 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
selected_rows_out->set_rows(selected_rows_in->rows());
selected_rows_out->set_height(selected_rows_in->height());
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->dims = selected_rows_in->mutable_value()->dims();
auto meta = DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value());
meta->dims = selected_rows_in->mutable_value()->dims();
if (!strided_kernel_used_) {
meta->strides = meta->calc_strides(meta->dims);
}
} else {
set_dims(meta_tensor.dims());
}
......@@ -204,6 +234,13 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
}
}
void MetaTensor::share_strides(const MetaTensor& meta_tensor) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_)) {
set_strides(meta_tensor.strides());
}
}
bool MetaTensor::initialized() const { return tensor_ != nullptr; }
// Private Member Methods
......
......@@ -42,12 +42,19 @@ class MetaTensor {
MetaTensor() : tensor_(nullptr) {}
// supporting implicit construction is easier to use
MetaTensor(TensorBase* tensor) : tensor_(tensor) {} // NOLINT
MetaTensor(const TensorBase& tensor) // NOLINT
: tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(const TensorBase* tensor) // NOLINT
: tensor_(const_cast<TensorBase*>(tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT
MetaTensor(TensorBase* tensor, bool strided_kernel_used = false) // NOLINT
: tensor_(tensor), strided_kernel_used_(strided_kernel_used) {}
MetaTensor(const TensorBase& tensor,
bool strided_kernel_used = false)
: tensor_(const_cast<TensorBase*>(&tensor)), // NOLINT
strided_kernel_used_(strided_kernel_used) {}
MetaTensor(const TensorBase* tensor,
bool strided_kernel_used = false) // NOLINT
: tensor_(const_cast<TensorBase*>(tensor)), // NOLINT
strided_kernel_used_(strided_kernel_used) {}
MetaTensor(TensorBase& tensor, bool strided_kernel_used = false) // NOLINT
: tensor_(&tensor), // NOLINT
strided_kernel_used_(strided_kernel_used) {}
MetaTensor(MetaTensor&&) = default;
MetaTensor& operator=(MetaTensor&&) = default;
......@@ -60,13 +67,16 @@ class MetaTensor {
virtual DDim dims() const;
virtual DataType dtype() const;
virtual DataLayout layout() const;
virtual DDim strides() const;
virtual void set_dims(const DDim& dims);
virtual void set_dtype(DataType dtype);
virtual void set_layout(DataLayout layout);
virtual void set_strides(const DDim& strides);
virtual void share_lod(const MetaTensor& meta_tensor);
virtual void share_meta(const MetaTensor& meta_tensor);
virtual void share_dims(const MetaTensor& meta_tensor);
virtual void share_strides(const MetaTensor& meta_tensor);
virtual bool initialized() const;
......@@ -92,6 +102,7 @@ class MetaTensor {
TensorBase* tensor() const;
TensorBase* tensor_ = nullptr;
bool strided_kernel_used_ = false;
};
} // namespace phi
......@@ -13,13 +13,123 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/ir/core/enforce.h"
namespace phi {
DDim DenseTensorMeta::calc_strides(const DDim& dims) {
if (dims.size() == -1 || product(dims) <= 0) {
return dims;
}
DDim strides(dims);
// NOTE: The NHWC and NDHWC in Paddle are implemented by actually modifying
// the video memory data format, and stride is not required. But it may be
// used in the future. if (dims.size() == 4 && layout == DataLayout::NHWC) {
// strides[1] = 1;
// strides[3] = dims[1];
// strides[2] = strides[3] * dims[3];
// strides[0] = strides[2] * dims[2];
// } else if (dims.size() == 5 && layout == DataLayout::NDHWC) {
// strides[1] = 1;
// strides[4] = dims[1];
// strides[3] = strides[4] * dims[4];
// strides[2] = strides[3] * dims[3];
// strides[0] = strides[2] * dims[2];
// } else {
// strides[dims.size() - 1] = 1;
// for (int i = dims.size() - 2; i >= 0; --i) {
// strides[i] = strides[i + 1] * dims[i + 1];
// }
// }
auto p_dims = dims.Get();
auto p_strides = strides.GetMutable();
switch (dims.size()) {
case 0:
return strides;
case 1:
p_strides[0] = 1;
return strides;
case 2:
p_strides[1] = 1;
p_strides[0] = p_dims[1];
return strides;
case 3:
p_strides[2] = 1;
p_strides[1] = p_dims[2];
p_strides[0] = p_strides[1] * p_dims[1];
return strides;
case 4:
p_strides[3] = 1;
p_strides[2] = p_dims[3];
p_strides[1] = p_strides[2] * p_dims[2];
p_strides[0] = p_strides[1] * p_dims[1];
return strides;
case 5:
p_strides[4] = 1;
p_strides[3] = p_dims[4];
p_strides[2] = p_strides[3] * p_dims[3];
p_strides[1] = p_strides[2] * p_dims[2];
p_strides[0] = p_strides[1] * p_dims[1];
return strides;
case 6:
p_strides[5] = 1;
p_strides[4] = p_dims[5];
p_strides[3] = p_strides[4] * p_dims[4];
p_strides[2] = p_strides[3] * p_dims[3];
p_strides[1] = p_strides[2] * p_dims[2];
p_strides[0] = p_strides[1] * p_dims[1];
return strides;
case 7:
p_strides[6] = 1;
p_strides[5] = p_dims[6];
p_strides[4] = p_strides[5] * p_dims[5];
p_strides[3] = p_strides[4] * p_dims[4];
p_strides[2] = p_strides[3] * p_dims[3];
p_strides[1] = p_strides[2] * p_dims[2];
p_strides[0] = p_strides[1] * p_dims[1];
return strides;
case 8:
p_strides[7] = 1;
p_strides[6] = p_dims[7];
p_strides[5] = p_strides[6] * p_dims[6];
p_strides[4] = p_strides[5] * p_dims[5];
p_strides[3] = p_strides[4] * p_dims[4];
p_strides[2] = p_strides[3] * p_dims[3];
p_strides[1] = p_strides[2] * p_dims[2];
p_strides[0] = p_strides[1] * p_dims[1];
return strides;
case 9:
p_strides[8] = 1;
p_strides[7] = p_dims[8];
p_strides[6] = p_strides[7] * p_dims[7];
p_strides[5] = p_strides[6] * p_dims[6];
p_strides[4] = p_strides[5] * p_dims[5];
p_strides[3] = p_strides[4] * p_dims[4];
p_strides[2] = p_strides[3] * p_dims[3];
p_strides[1] = p_strides[2] * p_dims[2];
p_strides[0] = p_strides[1] * p_dims[1];
return strides;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 9, but received %d.",
dims.size()));
}
}
DenseTensorMeta::DenseTensorMeta() { use_gpudnn = true; }
DenseTensorMeta::DenseTensorMeta(DataType dtype, const DDim& dims)
: dims(dims), dtype(dtype) {
strides = calc_strides(dims);
use_gpudnn = true;
}
DenseTensorMeta::DenseTensorMeta(DataType dtype,
const DDim& dims,
const DDim& strides)
: dims(dims), dtype(dtype), strides(strides) {
use_gpudnn = true;
}
......@@ -28,6 +138,7 @@ DenseTensorMeta::DenseTensorMeta(DataType dtype,
DataLayout layout,
size_t offset)
: dims(dims), dtype(dtype), layout(layout), offset(offset) {
strides = calc_strides(dims);
use_gpudnn = true;
}
......@@ -37,9 +148,58 @@ DenseTensorMeta::DenseTensorMeta(DataType dtype,
const LoD& lod,
size_t offset)
: dims(dims), dtype(dtype), layout(layout), lod(lod), offset(offset) {
strides = calc_strides(dims);
use_gpudnn = true;
}
DenseTensorMeta::DenseTensorMeta(const DenseTensorMeta& other) {
is_scalar = other.is_scalar;
use_gpudnn = other.use_gpudnn;
dims = other.dims;
dtype = other.dtype;
layout = other.layout;
lod = other.lod;
offset = other.offset;
if (other.strides.size() == -1) {
strides == calc_strides(dims);
} else {
strides = other.strides;
}
}
DenseTensorMeta& DenseTensorMeta::operator=(const DenseTensorMeta& other) {
is_scalar = other.is_scalar;
use_gpudnn = other.use_gpudnn;
dims = other.dims;
dtype = other.dtype;
layout = other.layout;
lod = other.lod;
offset = other.offset;
if (other.strides.size() == -1) {
strides == calc_strides(dims);
} else {
strides = other.strides;
}
return *this;
}
DenseTensorMeta& DenseTensorMeta::operator=(DenseTensorMeta&& other) {
is_scalar = other.is_scalar;
use_gpudnn = other.use_gpudnn;
dims = std::move(other.dims);
dtype = other.dtype;
layout = other.layout;
lod = std::move(other.lod);
offset = other.offset;
if (other.strides.size() == -1) {
strides == calc_strides(dims);
} else {
strides = std::move(other.strides);
}
return *this;
}
bool DenseTensorMeta::valid() const noexcept {
bool valid{true};
valid = valid && (dtype != DataType::UNDEFINED);
......@@ -48,6 +208,10 @@ bool DenseTensorMeta::valid() const noexcept {
return valid;
}
bool DenseTensorMeta::is_contiguous() const noexcept {
return strides == calc_strides(dims);
}
StringTensorMeta::StringTensorMeta(const DDim& dims) : dims(dims) {}
bool StringTensorMeta::valid() const noexcept {
......
......@@ -48,6 +48,7 @@ using LoD = std::vector<std::vector<size_t>>;
struct DenseTensorMeta {
DenseTensorMeta();
DenseTensorMeta(DataType dtype, const DDim& dims);
DenseTensorMeta(DataType dtype, const DDim& dims, const DDim& stride);
DenseTensorMeta(DataType dtype,
const DDim& dims,
DataLayout layout,
......@@ -58,10 +59,19 @@ struct DenseTensorMeta {
const LoD& lod,
size_t offset = 0);
DenseTensorMeta(const DenseTensorMeta& other);
DenseTensorMeta& operator=(const DenseTensorMeta& other);
DenseTensorMeta& operator=(DenseTensorMeta&& other);
static DDim calc_strides(const DDim& dims);
/// \brief Test whether the metadata is valid. Does not throw exceptions.
/// \return Whether the metadata is valid.
bool valid() const noexcept;
bool is_contiguous() const noexcept;
bool is_scalar{false};
/// \brief Determine whether using gpudnn speed-up library in the new dygraph.
/// It maybe also support MKLDNN library in the near future.
......@@ -71,13 +81,14 @@ struct DenseTensorMeta {
DataLayout layout{DataLayout::NCHW};
LoD lod;
size_t offset{0};
DDim strides;
};
inline bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) {
return (lhs.is_scalar == rhs.is_scalar) && lhs.use_gpudnn == rhs.use_gpudnn &&
(lhs.dims == rhs.dims) && (lhs.dtype == rhs.dtype) &&
(lhs.layout == rhs.layout) && (lhs.lod == rhs.lod) &&
(lhs.offset == rhs.offset);
(lhs.offset == rhs.offset) && (lhs.strides == rhs.strides);
}
struct StringTensorMeta {
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
......@@ -31,6 +32,12 @@ void Copy(const Context& dev_ctx,
Place dst_place,
bool blocking,
DenseTensor* dst) {
if (!src.meta().is_contiguous()) {
DenseTensor src_copy = paddle::experimental::Trans2Contiguous(src);
Copy(dev_ctx, src_copy, dst_place, blocking, dst);
return;
}
auto* src_ptr = src.data();
const auto& src_place = src.place();
......@@ -253,6 +260,7 @@ void Copy(const Context& dev_ctx,
PADDLE_THROW(errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
dst->set_strides(src.strides());
}
template <typename Context>
......
......@@ -270,7 +270,30 @@ namespace phi {
"`"); \
} \
}()
#if defined(PADDLE_WITH_XPU)
#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::BOOL, bool, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT8, int8_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::UINT8, uint8_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT16, int16_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT32, int32_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT64, int64_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::phi::DataType::BFLOAT16, phi::bfloat16, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::phi::DataType::FLOAT16, phi::float16, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::phi::DataType::FLOAT64, double, __VA_ARGS__) \
default: \
PADDLE_THROW(phi::errors::InvalidArgument( \
"Invalid enum data type `%d`.", static_cast<int>(__dtype__))); \
} \
}()
#else
#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
......@@ -297,6 +320,7 @@ namespace phi {
"Invalid enum data type `%d`.", static_cast<int>(__dtype__))); \
} \
}()
#endif
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(SPECIFIED_TYPE1, \
SPECIFIED_TYPE2, \
......
......@@ -1785,6 +1785,33 @@ void IndexSelectInferMeta(const MetaTensor& x,
output->share_lod(x);
}
void IndexSelectStridedInferMeta(const MetaTensor& x,
int64_t index,
int dim,
MetaTensor* output) {
auto input_dim = x.dims();
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(),
input_dim.size() - 1,
dim));
auto output_dim = phi::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
}
output_dim.erase(output_dim.begin() + dim);
output->set_dims(phi::make_ddim(output_dim));
output->set_dtype(x.dtype());
output->set_layout(x.layout());
output->share_lod(x);
}
void IndexAddInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& add_value,
......
......@@ -287,6 +287,11 @@ void IndexSelectInferMeta(const MetaTensor& x,
int dim,
MetaTensor* output);
void IndexSelectStridedInferMeta(const MetaTensor& x,
int64_t index,
int dim,
MetaTensor* output);
void IndexAddInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& add_value,
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
......@@ -352,14 +353,13 @@ void BatchSizeLikeInferMeta(const MetaTensor& x,
phi::errors::InvalidArgument("Input dimension index must be larger "
"equal than 0, but received: %s.",
x_batch_size_dim));
PADDLE_ENFORCE_GT(input_dim_size,
x_batch_size_dim,
phi::errors::InvalidArgument(
"Input dimension size must be larger than "
"input dimension index, but received input "
"dimension size: %s, input dimension index: %s.",
input_dim_size,
x_batch_size_dim));
PADDLE_ENFORCE(input_dim_size > x_batch_size_dim || input_dim_size == -1,
phi::errors::InvalidArgument(
"Input dimension size must be larger than "
"input dimension index, but received input "
"dimension size: %s, input dimension index: %s.",
input_dim_size,
x_batch_size_dim));
int output_dim_size = static_cast<int>(shape.size());
PADDLE_ENFORCE_GE(
......@@ -829,6 +829,7 @@ void DiagonalInferMeta(const MetaTensor& input,
}
}
out->set_dims(phi::make_ddim(out_dims));
out->set_dtype(input.dtype());
}
void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
......@@ -3340,6 +3341,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_strides(x.strides());
ReshapeInferMeta(x, shape, out, config);
}
......@@ -3597,6 +3599,7 @@ void SliceRawInferMeta(const MetaTensor& input,
if (!new_axes.empty() && new_axes[0] != 0) {
out->share_lod(input);
}
out->set_dtype(input.dtype());
}
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) {
......@@ -3869,6 +3872,7 @@ void SqueezeInferMeta(const MetaTensor& x,
output_size = 0;
}
std::vector<int64_t> vec_out_dims(output_size, -1);
out->set_dims(phi::make_ddim(vec_out_dims));
} else {
std::vector<int32_t> tmp;
......@@ -5114,6 +5118,11 @@ void QuantForCompressInferMeta(const MetaTensor& x,
scale->set_dtype(DataType::FLOAT32);
}
void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out) {
out->share_meta(x);
out->set_strides(x.strides());
}
} // namespace phi
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
......@@ -734,4 +734,6 @@ void QuantForCompressInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* scale);
void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out);
} // namespace phi
......@@ -95,6 +95,7 @@ set(cc_search_pattern
"strings/*.cc"
"strings/cpu/*.cc"
"fusion/*.cc"
"stride/*.cc"
"fusion/cpu/*.cc")
if(WITH_MKLDNN)
......
......@@ -23,4 +23,9 @@ void AsComplexKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename T, typename Context>
void AsComplexStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
......@@ -23,4 +23,9 @@ void AsRealKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename T, typename Context>
void AsRealStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename Context>
void AsStridedGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& stride,
int64_t offset,
DenseTensor* input_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename Context>
void AsStridedKernel(const Context& dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& stride,
int64_t offset,
DenseTensor* out);
} // namespace phi
......@@ -36,4 +36,14 @@ void ComplexGradKernel(const Context& dev_ctx,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void RealGradStridedKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx);
template <typename T, typename Context>
void ImagGradStridedKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx);
} // namespace phi
......@@ -36,6 +36,16 @@ void ComplexKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void RealStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename T, typename Context>
void ImagStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
// If T is complex
template <
typename T,
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
/**
* @brief Computes a contiguous in memory tensor containing the same data as
* input tensor.
* @param ctx device context
* @param input Source tensor need be computed
* @param out The contiguous in memory tensor
*/
template <typename T, typename Context>
void ContiguousKernel(const Context& dev_ctx,
const DenseTensor& input,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Contiguous(const Context& dev_ctx, const DenseTensor& input) {
DenseTensor dense_out;
MetaTensor meta_input(input);
MetaTensor meta_out(&dense_out);
UnchangedInferMeta(meta_input, &meta_out);
ContiguousKernel<T, Context>(dev_ctx, input, &dense_out);
return dense_out;
}
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/contiguous_kernel.h"
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void ContiguousKernel(const Context& dev_ctx,
const DenseTensor& input,
DenseTensor* out) {
phi::DenseTensorMeta meta = input.meta();
meta.strides = meta.calc_strides(meta.dims);
meta.offset = 0;
out->set_meta(meta);
const T* input_data = input.data<T>();
T* output_data = dev_ctx.template Alloc<T>(out);
int rank = input.dims().size();
auto dims = input.dims();
auto input_stride = input.strides();
auto numel = input.numel();
for (int64_t i = 0; i < numel; i++) {
int64_t input_offset = 0;
int64_t index_tmp = i;
for (int dim = rank - 1; dim >= 0; --dim) {
int64_t mod = index_tmp % dims[dim];
index_tmp = index_tmp / dims[dim];
input_offset += mod * input_stride[dim];
}
output_data[i] = input_data[input_offset];
}
}
} // namespace phi
PD_REGISTER_KERNEL(contiguous,
CPU,
ALL_LAYOUT,
phi::ContiguousKernel,
bool,
uint8_t,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
::phi::dtype::float16,
::phi::dtype::bfloat16,
::phi::dtype::complex<float>,
::phi::dtype::complex<double>) {}
......@@ -21,10 +21,15 @@ PD_REGISTER_KERNEL(fill,
CPU,
ALL_LAYOUT,
phi::FillKernel,
bool,
uint8_t,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
int64_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
bool) {}
::phi::dtype::float16,
::phi::dtype::bfloat16,
::phi::dtype::complex<float>,
::phi::dtype::complex<double>) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/strided_copy_kernel.h"
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void StridedCopyKernel(const Context& dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& out_stride,
int64_t offset,
DenseTensor* out) {
phi::DenseTensorMeta meta = input.meta();
meta.strides = phi::make_ddim(out_stride);
meta.dims = phi::make_ddim(dims);
meta.offset = offset;
out->set_meta(meta);
PADDLE_ENFORCE_EQ(input.dims(),
out->dims(),
phi::errors::InvalidArgument(
"Input shape(%s) must be equal with out shape(%s).",
input.dims(),
out->dims()));
PADDLE_ENFORCE_EQ(input.numel(),
out->numel(),
phi::errors::InvalidArgument(
"Input numel(%d) must be equal with out numel(%d).",
input.numel(),
out->numel()));
if (input.numel() <= 0) {
return;
}
const T* input_data = input.data<T>();
int input_rank = input.dims().size();
const int64_t* input_dims = input.dims().Get();
const int64_t* input_stride = input.strides().Get();
T* output_data = out->data<T>();
PADDLE_ENFORCE_NOT_NULL(output_data,
phi::errors::InvalidArgument(
"StridedCopyKernel's out tensor must complete "
"mutable data before call kernel."));
int output_rank = meta.dims.size();
const int64_t* output_dims = meta.dims.Get();
const int64_t* output_stride = meta.strides.Get();
auto numel = input.numel();
for (int64_t i = 0; i < numel; i++) {
int64_t input_offset = 0;
int64_t index_tmp = i;
for (int dim = input_rank - 1; dim >= 0; --dim) {
input_offset += (index_tmp % input_dims[dim]) * input_stride[dim];
index_tmp = index_tmp / input_dims[dim];
}
int64_t output_offset = 0;
index_tmp = i;
for (int dim = output_rank - 1; dim >= 0; --dim) {
output_offset += (index_tmp % output_dims[dim]) * output_stride[dim];
index_tmp = index_tmp / output_dims[dim];
}
output_data[output_offset] = input_data[input_offset];
}
}
} // namespace phi
PD_REGISTER_KERNEL(strided_copy,
CPU,
ALL_LAYOUT,
phi::StridedCopyKernel,
bool,
uint8_t,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
::phi::dtype::float16,
::phi::dtype::bfloat16,
::phi::dtype::complex<float>,
::phi::dtype::complex<double>) {}
......@@ -26,4 +26,13 @@ void DiagonalGradKernel(const Context& dev_ctx,
int axis1,
int axis2,
DenseTensor* in_grad);
template <typename Context>
void DiagonalGradStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
int axis1,
int axis2,
DenseTensor* in_grad);
} // namespace phi
......@@ -41,6 +41,14 @@ void DiagonalKernel(const Context& dev_ctx,
int axis2,
DenseTensor* out);
template <typename Context>
void DiagonalStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
int offset,
int axis1,
int axis2,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Diagonal(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/visit_type.h"
namespace phi {
......
......@@ -24,4 +24,10 @@ void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad);
template <typename Context>
void FlattenGradStridedKernel(const Context& dev_ctx,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
......@@ -35,6 +35,21 @@ void FlattenKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape);
template <typename Context>
void FlattenInferStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);
template <typename Context>
void FlattenStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
DenseTensor Flatten(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -53,6 +53,7 @@ template struct SetConstant<phi::CPUContext, int>;
template struct SetConstant<phi::CPUContext, int64_t>;
template struct SetConstant<phi::CPUContext, bool>;
template struct SetConstant<phi::CPUContext, uint8_t>;
template struct SetConstant<phi::CPUContext, int8_t>;
template struct SetConstant<phi::CPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::CPUContext, phi::dtype::complex<double>>;
......@@ -62,6 +63,7 @@ template struct SetConstant<phi::XPUContext, phi::dtype::bfloat16>;
template struct SetConstant<phi::XPUContext, float>;
template struct SetConstant<phi::XPUContext, double>;
template struct SetConstant<phi::XPUContext, uint8_t>;
template struct SetConstant<phi::XPUContext, int8_t>;
template struct SetConstant<phi::XPUContext, int16_t>;
template struct SetConstant<phi::XPUContext, int>;
template struct SetConstant<phi::XPUContext, int64_t>;
......
......@@ -146,6 +146,7 @@ template struct SetConstant<phi::GPUContext, bfloat16>;
template struct SetConstant<phi::GPUContext, float>;
template struct SetConstant<phi::GPUContext, double>;
template struct SetConstant<phi::GPUContext, uint8_t>;
template struct SetConstant<phi::GPUContext, int8_t>;
template struct SetConstant<phi::GPUContext, int>;
template struct SetConstant<phi::GPUContext, int16_t>;
template struct SetConstant<phi::GPUContext, int64_t>;
......@@ -158,6 +159,7 @@ template struct SetConstant<phi::GPUPinnedContext, bfloat16>;
template struct SetConstant<phi::GPUPinnedContext, float>;
template struct SetConstant<phi::GPUPinnedContext, double>;
template struct SetConstant<phi::GPUPinnedContext, uint8_t>;
template struct SetConstant<phi::GPUPinnedContext, int8_t>;
template struct SetConstant<phi::GPUPinnedContext, int>;
template struct SetConstant<phi::GPUPinnedContext, int16_t>;
template struct SetConstant<phi::GPUPinnedContext, int64_t>;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/funcs/strided_reshape_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/reshape_kernel.h"
namespace phi {
bool ReshapeStride(const DDim& old_dims,
const DDim& old_stride,
const DDim& new_dims,
DDim& new_stride) { // NOLINT
int64_t numel = product(old_dims);
if (numel < 0) {
int64_t tmp[2];
tmp[0] = 1;
tmp[1] = new_dims.size();
new_stride = DDim(tmp, 2);
return true;
} else if (numel == 0) {
if (old_dims == new_dims) {
new_stride = old_stride;
} else {
new_stride = new_dims;
new_stride[new_dims.size() - 1] = 1;
for (int i = new_dims.size() - 2; i >= 0; i--) {
new_stride[i] = new_stride[i + 1] *
std::max(static_cast<int64_t>(1), new_dims[i + 1]);
}
}
return true;
} else {
int64_t old_numel = 1;
int64_t new_numel = 1;
int64_t old_stride_lastvalue = old_stride[old_stride.size() - 1];
int new_stride_index = new_dims.size() - 1;
new_stride = new_dims;
for (int old_dims_index = old_dims.size() - 1; old_dims_index >= 0;
old_dims_index--) {
old_numel *= old_dims[old_dims_index];
if ((old_dims_index == 0) || (old_dims[old_dims_index - 1] != 1 &&
old_stride[old_dims_index - 1] !=
old_numel * old_stride_lastvalue)) {
while (new_stride_index >= 0 &&
(new_numel < old_numel || new_dims[new_stride_index] == 1)) {
new_stride[new_stride_index] = new_numel * old_stride_lastvalue;
new_numel *= new_dims[new_stride_index];
new_stride_index--;
}
if (new_numel != old_numel) {
return false;
}
if (old_dims_index > 0) {
old_numel = 1;
new_numel = 1;
old_stride_lastvalue = old_stride[old_dims_index - 1];
}
}
}
if (new_stride_index != -1) {
return false;
}
return true;
}
return false;
}
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/reshape_kernel.h"
namespace phi {
bool ReshapeStride(const DDim& old_dims,
const DDim& old_stride,
const DDim& new_dims,
DDim& new_stride); // NOLINT
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/contiguous_kernel.h"
#include <set>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
template <typename T, size_t N>
__global__ void ContiguousFunc(
const T* input_data,
T* out_data,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> input_stride,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> dims,
const int64_t numel) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
#pragma unroll
for (int64_t i = gid; i < numel; i += blockDim.x * gridDim.x) {
int64_t input_offset = 0;
int64_t index_tmp = i;
#pragma unroll
for (int dim = N - 1; dim >= 0; --dim) {
input_offset += index_tmp % dims[dim] * input_stride[dim];
index_tmp = index_tmp / dims[dim];
}
out_data[i] = input_data[input_offset];
}
}
bool is_only_transposed(const DDim& shape,
const DDim& stride,
uint64_t offset,
DDim& src_shape, // NOLINT
DDim& src_stride, // NOLINT
std::vector<int>& axis) { // NOLINT
if (offset != 0) {
return false;
}
std::set<int> visited_idx;
axis.resize(stride.size());
for (int i = 0; i < stride.size(); i++) {
int64_t max_num = 0;
int max_idx = -1;
for (int j = 0; j < stride.size(); j++) {
if (visited_idx.count(j)) {
continue;
}
if (stride[j] < 1) {
return false;
}
if (stride[j] > max_num) {
max_num = stride[j];
max_idx = j;
}
}
if (max_idx == -1) {
return false;
}
if (i != 0 && src_stride[i - 1] == max_num) {
return false;
}
visited_idx.insert(max_idx);
src_stride[i] = max_num;
src_shape[i] = shape[max_idx];
axis[max_idx] = i;
}
if (DenseTensorMeta::calc_strides(src_shape) == src_stride) {
return true;
} else {
return false;
}
}
template <typename T, typename Context>
void ContiguousKernel(const Context& dev_ctx,
const DenseTensor& input,
DenseTensor* out) {
phi::DenseTensorMeta meta = input.meta();
std::vector<int> axis;
DDim src_stride = meta.strides;
DDim src_shape = meta.dims;
if (is_only_transposed(
meta.dims, meta.strides, meta.offset, src_shape, src_stride, axis)) {
meta.strides = meta.calc_strides(meta.dims);
out->set_meta(meta);
DenseTensor tmp_tensor = input;
phi::DenseTensorMeta tmp_meta = meta;
tmp_meta.strides = src_stride;
tmp_meta.dims = src_shape;
tmp_tensor.set_meta(tmp_meta);
TransposeKernel<T, Context>(dev_ctx, tmp_tensor, axis, out);
return;
}
meta.strides = meta.calc_strides(meta.dims);
meta.offset = 0;
out->set_meta(meta);
const T* input_data = input.data<T>();
T* output_data = dev_ctx.template Alloc<T>(out);
int rank = input.dims().size();
auto numel = input.numel();
if (numel <= 0) {
return;
}
phi::Array<int64_t, phi::DDim::kMaxRank + 1> input_stride;
phi::Array<int64_t, phi::DDim::kMaxRank + 1> input_dims;
for (int i = 0; i < input.dims().size(); i++) {
input_dims[i] = input.dims()[i];
input_stride[i] = input.strides()[i];
}
if (rank == 0) {
rank = 1;
input_dims[0] = numel;
input_stride[0] = 1;
}
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
switch (rank) {
case 1:
ContiguousFunc<T, 1><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 2:
ContiguousFunc<T, 2><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 3:
ContiguousFunc<T, 3><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 4:
ContiguousFunc<T, 4><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 5:
ContiguousFunc<T, 5><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 6:
ContiguousFunc<T, 6><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 7:
ContiguousFunc<T, 7><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 8:
ContiguousFunc<T, 8><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 9:
ContiguousFunc<T, 9><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 9, but received %d.", rank));
}
}
} // namespace phi
PD_REGISTER_KERNEL(contiguous,
GPU,
ALL_LAYOUT,
phi::ContiguousKernel,
bool,
uint8_t,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
::phi::dtype::float16,
::phi::dtype::bfloat16,
::phi::dtype::complex<float>,
::phi::dtype::complex<double>) {}
......@@ -22,10 +22,15 @@ PD_REGISTER_KERNEL(fill,
GPU,
ALL_LAYOUT,
phi::FillKernel,
bool,
uint8_t,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
int64_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
bool) {}
::phi::dtype::float16,
::phi::dtype::bfloat16,
::phi::dtype::complex<float>,
::phi::dtype::complex<double>) {}
此差异已折叠。
......@@ -56,6 +56,9 @@ PD_REGISTER_KERNEL(transpose,
ALL_LAYOUT,
phi::TransposeKernel,
bool,
uint8_t,
int8_t,
int16_t,
float,
double,
int32_t,
......
......@@ -26,4 +26,12 @@ void IndexSelectGradKernel(const Context& ctx,
int dim,
DenseTensor* x_grad);
template <typename Context>
void IndexSelectGradStridedKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int64_t index,
int dim,
DenseTensor* x_grad);
} // namespace phi
......@@ -25,4 +25,11 @@ void IndexSelectKernel(const Context& ctx,
int dim,
DenseTensor* output);
template <typename Context>
void IndexSelectStridedKernel(const Context& ctx,
const DenseTensor& x,
int64_t index,
int dim,
DenseTensor* output);
} // namespace phi
......@@ -29,4 +29,15 @@ void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad);
template <typename Context>
void ReshapeGradStridedKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad);
template <typename Context>
void ReshapeDoubleGradStridedKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad);
} // namespace phi
......@@ -34,6 +34,13 @@ void ReshapeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape);
template <typename Context>
void ReshapeStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
DenseTensor Reshape(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -31,6 +31,17 @@ void SliceGradKernel(const Context& ctx,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad);
template <typename Context>
void SliceGradStridedKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad);
template <typename T, typename Context>
void SliceArrayGradKernel(const Context& dev_ctx,
const TensorArray& input,
......
......@@ -44,6 +44,16 @@ void SliceArrayDenseKernel(const Context& dev_ctx,
const IntArray& starts,
DenseTensor* out);
template <typename Context>
void SliceStridedKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Slice(const Context& ctx,
const DenseTensor& input,
......
......@@ -35,6 +35,20 @@ void SplitWithNumKernel(const Context& dev_ctx,
const Scalar& axis,
std::vector<DenseTensor*> out);
template <typename Context>
void SplitStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis,
std::vector<DenseTensor*> out);
template <typename Context>
void SplitWithNumStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
int num,
const Scalar& axis,
std::vector<DenseTensor*> out);
template <typename T, typename Context>
std::vector<DenseTensor> Split(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -26,4 +26,11 @@ void SqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
const IntArray& axes,
DenseTensor* dx);
template <typename Context>
void SqueezeGradStridedKernel(const Context& dev_ctx,
const DenseTensor& xshape,
const DenseTensor& dout,
const IntArray& axes,
DenseTensor* dx);
} // namespace phi
......@@ -34,6 +34,19 @@ void SqueezeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape);
template <typename Context>
void SqueezeInferStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out);
template <typename Context>
void SqueezeStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void Squeeze(const Context& dev_ctx,
const DenseTensor& x,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/as_complex_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AsComplexStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
out->set_strides(DenseTensorMeta::calc_strides(out->dims()));
if (x.dtype() == DataType::FLOAT32) {
out->set_type(DataType::COMPLEX64);
} else if (x.dtype() == DataType::FLOAT64) {
out->set_type(DataType::COMPLEX128);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"as_complex is not supported data type (%s).",
DataTypeToString(x.dtype())));
}
out->set_offset(x.offset());
out->ResetHolder(x.Holder());
}
} // namespace phi
PD_REGISTER_KERNEL(
as_complex, CPU, STRIDED, phi::AsComplexStridedKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
as_complex, GPU, STRIDED, phi::AsComplexStridedKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/as_real_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AsRealStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
out->set_strides(DenseTensorMeta::calc_strides(out->dims()));
if (x.dtype() == DataType::COMPLEX64) {
out->set_type(DataType::FLOAT32);
} else if (x.dtype() == DataType::COMPLEX128) {
out->set_type(DataType::FLOAT64);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("as_real is not supported data type (%s).",
DataTypeToString(x.dtype())));
}
out->set_offset(x.offset());
out->ResetHolder(x.Holder());
}
} // namespace phi
PD_REGISTER_KERNEL(as_real,
CPU,
STRIDED,
phi::AsRealStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(as_real,
GPU,
STRIDED,
phi::AsRealStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/as_strided_grad_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/as_strided_kernel.h"
#include "paddle/phi/kernels/fill_kernel.h"
#include "paddle/phi/kernels/strided_copy_kernel.h"
namespace phi {
template <typename Context>
void AsStridedGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& stride,
int64_t offset,
DenseTensor* input_grad) {
dev_ctx.Alloc(input_grad, input_grad->dtype());
input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims()));
PD_VISIT_ALL_TYPES(input_grad->dtype(), "AsStridedGradKernel", ([&] {
phi::FillKernel<data_t, Context>(
dev_ctx, *input_grad, 0, input_grad);
}));
DenseTensor tmp;
tmp.set_meta(out_grad.meta());
AsStridedKernel<Context>(dev_ctx, *input_grad, dims, stride, offset, &tmp);
PD_VISIT_ALL_TYPES(out_grad.dtype(), "AsStridedGradKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx,
out_grad,
phi::vectorize<int64_t>(tmp.dims()),
phi::vectorize<int64_t>(tmp.strides()),
tmp.offset(),
&tmp);
}));
}
} // namespace phi
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(
as_strided_grad, STRIDED, phi::AsStridedGradKernel) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/as_strided_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename Context>
void AsStridedKernel(const Context& dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& stride,
int64_t offset,
DenseTensor* out) {
out->Resize(DDim(dims.data(), dims.size()));
out->set_strides(DDim(stride.data(), stride.size()));
out->set_offset(offset);
out->ResetHolder(input.Holder());
}
} // namespace phi
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(as_strided,
STRIDED,
phi::AsStridedKernel) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/complex_grad_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/fill_kernel.h"
#include "paddle/phi/kernels/strided_copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void RealGradStridedKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx) {
dev_ctx.Alloc(dx, dx->dtype());
dx->set_strides(DenseTensorMeta::calc_strides(dx->dims()));
PD_VISIT_ALL_TYPES(dx->dtype(), "RealGradStridedKernel", ([&] {
phi::FillKernel<data_t, Context>(dev_ctx, *dx, 0, dx);
}));
DenseTensor tmp;
tmp.set_meta(dout.meta());
RealStridedKernel<T, Context>(dev_ctx, *dx, &tmp);
PD_VISIT_ALL_TYPES(dout.dtype(), "RealGradStridedKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx,
dout,
phi::vectorize<int64_t>(tmp.dims()),
phi::vectorize<int64_t>(tmp.strides()),
tmp.offset(),
&tmp);
}));
}
template <typename T, typename Context>
void ImagGradStridedKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx) {
dev_ctx.Alloc(dx, dx->dtype());
dx->set_strides(DenseTensorMeta::calc_strides(dx->dims()));
PD_VISIT_ALL_TYPES(dx->dtype(), "ImagGradStridedKernel", ([&] {
phi::FillKernel<data_t, Context>(dev_ctx, *dx, 0, dx);
}));
DenseTensor tmp;
tmp.set_meta(dout.meta());
ImagStridedKernel<T, Context>(dev_ctx, *dx, &tmp);
PD_VISIT_ALL_TYPES(dout.dtype(), "ImagGradStridedKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx,
dout,
phi::vectorize<int64_t>(tmp.dims()),
phi::vectorize<int64_t>(tmp.strides()),
tmp.offset(),
&tmp);
}));
}
} // namespace phi
PD_REGISTER_KERNEL(real_grad,
CPU,
STRIDED,
phi::RealGradStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag_grad,
CPU,
STRIDED,
phi::ImagGradStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(real_grad,
GPU,
STRIDED,
phi::RealGradStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag_grad,
GPU,
STRIDED,
phi::ImagGradStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RealStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
if (x.dtype() != DataType::COMPLEX64 && x.dtype() != DataType::COMPLEX128) {
PADDLE_THROW(
phi::errors::NotFound("paddle.real only support COMPLEX64 and "
"COMPLEX128, but the input dtype is %s",
x.dtype()));
}
DDim stride = x.strides();
for (int i = 0; i < stride.size(); i++) {
stride[i] = x.strides()[i] * 2;
}
out->set_offset(x.offset());
out->set_strides(stride);
out->ResetHolder(x.Holder());
}
template <typename T, typename Context>
void ImagStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
if (x.dtype() != DataType::COMPLEX64 && x.dtype() != DataType::COMPLEX128) {
PADDLE_THROW(
phi::errors::NotFound("paddle.imag only support COMPLEX64 and "
"COMPLEX128, but the input dtype is %s",
x.dtype()));
}
DDim stride = x.strides();
for (int i = 0; i < stride.size(); i++) {
stride[i] = x.strides()[i] * 2;
}
out->set_strides(stride);
out->set_offset(x.offset() + phi::SizeOf(out->dtype()));
out->ResetHolder(x.Holder());
}
} // namespace phi
PD_REGISTER_KERNEL(real,
CPU,
STRIDED,
phi::RealStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag,
CPU,
STRIDED,
phi::ImagStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(real,
GPU,
STRIDED,
phi::RealStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag,
GPU,
STRIDED,
phi::ImagStridedKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diagonal_grad_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/diagonal_kernel.h"
#include "paddle/phi/kernels/fill_kernel.h"
#include "paddle/phi/kernels/strided_copy_kernel.h"
namespace phi {
template <typename Context>
void DiagonalGradStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
int axis1,
int axis2,
DenseTensor* in_grad) {
dev_ctx.Alloc(in_grad, in_grad->dtype());
in_grad->set_strides(DenseTensorMeta::calc_strides(in_grad->dims()));
PD_VISIT_ALL_TYPES(in_grad->dtype(), "DiagonalGradStridedKernel", ([&] {
phi::FillKernel<data_t, Context>(
dev_ctx, *in_grad, 0, in_grad);
}));
DenseTensor tmp;
DiagonalStridedKernel<Context>(dev_ctx, *in_grad, offset, axis1, axis2, &tmp);
PD_VISIT_ALL_TYPES(out_grad.dtype(), "DiagonalGradStridedKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx,
out_grad,
phi::vectorize<int64_t>(tmp.dims()),
phi::vectorize<int64_t>(tmp.strides()),
tmp.offset(),
&tmp);
}));
}
} // namespace phi
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(
diagonal_grad, STRIDED, phi::DiagonalGradStridedKernel) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diagonal_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename Context>
void DiagonalStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
int offset,
int axis1,
int axis2,
DenseTensor* out) {
size_t x_rank = x.dims().size();
if (axis1 < 0) {
axis1 += x_rank;
}
if (axis2 < 0) {
axis2 += x_rank;
}
int64_t diag_size;
int64_t x_offset = x.offset();
if (offset >= 0) {
diag_size = std::max<int64_t>(
std::min(x.dims()[axis1], x.dims()[axis2] - offset), 0);
if (diag_size != 0) {
x_offset += offset * x.strides()[axis2] * SizeOf(x.dtype());
}
} else {
diag_size = std::max<int64_t>(
std::min(x.dims()[axis1] + offset, x.dims()[axis2]), 0);
if (diag_size != 0) {
x_offset -= offset * x.strides()[axis1] * SizeOf(x.dtype());
}
}
std::vector<int64_t> shape = phi::vectorize<int64_t>(x.dims());
std::vector<int64_t> stride = phi::vectorize<int64_t>(x.strides());
shape.erase(shape.begin() + std::max(axis1, axis2));
stride.erase(stride.begin() + std::max(axis1, axis2));
shape.erase(shape.begin() + std::min(axis1, axis2));
stride.erase(stride.begin() + std::min(axis1, axis2));
shape.push_back(diag_size);
stride.push_back(x.strides()[axis1] + x.strides()[axis2]);
auto meta = out->meta();
auto tmp_dim = DDim(shape.data(), shape.size());
// if (product(meta.dims) > 0 && meta.dims != tmp_dim) {
// PADDLE_THROW(
// phi::errors::Fatal("Diagonal kernel stride compute diff, infer shape
// "
// "is %s, but compute is %s.",
// meta.dims,
// tmp_dim));
// }
meta.dims = tmp_dim;
meta.strides = DDim(stride.data(), stride.size());
meta.offset = x_offset;
out->set_meta(meta);
out->ResetHolder(x.Holder());
}
} // namespace phi
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(
diagonal, STRIDED, phi::DiagonalStridedKernel) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flatten_grad_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/reshape_kernel.h"
namespace phi {
template <typename Context>
void FlattenGradStridedKernel(const Context& dev_ctx,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto xshape_dims = xshape.dims();
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
ReshapeStridedKernel<Context>(dev_ctx,
out_grad,
IntArray(phi::vectorize<int64_t>(x_dims)),
x_grad,
nullptr);
}
} // namespace phi
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(
flatten_grad, STRIDED, phi::FlattenGradStridedKernel) {}
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册