未验证 提交 ab583173 编写于 作者: Z zhangkaihuo 提交者: GitHub

Use base visit in cpu kernel (#45062)

上级 0b4268a6
......@@ -98,7 +98,7 @@ template <typename T, typename Context>
void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalesceCPUKernel", ([&] {
CoalesceCPUKernel<T, data_t>(dev_ctx, x, out);
}));
......
......@@ -196,7 +196,7 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooGradCPUKernel", ([&] {
Conv3dCooGradCPUKernel<T, data_t>(dev_ctx,
x,
......
......@@ -186,7 +186,7 @@ void Conv3dCooKernel(const Context& dev_ctx,
SparseCooTensor* out,
DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Conv3dCooCPUKernel", ([&] {
Conv3dCooCPUKernel<T, data_t>(dev_ctx,
x,
......
......@@ -236,7 +236,7 @@ void ElementWiseDivideCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
SparseCsrTensor* dx,
SparseCsrTensor* dy) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "ElementWiseDivideCsrGradCPUKernel", ([&] {
ElementWiseDivideCsrGradCPUKernel<T, data_t>(
dev_ctx, x, y, out, dout, dx, dy);
......@@ -250,7 +250,7 @@ void ElementWiseDivideCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout,
SparseCooTensor* dx,
SparseCooTensor* dy) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "ElementWiseDivideCooGradCPUKernel", ([&] {
ElementWiseDivideCooGradCPUKernel<T, data_t>(
dev_ctx, x, y, out, dout, dx, dy);
......@@ -262,36 +262,38 @@ void ElementWiseDivideCooGradKernel(const Context& dev_ctx,
\
DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name)
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \
template <typename T, typename Context> \
void ElementWise##name##CsrGradKernel(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& y, \
const SparseCsrTensor& dout, \
SparseCsrTensor* dx, \
SparseCsrTensor* dy) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_crows().dtype(), \
"ElementWise##name##CsrGradCPUKernel", \
([&] { \
ElementWise##name##CsrGradCPUKernel<T, data_t>( \
dev_ctx, x, y, dout, dx, dy); \
})); \
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \
template <typename T, typename Context> \
void ElementWise##name##CsrGradKernel(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& y, \
const SparseCsrTensor& dout, \
SparseCsrTensor* dx, \
SparseCsrTensor* dy) { \
PD_VISIT_BASE_INTEGRAL_TYPES( \
x.non_zero_crows().dtype(), \
"ElementWise##name##CsrGradCPUKernel", \
([&] { \
ElementWise##name##CsrGradCPUKernel<T, data_t>( \
dev_ctx, x, y, dout, dx, dy); \
})); \
}
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) \
template <typename T, typename Context> \
void ElementWise##name##CooGradKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y, \
const SparseCooTensor& dout, \
SparseCooTensor* dx, \
SparseCooTensor* dy) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \
"ElementWise##name##CooGradCPUKernel", \
([&] { \
ElementWise##name##CooGradCPUKernel<T, data_t>( \
dev_ctx, x, y, dout, dx, dy); \
})); \
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) \
template <typename T, typename Context> \
void ElementWise##name##CooGradKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y, \
const SparseCooTensor& dout, \
SparseCooTensor* dx, \
SparseCooTensor* dy) { \
PD_VISIT_BASE_INTEGRAL_TYPES( \
x.non_zero_indices().dtype(), \
"ElementWise##name##CooGradCPUKernel", \
([&] { \
ElementWise##name##CooGradCPUKernel<T, data_t>( \
dev_ctx, x, y, dout, dx, dy); \
})); \
}
DEFINE_ELEMENTWISE_GRAD_KERNEL(Add)
......
......@@ -57,11 +57,12 @@ void Merge(const IntT el_len,
const IntT len_b_max,
IntT* c_index,
T* c_values,
IntT& nnz,
IntT* out_nnz,
const Functor& functor_org,
const bool is_divide) {
IntT a = 0;
IntT b = 0;
IntT& nnz = (*out_nnz);
nnz = 0;
const IntT* b_index = nullptr;
std::vector<IntT> b_full_index;
......@@ -94,9 +95,7 @@ void Merge(const IntT el_len,
}
++a;
++b;
}
// coordinate x[a] < coordinate y[b]
else if (a_index[a] < b_index[b]) {
} else if (a_index[a] < b_index[b]) { // coordinate x[a] < coordinate y[b]
if (!functor(a_values + a * el_len,
zero.data(),
c_values + nnz * el_len,
......@@ -105,9 +104,7 @@ void Merge(const IntT el_len,
++nnz;
}
++a;
}
// coordinate x[a] > coordinate y[b]
else if (a_index[a] > b_index[b]) {
} else if (a_index[a] > b_index[b]) { // coordinate x[a] > coordinate y[b]
if (!functor(zero.data(),
b_values[b_index[b]],
c_values + nnz * el_len,
......@@ -215,7 +212,7 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
max_len,
out_indexs.data(),
out_values_vec.data(),
nnz,
&nnz,
functor,
is_divide);
......@@ -292,7 +289,7 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
const SparseCsrTensor& x, \
const SparseCsrTensor& y, \
SparseCsrTensor* out) { \
PD_VISIT_INTEGRAL_TYPES( \
PD_VISIT_BASE_INTEGRAL_TYPES( \
x.non_zero_crows().dtype(), "ElementWise##name##CsrCPUKernel", ([&] { \
ElementWise##name##CsrCPUKernel<T, data_t>(dev_ctx, x, y, out); \
})); \
......@@ -309,18 +306,18 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
dev_ctx, x, y, out, functor); \
}
#define DEFINE_COO_ELEMENTWISE_KERNEL(name) \
template <typename T, typename Context> \
void ElementWise##name##CooKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y, \
SparseCooTensor* out) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \
"ElementWise##name##CooCPUKernel", \
([&] { \
ElementWise##name##CooCPUKernel<T, data_t>( \
dev_ctx, x, y, out); \
})); \
#define DEFINE_COO_ELEMENTWISE_KERNEL(name) \
template <typename T, typename Context> \
void ElementWise##name##CooKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y, \
SparseCooTensor* out) { \
PD_VISIT_BASE_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \
"ElementWise##name##CooCPUKernel", \
([&] { \
ElementWise##name##CooCPUKernel<T, data_t>( \
dev_ctx, x, y, out); \
})); \
}
DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(Add)
......
......@@ -79,7 +79,7 @@ void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
mask.non_zero_indices().dtype(), "SparseMaskCPUKernel", ([&] {
SparseMaskCPUKernel<T, data_t>(dev_ctx, x, mask, out);
}));
......@@ -146,7 +146,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseMaskHelperCPUKernel", ([&] {
SparseMaskHelperCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
}));
......
......@@ -83,7 +83,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& out_grad,
const std::vector<int>& kernel_sizes,
SparseCooTensor* x_grad) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] {
MaxPoolCooGradCPUKernel<T, data_t>(
dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad);
......
......@@ -109,7 +109,7 @@ void MaxPoolCooKernel(const Context& dev_ctx,
SparseCooTensor* out,
DenseTensor* rulebook,
DenseTensor* counter) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] {
MaxPoolCooCPUKernel<T, data_t>(dev_ctx,
x,
......
......@@ -62,7 +62,7 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
T* dx_data = dx_values->data<T>();
// dx = (dout - sum(dout * out)) * out
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] {
const data_t* out_crows_data = out_crows.data<data_t>();
for (int i = 0; i < batch_size; ++i) {
......
......@@ -60,7 +60,7 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
T* out_data = out_values->data<T>();
// out = exp(x-x_max) / sum( exp(x-x_max ))
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] {
const data_t* x_crows_data = x_crows.data<data_t>();
for (int i = 0; i < batch_size; ++i) {
......
......@@ -160,7 +160,7 @@ template <typename T, typename Context>
void SparseCsrToCooKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "SparseCsrToCooCPUKernel", ([&] {
SparseCsrToCooCPUKernel<T, data_t>(dev_ctx, x, out);
}));
......@@ -250,7 +250,7 @@ template <typename T, typename Context>
void SparseCooToCsrKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCsrTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseCooToCsrCPUKernel", ([&] {
SparseCooToCsrCPUKernel<T, data_t>(dev_ctx, x, out);
}));
......@@ -304,7 +304,7 @@ template <typename T, typename Context>
void SparseCooToDenseKernel(const Context& dev_ctx,
const SparseCooTensor& x,
DenseTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseCooToDenseCPUKernel", ([&] {
SparseCooToDenseCPUKernel<T, data_t>(dev_ctx, x, out);
}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册