未验证 提交 895692e3 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Reshape Kernel Refactor (#37164)

* reshape kernel refactor

* fix compile bugs when run ci

* support xpu for reshape

* fix bugs when run unittest in kunlun ci

* fix compile bugs when run kunlun

* perfect code according to suggestion
上级 228eb898
...@@ -1883,6 +1883,10 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1883,6 +1883,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
pt_kernel_context_->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "unsupported cast op attribute `%s` when construct "
......
...@@ -372,6 +372,9 @@ static void BuildDygraphPtenKernelContext( ...@@ -372,6 +372,9 @@ static void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "unsupported cast op attribute `%s` when construct "
......
...@@ -15,7 +15,12 @@ limitations under the License. */ ...@@ -15,7 +15,12 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeContext; class InferShapeContext;
...@@ -248,13 +253,6 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -248,13 +253,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -366,13 +364,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -366,13 +364,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -382,42 +373,117 @@ class ReshapeKernel { ...@@ -382,42 +373,117 @@ class ReshapeKernel {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out"); auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X"); auto *in = ctx.Input<framework::LoDTensor>("X");
// framework::DDim out_dims = out->dims();
framework::DDim out_dims = out->dims(); auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
// we can't MakePtenDenseTensor by out, because reshape will realloc memory
// and this will throw error(can't realloc shared memory) in current
// DenseTensor
// design. So, codes below create a tmp densetensor for output.
// TODO(YuanRisheng) we can use MakePtenDenseTensor after #36916 merge.
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()),
in->dims(),
pten::TransToPtenDataLayout(in->layout())};
auto pt_out_tmp =
std::make_shared<pten::DenseTensor>(alloc, std::move(meta));
pten::DenseTensor *pt_out = nullptr;
if (in == out) {
pt_out = pt_x.get();
} else {
pt_out = pt_out_tmp.get();
}
auto list_new_shape_tensor = auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensor"); ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
// have shape tensor
auto new_shape = get_new_shape(list_new_shape_tensor);
out_dims = ReshapeOp::ValidateShape(new_shape, in->dims());
} else {
auto *shape_tensor = ctx.HasInput("Shape") auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape") ? ctx.Input<framework::LoDTensor>("Shape")
: nullptr; : nullptr;
if (list_new_shape_tensor.size() > 0) {
if (shape_tensor) { // have shape tensor
auto *shape_data = shape_tensor->data<int>(); std::vector<pten::DenseTensor> pt_vec_shape;
framework::Tensor cpu_shape_tensor; for (auto &tensor : list_new_shape_tensor) {
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
pt_vec_shape.push_back(
std::move(*(paddle::experimental::MakePtenDenseTensor(temp))));
} else {
pt_vec_shape.push_back(
std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor))));
}
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
} else if (shape_tensor) {
std::unique_ptr<pten::DenseTensor> pt_shape;
if (platform::is_gpu_place(shape_tensor->place()) || if (platform::is_gpu_place(shape_tensor->place()) ||
platform::is_xpu_place(shape_tensor->place())) { platform::is_xpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), framework::Tensor temp;
&cpu_shape_tensor); TensorCopySync(*shape_tensor, platform::CPUPlace(), &temp);
shape_data = cpu_shape_tensor.data<int>(); pt_shape = paddle::experimental::MakePtenDenseTensor(temp);
} else {
pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor);
} }
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel()); if (platform::is_cpu_place(ctx.GetPlace())) {
out_dims = ReshapeOp::ValidateShape(shape, in->dims()); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
} else {
auto &shape_vec = ctx.Attr<std::vector<int>>("shape");
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#endif
}
// non-inplace need move all result from pt_out to out, inplace need set
// result dims.
if (in != out) {
paddle::experimental::MovesStorage(pt_out, static_cast<Tensor *>(out));
} else {
out->Resize(pt_out->dims());
} }
out->Resize(out_dims);
out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopy(
*in, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
} }
}; };
...@@ -479,6 +545,21 @@ class Reshape2Op : public ReshapeOp { ...@@ -479,6 +545,21 @@ class Reshape2Op : public ReshapeOp {
ReshapeOp::InferShape(ctx); ReshapeOp::InferShape(ctx);
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
return framework::KernelSignature(
"reshape2.mulhost.mid", {"X", "ShapeTensor"}, {}, {"XShape", "Out"});
} else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape2.host.mid", {"X", "Shape"}, {},
{"XShape", "Out"});
} else {
return framework::KernelSignature("reshape2.mid", {"X"}, {"shape"},
{"XShape", "Out"});
}
}
}; };
class Reshape2OpMaker : public ReshapeOpMaker { class Reshape2OpMaker : public ReshapeOpMaker {
...@@ -557,13 +638,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -557,13 +638,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -114,23 +114,6 @@ struct KernelRegistrar { ...@@ -114,23 +114,6 @@ struct KernelRegistrar {
KernelArgsParseFn args_parse_fn, KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn, KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) { KernelFn kernel_fn) {
if (layout == DataLayout::ANY) {
for (size_t layout_iter = static_cast<size_t>(DataLayout::NHWC);
layout_iter != static_cast<size_t>(DataLayout::NUM_DATA_LAYOUTS);
layout_iter++) {
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
static_cast<DataLayout>(layout_iter),
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
}
} else {
for (size_t dtype = static_cast<size_t>(DataType::BOOL); for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES); dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) { dtype++) {
...@@ -143,7 +126,6 @@ struct KernelRegistrar { ...@@ -143,7 +126,6 @@ struct KernelRegistrar {
kernel_fn); kernel_fn);
} }
} }
}
private: private:
void ConstructKernel(const char* kernel_name_cstr, void ConstructKernel(const char* kernel_name_cstr,
...@@ -158,7 +140,6 @@ struct KernelRegistrar { ...@@ -158,7 +140,6 @@ struct KernelRegistrar {
Kernel kernel(kernel_fn); Kernel kernel(kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def()); args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel); args_def_fn(&kernel);
KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name()); KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
} }
...@@ -846,7 +827,8 @@ struct KernelRegistrar { ...@@ -846,7 +827,8 @@ struct KernelRegistrar {
decltype(meta_kernel_fn) meta_kernel_fn; \ decltype(meta_kernel_fn) meta_kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \ func_id)(::pten::Kernel*); \
static const ::pten::KernelRegistrar __reg_pt_op_kernel_##func_id( \ static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \
func_id)( \
kernel_name, \ kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
......
...@@ -208,6 +208,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> { ...@@ -208,6 +208,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
/* Output Helpers */ /* Output Helpers */
......
...@@ -37,4 +37,17 @@ DenseTensor Flatten(const ContextT& dev_ctx, ...@@ -37,4 +37,17 @@ DenseTensor Flatten(const ContextT& dev_ctx,
return dense_out; return dense_out;
} }
template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out);
return dense_out;
}
} // namespace pten } // namespace pten
...@@ -82,4 +82,142 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, ...@@ -82,4 +82,142 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
layout == DataLayout::UNDEFINED ? x_meta.layout : layout}; layout == DataLayout::UNDEFINED ? x_meta.layout : layout};
} }
static paddle::framework::DDim ValidateShape(
const std::vector<int> shape, const paddle::framework::DDim& in_dims) {
const int64_t in_size = paddle::framework::product(in_dims);
auto in_dims_vec = paddle::framework::vectorize(in_dims);
bool all_positive = std::all_of(in_dims_vec.cbegin(),
in_dims_vec.cend(),
[](int64_t i) { return i > 0; });
// only one dimension can be set to -1, whose size will be automatically
// infered.
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
std::vector<int64_t> output_shape(shape.size(), 0);
int64_t capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE_EQ(
unk_dim_idx,
-1,
paddle::platform::errors::InvalidArgument(
"Only one dimension value of 'shape' in ReshapeOp can "
"be -1. But received shape = [%s], shape[%d] is also -1.",
paddle::framework::make_ddim(shape),
i));
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(
static_cast<int>(i),
in_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in `shape` must be less than "
"the input tensor X's dimensions. "
"But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
"X's dimensions = %d.",
paddle::framework::make_ddim(shape),
i,
in_dims,
in_dims.size()));
} else {
PADDLE_ENFORCE_GT(
shape[i],
0,
paddle::platform::errors::InvalidArgument(
"Each dimension value of 'shape' in ReshapeOp must not "
"be negative except one unknown dimension. "
"But received shape = [%s], shape[%d] = %d.",
paddle::framework::make_ddim(shape),
i,
shape[i]));
}
// NOTE all non-zero values will be converted to True (include negative
// value)
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] = (shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
}
if (unk_dim_idx != -1) {
if (all_positive) {
// in_size < 0 and is un-determinate in compile time, skip the check,
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, in_size = -8, output_shape[0] = 0
// the following check will fail.
output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(
output_shape[unk_dim_idx] * capacity,
-in_size,
paddle::platform::errors::InvalidArgument(
"The 'shape' attribute in ReshapeOp is invalid. "
"The input tensor X'size must be divisible by known "
"capacity of 'shape'. "
"But received X's shape = [%s], X's size = %d, "
"'shape' is [%s], known capacity of 'shape' is %d.",
in_dims,
in_size,
paddle::framework::make_ddim(shape),
capacity));
} else {
output_shape[unk_dim_idx] = -1;
}
} else {
if (all_positive) {
PADDLE_ENFORCE_EQ(
capacity,
in_size,
paddle::platform::errors::InvalidArgument(
"The 'shape' in ReshapeOp is invalid. "
"The input tensor X'size must be equal to the capacity of "
"'shape'. "
"But received X's shape = [%s], X's size = %d, 'shape' is "
"[%s], the capacity of 'shape' is %d.",
in_dims,
in_size,
paddle::framework::make_ddim(shape),
capacity));
}
}
// support reshape with zero-input(input tensor with product(shape) == 0)
// by now we require that if the input tensor is zero shape, the target
// shape of output must be zero
if (in_size == 0) {
PADDLE_ENFORCE_LE(
capacity,
in_size,
paddle::platform::errors::InvalidArgument(
"The 'shape' in ReshapeOp is invalid. "
"The input tensor X's shape = [%s], X's capacity = %d."
"But the target shape of Out is [%s], the "
"capacity of 'Out' is %d.",
in_dims,
in_size,
paddle::framework::make_ddim(shape),
capacity));
}
return paddle::framework::make_ddim(output_shape);
}
DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int>& shape) {
PADDLE_ENFORCE_EQ(!shape.empty(),
true,
paddle::platform::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty."));
auto x_dims = x_meta.dims;
auto out_dims = ValidateShape(shape, x_dims);
DenseTensorMeta return_meta(x_meta.type, out_dims, x_meta.layout);
if (x_dims[0] == return_meta.dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
return_meta.lod = x_meta.lod;
}
return return_meta;
}
} // namespace pten } // namespace pten
...@@ -45,4 +45,6 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, ...@@ -45,4 +45,6 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
DataType dtype, DataType dtype,
DataLayout layout); DataLayout layout);
DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int>& shape);
} // namespace pten } // namespace pten
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/pten/kernels/cpu/manipulation.h" #include "paddle/pten/kernels/cpu/manipulation.h"
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/cpu/utils.h" #include "paddle/pten/kernels/cpu/utils.h"
#include "paddle/pten/kernels/functions/general/manipulation.h"
namespace pten { namespace pten {
...@@ -40,14 +41,75 @@ void FlattenWithXShape(const CPUContext& dev_ctx, ...@@ -40,14 +41,75 @@ void FlattenWithXShape(const CPUContext& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
Flatten<T>(dev_ctx, x, start_axis, stop_axis, out); Flatten<T>(dev_ctx, x, start_axis, stop_axis, out);
const auto& in_dims = x.meta().dims; general::SetXShape(x, xshape);
std::vector<int64_t> xshape_dims(in_dims.size() + 1); }
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) { void ReshapeFromVectorVal(const CPUContext& dev_ctx,
xshape_dims[i + 1] = in_dims[i]; const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) {
out->Resize(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, out);
out->Resize(out_meta.dims);
}
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromVectorVal(dev_ctx, x, shape, out);
general::SetXShape(x, xshape);
}
void ReshapeFromDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromDT(dev_ctx, x, shape, out);
general::SetXShape(x, xshape);
}
void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
} }
xshape->Resize(paddle::framework::make_ddim(xshape_dims)); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
xshape->set_lod(x.lod()); }
void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromVectorDT(dev_ctx, x, shape, out);
general::SetXShape(x, xshape);
} }
} // namespace pten } // namespace pten
...@@ -78,3 +140,15 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", ...@@ -78,3 +140,15 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
// TODO(yuanrisheng): "reshape2" is compatible with old kernel
// architecture, kernel_name should be "reshape".
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
CPU,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2.mid",
CPU,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
...@@ -15,8 +15,6 @@ limitations under the License. */ ...@@ -15,8 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -31,4 +29,37 @@ void Flatten(const CPUContext& dev_ctx, ...@@ -31,4 +29,37 @@ void Flatten(const CPUContext& dev_ctx,
int stop_axis, int stop_axis,
DenseTensor* out); DenseTensor* out);
void ReshapeFromDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out);
void ReshapeFromVectorVal(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out);
void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out);
} // namespace pten } // namespace pten
...@@ -21,21 +21,23 @@ namespace pten { ...@@ -21,21 +21,23 @@ namespace pten {
void Copy(const CPUContext& dev_ctx, const DenseTensor& src, DenseTensor* dst) { void Copy(const CPUContext& dev_ctx, const DenseTensor& src, DenseTensor* dst) {
auto* src_ptr = src.data(); auto* src_ptr = src.data();
auto* dst_ptr = dst->mutable_data();
const auto& src_place = src.place(); const auto& src_place = src.place();
const auto& dst_place = dst->place(); const auto& dst_place = dst->place();
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
dst->Resize(src.dims());
auto* dst_ptr = dst->mutable_data();
if (src_ptr == dst_ptr && src_place == dst_place) { if (src_ptr == dst_ptr && src_place == dst_place) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to " VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place; << dst_place;
return; return;
} }
VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr;
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
dst->Resize(src.dims());
CHECK(dst->layout() == src.layout()); CHECK(dst->layout() == src.layout());
auto size = src.numel() * paddle::framework::SizeOfType( auto size = src.numel() * paddle::framework::SizeOfType(
TransToProtoVarType(src.data_type())); TransToProtoVarType(src.data_type()));
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/cuda/manipulation.h" #include "paddle/pten/kernels/cuda/manipulation.h"
#include "paddle/pten/kernels/cuda/utils.h" #include "paddle/pten/kernels/cuda/utils.h"
#include "paddle/pten/kernels/functions/general/manipulation.h"
namespace pten { namespace pten {
...@@ -25,7 +26,7 @@ void Flatten(const CUDAContext& dev_ctx, ...@@ -25,7 +26,7 @@ void Flatten(const CUDAContext& dev_ctx,
int stop_axis, int stop_axis,
DenseTensor* out) { DenseTensor* out) {
auto out_dims = out->dims(); auto out_dims = out->dims();
pten::Copy(dev_ctx, x, out); pten::Copy(dev_ctx, x, false, out);
out->Resize(out_dims); out->Resize(out_dims);
} }
...@@ -40,14 +41,76 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, ...@@ -40,14 +41,76 @@ void FlattenWithXShape(const CUDAContext& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
Flatten<T>(dev_ctx, x, start_axis, stop_axis, out); Flatten<T>(dev_ctx, x, start_axis, stop_axis, out);
const auto& in_dims = x.meta().dims; general::SetXShape(x, xshape);
std::vector<int64_t> xshape_dims(in_dims.size() + 1); }
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) { void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
xshape_dims[i + 1] = in_dims[i]; const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) {
LOG(INFO) << "out_meta dims:" << out_meta.dims;
out->Resize(out_meta.dims);
return;
} }
xshape->Resize(paddle::framework::make_ddim(xshape_dims)); pten::Copy(dev_ctx, x, false, out);
xshape->set_lod(x.lod()); out->Resize(out_meta.dims);
}
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromVectorVal(dev_ctx, x, shape, out);
general::SetXShape(x, xshape);
}
void ReshapeFromDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromDT(dev_ctx, x, shape, out);
general::SetXShape(x, xshape);
}
void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out) {
ReshapeFromVectorDT(dev_ctx, x, shape, out);
general::SetXShape(x, xshape);
} }
} // namespace pten } // namespace pten
...@@ -80,3 +143,13 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", ...@@ -80,3 +143,13 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
CUDA,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2.mid",
CUDA,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
...@@ -33,6 +33,39 @@ void Flatten(const CUDAContext& dev_ctx, ...@@ -33,6 +33,39 @@ void Flatten(const CUDAContext& dev_ctx,
int stop_axis, int stop_axis,
DenseTensor* out); DenseTensor* out);
void ReshapeFromDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out);
void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out);
void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out);
} // namespace pten } // namespace pten
#endif #endif
...@@ -22,23 +22,32 @@ namespace pten { ...@@ -22,23 +22,32 @@ namespace pten {
void Copy(const CUDAContext& dev_ctx, void Copy(const CUDAContext& dev_ctx,
const DenseTensor& src, const DenseTensor& src,
bool is_sync,
DenseTensor* dst) { DenseTensor* dst) {
auto* src_ptr = src.data(); auto* src_ptr = src.data();
auto* dst_ptr = dst->mutable_data();
const auto& src_place = src.place(); const auto& src_place = src.place();
const auto& dst_place = dst->place(); const auto& dst_place = dst->place();
if (src_place == dst_place && paddle::platform::is_cpu_place(src_place)) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The src and dst tensor are all CPU tensor, you should call copy "
"function in CPU mode."));
}
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
dst->Resize(src.dims());
auto* dst_ptr = dst->mutable_data();
if (src_ptr == dst_ptr && src_place == dst_place) { if (src_ptr == dst_ptr && src_place == dst_place) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to " VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place; << dst_place;
return; return;
} }
VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr;
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
dst->Resize(src.dims());
CHECK(dst->layout() == src.layout()); CHECK(dst->layout() == src.layout());
auto size = src.numel() * paddle::framework::SizeOfType( auto size = src.numel() * paddle::framework::SizeOfType(
TransToProtoVarType(src.data_type())); TransToProtoVarType(src.data_type()));
...@@ -88,7 +97,9 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -88,7 +97,9 @@ void Copy(const CUDAContext& dev_ctx,
src_gpu_place, src_gpu_place,
ctx_gpu_place)); ctx_gpu_place));
auto stream = auto stream =
reinterpret_cast<const paddle::platform::CUDADeviceContext&>(dev_ctx) is_sync ? nullptr
: reinterpret_cast<const paddle::platform::CUDADeviceContext&>(
dev_ctx)
.stream(); .stream();
paddle::memory::Copy( paddle::memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
...@@ -114,7 +125,9 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -114,7 +125,9 @@ void Copy(const CUDAContext& dev_ctx,
dst_gpu_place, dst_gpu_place,
ctx_gpu_place)); ctx_gpu_place));
auto stream = auto stream =
reinterpret_cast<const paddle::platform::CUDADeviceContext&>(dev_ctx) is_sync ? nullptr
: reinterpret_cast<const paddle::platform::CUDADeviceContext&>(
dev_ctx)
.stream(); .stream();
paddle::memory::Copy( paddle::memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
...@@ -142,7 +155,9 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -142,7 +155,9 @@ void Copy(const CUDAContext& dev_ctx,
src_gpu_place.device, src_gpu_place.device,
ctx_gpu_place.device)); ctx_gpu_place.device));
auto stream = auto stream =
reinterpret_cast<const paddle::platform::CUDADeviceContext&>(dev_ctx) is_sync ? nullptr
: reinterpret_cast<const paddle::platform::CUDADeviceContext&>(
dev_ctx)
.stream(); .stream();
paddle::memory::Copy( paddle::memory::Copy(
dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream); dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
...@@ -170,7 +185,9 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -170,7 +185,9 @@ void Copy(const CUDAContext& dev_ctx,
dst_gpu_place.device, dst_gpu_place.device,
ctx_gpu_place.device)); ctx_gpu_place.device));
auto stream = auto stream =
reinterpret_cast<const paddle::platform::CUDADeviceContext&>(dev_ctx) is_sync ? nullptr
: reinterpret_cast<const paddle::platform::CUDADeviceContext&>(
dev_ctx)
.stream(); .stream();
paddle::memory::Copy( paddle::memory::Copy(
dst_gpu_place, dst_ptr, src_cuda_pinned_place, src_ptr, size, stream); dst_gpu_place, dst_ptr, src_cuda_pinned_place, src_ptr, size, stream);
...@@ -188,7 +205,9 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -188,7 +205,9 @@ void Copy(const CUDAContext& dev_ctx,
"Context place error, excepted GPUPlace, but actually %s.", "Context place error, excepted GPUPlace, but actually %s.",
ctx_place)); ctx_place));
auto stream = auto stream =
reinterpret_cast<const paddle::platform::CUDADeviceContext&>(dev_ctx) is_sync ? nullptr
: reinterpret_cast<const paddle::platform::CUDADeviceContext&>(
dev_ctx)
.stream(); .stream();
if (paddle::platform::is_same_place(src_place, dst_place)) { if (paddle::platform::is_same_place(src_place, dst_place)) {
paddle::memory::Copy( paddle::memory::Copy(
...@@ -213,7 +232,6 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -213,7 +232,6 @@ void Copy(const CUDAContext& dev_ctx,
} }
} }
} }
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl // TODO(chenweihang): replace by better impl
......
...@@ -26,7 +26,10 @@ namespace pten { ...@@ -26,7 +26,10 @@ namespace pten {
using CUDAContext = paddle::platform::CUDADeviceContext; using CUDAContext = paddle::platform::CUDADeviceContext;
void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, DenseTensor* dst); void Copy(const CUDAContext& dev_ctx,
const DenseTensor& src,
bool is_sync,
DenseTensor* dst);
} // namespace pten } // namespace pten
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
namespace general {
inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) {
const auto& in_dims = x.meta().dims;
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
xshape_dims[i + 1] = in_dims[i];
}
xshape->Resize(paddle::framework::make_ddim(xshape_dims));
xshape->set_lod(x.meta().lod);
}
} // namespace general
} // namespace pten
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/pten/kernels/xpu/manipulation.h" #include "paddle/pten/kernels/xpu/manipulation.h"
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/functions/general/manipulation.h"
#include "paddle/pten/kernels/xpu/utils.h" #include "paddle/pten/kernels/xpu/utils.h"
namespace pten { namespace pten {
...@@ -50,6 +51,47 @@ void FlattenWithXShape(const XPUContext& dev_ctx, ...@@ -50,6 +51,47 @@ void FlattenWithXShape(const XPUContext& dev_ctx,
xshape->set_lod(x.lod()); xshape->set_lod(x.lod());
} }
void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) {
out->Resize(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, out);
out->Resize(out_meta.dims);
}
void ReshapeFromDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl // TODO(chenweihang): replace by better impl
...@@ -80,3 +122,10 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", ...@@ -80,3 +122,10 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
// TODO(yuanrisheng): "reshape2" is compatible with old kernel
// architecture, kernel_name should be "reshape".
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
XPU,
ANY,
pten::ReshapeFromVectorVal) {}
...@@ -33,6 +33,21 @@ void Flatten(const XPUContext& dev_ctx, ...@@ -33,6 +33,21 @@ void Flatten(const XPUContext& dev_ctx,
int stop_axis, int stop_axis,
DenseTensor* out); DenseTensor* out);
void ReshapeFromDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out);
void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out);
} // namespace pten } // namespace pten
#endif #endif
...@@ -125,8 +125,8 @@ TEST(API, matmul_cuda) { ...@@ -125,8 +125,8 @@ TEST(API, matmul_cuda) {
auto place = paddle::platform::CUDAPlace(); auto place = paddle::platform::CUDAPlace();
auto* dev_ctx = pool.GetByPlace(place); auto* dev_ctx = pool.GetByPlace(place);
pten::Copy(*dev_ctx, *ref_x.get(), dense_x.get()); pten::Copy(*dev_ctx, *ref_x.get(), false, dense_x.get());
pten::Copy(*dev_ctx, *ref_y.get(), dense_y.get()); pten::Copy(*dev_ctx, *ref_y.get(), false, dense_y.get());
paddle::experimental::Tensor x(dense_x); paddle::experimental::Tensor x(dense_x);
paddle::experimental::Tensor y(dense_y); paddle::experimental::Tensor y(dense_y);
...@@ -150,7 +150,7 @@ TEST(API, matmul_cuda) { ...@@ -150,7 +150,7 @@ TEST(API, matmul_cuda) {
pten::DenseTensorMeta( pten::DenseTensorMeta(
pten::DataType::FLOAT32, out.shape(), pten::DataLayout::NCHW)); pten::DataType::FLOAT32, out.shape(), pten::DataLayout::NCHW));
pten::Copy(*dev_ctx, *dense_out.get(), ref_out.get()); pten::Copy(*dev_ctx, *dense_out.get(), false, ref_out.get());
for (size_t i = 0; i < 9; i++) { for (size_t i = 0; i < 9; i++) {
ASSERT_NEAR(sum[i], ref_out->data<float>()[i], 1e-6f); ASSERT_NEAR(sum[i], ref_out->data<float>()[i], 1e-6f);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册