未验证 提交 982bf444 编写于 作者: Z zyfncg 提交者: GitHub

refactor matmul directory in pten (#38227)

* refactor matmul directory in pten

* fix merge conflict
上级 3310f519
......@@ -28,7 +28,7 @@ limitations under the License. */
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/linalg.h"
#include "paddle/pten/kernels/matmul_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
......@@ -384,8 +384,8 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out);
// call new kernel
pten::Matmul<T>(dev_ctx, *pt_x.get(), *pt_y.get(), trans_x, trans_y,
pt_out.get());
pten::MatmulKernel<T>(dev_ctx, *pt_x, *pt_y, trans_x, trans_y,
pt_out.get());
}
};
......
......@@ -28,10 +28,10 @@ get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}")
set(PTEN_DEPS ${PTEN_DEPS} ${pten_kernels})
set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu)
set(PTEN_DEPS ${PTEN_DEPS} math_cpu)
set(PTEN_DEPS ${PTEN_DEPS} nary unary binary)
if(WITH_GPU OR WITH_ROCM)
set(PTEN_DEPS ${PTEN_DEPS} math_gpu linalg_gpu)
set(PTEN_DEPS ${PTEN_DEPS} math_gpu)
endif()
cc_library(pten SRCS all.cc DEPS ${PTEN_DEPS})
......@@ -20,10 +20,8 @@ limitations under the License. */
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(matmul, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
#endif
......@@ -17,9 +17,7 @@
// See Note: [ How do we organize the kernel directory ]
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/cpu/linalg.h"
#include "paddle/pten/kernels/dot_kernel.h"
#include "paddle/pten/kernels/gpu/linalg.h"
namespace pten {
......
cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas pten_transpose_cpu cast_kernel)
cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/cpu/linalg.h"
#include "paddle/pten/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/hybird/math/matmul_func.h"
namespace pten {
template <typename T>
void Matmul(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
math::MatMulFunction<CPUContext, T>(
dev_ctx, x, y, out, transpose_x, transpose_y);
}
} // namespace pten
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
complex64,
complex128) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
namespace pten {
template <typename T>
void Matmul(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out);
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
if(WITH_GPU)
nv_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel copy_kernel)
nv_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
elseif(WITH_ROCM)
hip_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel copy_kernel)
hip_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/gpu/linalg.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/hybird/math/matmul_func.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
namespace pten {
template <typename T>
void Matmul(const GPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
math::MatMulFunction<GPUContext, T>(
dev_ctx, x, y, out, transpose_x, transpose_y);
}
} // namespace pten
using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(matmul,
GPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
float16,
complex64,
complex128) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul,
GPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
namespace math {
static void GetBroadcastFromDims(const int x_ndim,
const std::int64_t* x_dims,
......@@ -86,8 +85,8 @@ static void IndexIncreaseFromDims(const int ndim,
}
}
template <typename DeviceContext, typename T>
void MatMulFunction(const DeviceContext& dev_ctx,
template <typename Context, typename T>
void MatMulFunction(const Context& context,
const DenseTensor& X,
const DenseTensor& Y,
const std::vector<std::int64_t>& x_dims,
......@@ -103,7 +102,7 @@ void MatMulFunction(const DeviceContext& dev_ctx,
const T* x_data = X.data<T>();
const T* y_data = Y.data<T>();
auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
auto blas = paddle::operators::math::GetBlas<Context, T>(context);
if (x_ndim == 1 && y_ndim == 1) {
const int M = X.numel();
......@@ -471,8 +470,8 @@ void MatMulFunction(const DeviceContext& dev_ctx,
}
}
template <typename DeviceContext, typename T>
void MatMulFunction(const DeviceContext& dev_ctx,
template <typename Context, typename T>
void MatMulFunction(const Context& context,
const DenseTensor& X,
const DenseTensor& Y,
DenseTensor* Out,
......@@ -481,9 +480,28 @@ void MatMulFunction(const DeviceContext& dev_ctx,
bool flag = false) {
const std::vector<std::int64_t> x_dims = vectorize(X.dims());
const std::vector<std::int64_t> y_dims = vectorize(Y.dims());
MatMulFunction<DeviceContext, T>(
dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
MatMulFunction<Context, T>(
context, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
}
template <typename T, typename Context>
void MatmulKernel(const Context& context,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
MatMulFunction<Context, T>(context, x, y, out, transpose_x, transpose_y);
}
} // namespace math
} // namespace pten
......@@ -14,22 +14,33 @@
#pragma once
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/include/infermeta.h"
namespace pten {
template <typename T>
void Matmul(const GPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out);
template <typename T, typename Context>
void MatmulKernel(const Context& context,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out);
} // namespace pten
template <typename T, typename Context>
DenseTensor Matmul(const Context& context,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y) {
auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y);
DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
context.GetPlace()),
std::move(out_meta));
MatmulKernel<T, Context>(context, x, y, transpose_x, transpose_y, &dense_out);
return dense_out;
}
#endif
} // namespace pten
......@@ -2,6 +2,7 @@ cc_test(test_copy_dev_api SRCS test_copy_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_dot_dev_api SRCS test_dot_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_creation_dev_api SRCS test_creation_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_matmul_dev_api SRCS test_matmul_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_cast_dev_api SRCS test_cast_dev_api.cc DEPS pten pten_api_utils)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/kernels/matmul_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
namespace tests {
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
TEST(DEV_API, dot) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
DenseTensor dense_x(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>();
DenseTensor dense_y(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 3}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y.mutable_data<float>();
for (size_t i = 0; i < 9; ++i) {
dense_x_data[i] = 1.0;
dense_y_data[i] = 2.0;
}
std::vector<float> sum(9, 6.0);
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto out = Matmul<float, CPUContext>(
*(static_cast<CPUContext*>(ctx)), dense_x, dense_y, false, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
ASSERT_EQ(out.dtype(), DataType::FLOAT32);
ASSERT_EQ(out.layout(), DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
for (size_t i = 0; i < 9; i++) {
ASSERT_NEAR(sum[i], out.data<float>()[i], 1e-6f);
}
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册