未验证 提交 ab583173 编写于 作者: Z zhangkaihuo 提交者: GitHub

Use base visit in cpu kernel (#45062)

上级 0b4268a6
...@@ -98,7 +98,7 @@ template <typename T, typename Context> ...@@ -98,7 +98,7 @@ template <typename T, typename Context>
void CoalesceKernel(const Context& dev_ctx, void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
SparseCooTensor* out) { SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalesceCPUKernel", ([&] { x.non_zero_indices().dtype(), "CoalesceCPUKernel", ([&] {
CoalesceCPUKernel<T, data_t>(dev_ctx, x, out); CoalesceCPUKernel<T, data_t>(dev_ctx, x, out);
})); }));
......
...@@ -196,7 +196,7 @@ void Conv3dCooGradKernel(const Context& dev_ctx, ...@@ -196,7 +196,7 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
const std::string& key, const std::string& key,
SparseCooTensor* x_grad, SparseCooTensor* x_grad,
DenseTensor* kernel_grad) { DenseTensor* kernel_grad) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooGradCPUKernel", ([&] { x.non_zero_indices().dtype(), "Conv3dCooGradCPUKernel", ([&] {
Conv3dCooGradCPUKernel<T, data_t>(dev_ctx, Conv3dCooGradCPUKernel<T, data_t>(dev_ctx,
x, x,
......
...@@ -186,7 +186,7 @@ void Conv3dCooKernel(const Context& dev_ctx, ...@@ -186,7 +186,7 @@ void Conv3dCooKernel(const Context& dev_ctx,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook, DenseTensor* rulebook,
DenseTensor* counter) { DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooCPUKernel", ([&] { x.non_zero_indices().dtype(), "Conv3dCooCPUKernel", ([&] {
Conv3dCooCPUKernel<T, data_t>(dev_ctx, Conv3dCooCPUKernel<T, data_t>(dev_ctx,
x, x,
......
...@@ -236,7 +236,7 @@ void ElementWiseDivideCsrGradKernel(const Context& dev_ctx, ...@@ -236,7 +236,7 @@ void ElementWiseDivideCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout, const SparseCsrTensor& dout,
SparseCsrTensor* dx, SparseCsrTensor* dx,
SparseCsrTensor* dy) { SparseCsrTensor* dy) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "ElementWiseDivideCsrGradCPUKernel", ([&] { x.non_zero_crows().dtype(), "ElementWiseDivideCsrGradCPUKernel", ([&] {
ElementWiseDivideCsrGradCPUKernel<T, data_t>( ElementWiseDivideCsrGradCPUKernel<T, data_t>(
dev_ctx, x, y, out, dout, dx, dy); dev_ctx, x, y, out, dout, dx, dy);
...@@ -250,7 +250,7 @@ void ElementWiseDivideCooGradKernel(const Context& dev_ctx, ...@@ -250,7 +250,7 @@ void ElementWiseDivideCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout, const SparseCooTensor& dout,
SparseCooTensor* dx, SparseCooTensor* dx,
SparseCooTensor* dy) { SparseCooTensor* dy) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "ElementWiseDivideCooGradCPUKernel", ([&] { x.non_zero_indices().dtype(), "ElementWiseDivideCooGradCPUKernel", ([&] {
ElementWiseDivideCooGradCPUKernel<T, data_t>( ElementWiseDivideCooGradCPUKernel<T, data_t>(
dev_ctx, x, y, out, dout, dx, dy); dev_ctx, x, y, out, dout, dx, dy);
...@@ -262,36 +262,38 @@ void ElementWiseDivideCooGradKernel(const Context& dev_ctx, ...@@ -262,36 +262,38 @@ void ElementWiseDivideCooGradKernel(const Context& dev_ctx,
\ \
DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name)
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \ #define DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void ElementWise##name##CsrGradKernel(const Context& dev_ctx, \ void ElementWise##name##CsrGradKernel(const Context& dev_ctx, \
const SparseCsrTensor& x, \ const SparseCsrTensor& x, \
const SparseCsrTensor& y, \ const SparseCsrTensor& y, \
const SparseCsrTensor& dout, \ const SparseCsrTensor& dout, \
SparseCsrTensor* dx, \ SparseCsrTensor* dx, \
SparseCsrTensor* dy) { \ SparseCsrTensor* dy) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_crows().dtype(), \ PD_VISIT_BASE_INTEGRAL_TYPES( \
"ElementWise##name##CsrGradCPUKernel", \ x.non_zero_crows().dtype(), \
([&] { \ "ElementWise##name##CsrGradCPUKernel", \
ElementWise##name##CsrGradCPUKernel<T, data_t>( \ ([&] { \
dev_ctx, x, y, dout, dx, dy); \ ElementWise##name##CsrGradCPUKernel<T, data_t>( \
})); \ dev_ctx, x, y, dout, dx, dy); \
})); \
} }
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) \ #define DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void ElementWise##name##CooGradKernel(const Context& dev_ctx, \ void ElementWise##name##CooGradKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \ const SparseCooTensor& x, \
const SparseCooTensor& y, \ const SparseCooTensor& y, \
const SparseCooTensor& dout, \ const SparseCooTensor& dout, \
SparseCooTensor* dx, \ SparseCooTensor* dx, \
SparseCooTensor* dy) { \ SparseCooTensor* dy) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \ PD_VISIT_BASE_INTEGRAL_TYPES( \
"ElementWise##name##CooGradCPUKernel", \ x.non_zero_indices().dtype(), \
([&] { \ "ElementWise##name##CooGradCPUKernel", \
ElementWise##name##CooGradCPUKernel<T, data_t>( \ ([&] { \
dev_ctx, x, y, dout, dx, dy); \ ElementWise##name##CooGradCPUKernel<T, data_t>( \
})); \ dev_ctx, x, y, dout, dx, dy); \
})); \
} }
DEFINE_ELEMENTWISE_GRAD_KERNEL(Add) DEFINE_ELEMENTWISE_GRAD_KERNEL(Add)
......
...@@ -57,11 +57,12 @@ void Merge(const IntT el_len, ...@@ -57,11 +57,12 @@ void Merge(const IntT el_len,
const IntT len_b_max, const IntT len_b_max,
IntT* c_index, IntT* c_index,
T* c_values, T* c_values,
IntT& nnz, IntT* out_nnz,
const Functor& functor_org, const Functor& functor_org,
const bool is_divide) { const bool is_divide) {
IntT a = 0; IntT a = 0;
IntT b = 0; IntT b = 0;
IntT& nnz = (*out_nnz);
nnz = 0; nnz = 0;
const IntT* b_index = nullptr; const IntT* b_index = nullptr;
std::vector<IntT> b_full_index; std::vector<IntT> b_full_index;
...@@ -94,9 +95,7 @@ void Merge(const IntT el_len, ...@@ -94,9 +95,7 @@ void Merge(const IntT el_len,
} }
++a; ++a;
++b; ++b;
} } else if (a_index[a] < b_index[b]) { // coordinate x[a] < coordinate y[b]
// coordinate x[a] < coordinate y[b]
else if (a_index[a] < b_index[b]) {
if (!functor(a_values + a * el_len, if (!functor(a_values + a * el_len,
zero.data(), zero.data(),
c_values + nnz * el_len, c_values + nnz * el_len,
...@@ -105,9 +104,7 @@ void Merge(const IntT el_len, ...@@ -105,9 +104,7 @@ void Merge(const IntT el_len,
++nnz; ++nnz;
} }
++a; ++a;
} } else if (a_index[a] > b_index[b]) { // coordinate x[a] > coordinate y[b]
// coordinate x[a] > coordinate y[b]
else if (a_index[a] > b_index[b]) {
if (!functor(zero.data(), if (!functor(zero.data(),
b_values[b_index[b]], b_values[b_index[b]],
c_values + nnz * el_len, c_values + nnz * el_len,
...@@ -215,7 +212,7 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx, ...@@ -215,7 +212,7 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
max_len, max_len,
out_indexs.data(), out_indexs.data(),
out_values_vec.data(), out_values_vec.data(),
nnz, &nnz,
functor, functor,
is_divide); is_divide);
...@@ -292,7 +289,7 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx, ...@@ -292,7 +289,7 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
const SparseCsrTensor& x, \ const SparseCsrTensor& x, \
const SparseCsrTensor& y, \ const SparseCsrTensor& y, \
SparseCsrTensor* out) { \ SparseCsrTensor* out) { \
PD_VISIT_INTEGRAL_TYPES( \ PD_VISIT_BASE_INTEGRAL_TYPES( \
x.non_zero_crows().dtype(), "ElementWise##name##CsrCPUKernel", ([&] { \ x.non_zero_crows().dtype(), "ElementWise##name##CsrCPUKernel", ([&] { \
ElementWise##name##CsrCPUKernel<T, data_t>(dev_ctx, x, y, out); \ ElementWise##name##CsrCPUKernel<T, data_t>(dev_ctx, x, y, out); \
})); \ })); \
...@@ -309,18 +306,18 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx, ...@@ -309,18 +306,18 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
dev_ctx, x, y, out, functor); \ dev_ctx, x, y, out, functor); \
} }
#define DEFINE_COO_ELEMENTWISE_KERNEL(name) \ #define DEFINE_COO_ELEMENTWISE_KERNEL(name) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void ElementWise##name##CooKernel(const Context& dev_ctx, \ void ElementWise##name##CooKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \ const SparseCooTensor& x, \
const SparseCooTensor& y, \ const SparseCooTensor& y, \
SparseCooTensor* out) { \ SparseCooTensor* out) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \ PD_VISIT_BASE_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \
"ElementWise##name##CooCPUKernel", \ "ElementWise##name##CooCPUKernel", \
([&] { \ ([&] { \
ElementWise##name##CooCPUKernel<T, data_t>( \ ElementWise##name##CooCPUKernel<T, data_t>( \
dev_ctx, x, y, out); \ dev_ctx, x, y, out); \
})); \ })); \
} }
DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(Add) DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(Add)
......
...@@ -79,7 +79,7 @@ void SparseMaskKernel(const Context& dev_ctx, ...@@ -79,7 +79,7 @@ void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const SparseCooTensor& mask, const SparseCooTensor& mask,
SparseCooTensor* out) { SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
mask.non_zero_indices().dtype(), "SparseMaskCPUKernel", ([&] { mask.non_zero_indices().dtype(), "SparseMaskCPUKernel", ([&] {
SparseMaskCPUKernel<T, data_t>(dev_ctx, x, mask, out); SparseMaskCPUKernel<T, data_t>(dev_ctx, x, mask, out);
})); }));
...@@ -146,7 +146,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx, ...@@ -146,7 +146,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& mask_indices, const DenseTensor& mask_indices,
DenseTensor* out) { DenseTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseMaskHelperCPUKernel", ([&] { x.non_zero_indices().dtype(), "SparseMaskHelperCPUKernel", ([&] {
SparseMaskHelperCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out); SparseMaskHelperCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
})); }));
......
...@@ -83,7 +83,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx, ...@@ -83,7 +83,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& out_grad, const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes, const std::vector<int>& kernel_sizes,
SparseCooTensor* x_grad) { SparseCooTensor* x_grad) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] { x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] {
MaxPoolCooGradCPUKernel<T, data_t>( MaxPoolCooGradCPUKernel<T, data_t>(
dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad); dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad);
......
...@@ -109,7 +109,7 @@ void MaxPoolCooKernel(const Context& dev_ctx, ...@@ -109,7 +109,7 @@ void MaxPoolCooKernel(const Context& dev_ctx,
SparseCooTensor* out, SparseCooTensor* out,
DenseTensor* rulebook, DenseTensor* rulebook,
DenseTensor* counter) { DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] { x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] {
MaxPoolCooCPUKernel<T, data_t>(dev_ctx, MaxPoolCooCPUKernel<T, data_t>(dev_ctx,
x, x,
......
...@@ -62,7 +62,7 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, ...@@ -62,7 +62,7 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
T* dx_data = dx_values->data<T>(); T* dx_data = dx_values->data<T>();
// dx = (dout - sum(dout * out)) * out // dx = (dout - sum(dout * out)) * out
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] { out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] {
const data_t* out_crows_data = out_crows.data<data_t>(); const data_t* out_crows_data = out_crows.data<data_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
...@@ -60,7 +60,7 @@ void SoftmaxCsrKernel(const Context& dev_ctx, ...@@ -60,7 +60,7 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
T* out_data = out_values->data<T>(); T* out_data = out_values->data<T>();
// out = exp(x-x_max) / sum( exp(x-x_max )) // out = exp(x-x_max) / sum( exp(x-x_max ))
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] { x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] {
const data_t* x_crows_data = x_crows.data<data_t>(); const data_t* x_crows_data = x_crows.data<data_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
...@@ -160,7 +160,7 @@ template <typename T, typename Context> ...@@ -160,7 +160,7 @@ template <typename T, typename Context>
void SparseCsrToCooKernel(const Context& dev_ctx, void SparseCsrToCooKernel(const Context& dev_ctx,
const SparseCsrTensor& x, const SparseCsrTensor& x,
SparseCooTensor* out) { SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "SparseCsrToCooCPUKernel", ([&] { x.non_zero_crows().dtype(), "SparseCsrToCooCPUKernel", ([&] {
SparseCsrToCooCPUKernel<T, data_t>(dev_ctx, x, out); SparseCsrToCooCPUKernel<T, data_t>(dev_ctx, x, out);
})); }));
...@@ -250,7 +250,7 @@ template <typename T, typename Context> ...@@ -250,7 +250,7 @@ template <typename T, typename Context>
void SparseCooToCsrKernel(const Context& dev_ctx, void SparseCooToCsrKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
SparseCsrTensor* out) { SparseCsrTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseCooToCsrCPUKernel", ([&] { x.non_zero_indices().dtype(), "SparseCooToCsrCPUKernel", ([&] {
SparseCooToCsrCPUKernel<T, data_t>(dev_ctx, x, out); SparseCooToCsrCPUKernel<T, data_t>(dev_ctx, x, out);
})); }));
...@@ -304,7 +304,7 @@ template <typename T, typename Context> ...@@ -304,7 +304,7 @@ template <typename T, typename Context>
void SparseCooToDenseKernel(const Context& dev_ctx, void SparseCooToDenseKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
DenseTensor* out) { DenseTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseCooToDenseCPUKernel", ([&] { x.non_zero_indices().dtype(), "SparseCooToDenseCPUKernel", ([&] {
SparseCooToDenseCPUKernel<T, data_t>(dev_ctx, x, out); SparseCooToDenseCPUKernel<T, data_t>(dev_ctx, x, out);
})); }));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册