未验证 提交 c9f4cfad 编写于 作者: N Netpunk 提交者: GitHub

[PHI decoupling] replace dependency of inclusive_scan.h from phi (#48980)

* replace dependency of inclusive_scan.h from phi

* format code
上级 00f20313
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
#include <thrust/transform.h> #include <thrust/transform.h>
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h" #include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
// NOTE(@xiongkun): use of IsComplex<> // NOTE(@xiongkun): use of IsComplex<>
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
...@@ -194,16 +194,15 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -194,16 +194,15 @@ void CumprodGradKernel(const Context &dev_ctx,
auto zero_mask = const_cast<Allocator &>(dev_ctx.GetAllocator()) auto zero_mask = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(uint8_t)); .Allocate(numel * sizeof(uint8_t));
auto *zero_mask_data = reinterpret_cast<uint8_t *>(zero_mask->ptr()); auto *zero_mask_data = reinterpret_cast<uint8_t *>(zero_mask->ptr());
paddle::operators::math::InclusiveScan<uint8_t, cub::Max>( phi::funcs::InclusiveScan<uint8_t, cub::Max>(zero_mask_without_cummax_data,
zero_mask_without_cummax_data, zero_mask_data,
zero_mask_data, outer_dim,
outer_dim, mid_dim,
mid_dim, inner_dim,
inner_dim, static_cast<uint8_t>(0),
static_cast<uint8_t>(0), cub::Max(),
cub::Max(), /*reverse=*/false,
/*reverse=*/false, dev_ctx);
dev_ctx);
zero_mask_without_cummax = nullptr; zero_mask_without_cummax = nullptr;
// Step 2: calculate reversed cumsum(dy * y) // Step 2: calculate reversed cumsum(dy * y)
...@@ -222,16 +221,15 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -222,16 +221,15 @@ void CumprodGradKernel(const Context &dev_ctx,
.Allocate(numel * sizeof(T)); .Allocate(numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data = auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr()); reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
paddle::operators::math::InclusiveScan<T, cub::Sum>( phi::funcs::InclusiveScan<T, cub::Sum>(dy_mul_y_data,
dy_mul_y_data, dy_mul_y_reversed_cumsum_data,
dy_mul_y_reversed_cumsum_data, outer_dim,
outer_dim, mid_dim,
mid_dim, inner_dim,
inner_dim, static_cast<T>(0),
static_cast<T>(0), cub::Sum(),
cub::Sum(), /*reverse=*/true,
/*reverse=*/true, dev_ctx);
dev_ctx);
// Step 3: calculate the gradient value except the first zero position. // Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1], // The gradient value of the first zero position is filled with out[idx-1],
...@@ -262,7 +260,7 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -262,7 +260,7 @@ void CumprodGradKernel(const Context &dev_ctx,
// Step 4: calculate cumprod of x_filled_one // Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data = auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, funcs::MultiplyFunctor<T>>( phi::funcs::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
x_filled_one_data, x_filled_one_data,
x_filled_one_cumprod_data, x_filled_one_cumprod_data,
outer_dim, outer_dim,
...@@ -284,7 +282,7 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -284,7 +282,7 @@ void CumprodGradKernel(const Context &dev_ctx,
funcs::MultiplyFunctor<T>()); funcs::MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum = auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, cub::Sum>( phi::funcs::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod, dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum, dy_mul_x_filled_one_cumprod_reversed_cumsum,
outer_dim, outer_dim,
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/cumprod_kernel.h" #include "paddle/phi/kernels/cumprod_kernel.h"
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h" #include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
namespace phi { namespace phi {
...@@ -35,15 +35,15 @@ void CumprodKernel(const Context &dev_ctx, ...@@ -35,15 +35,15 @@ void CumprodKernel(const Context &dev_ctx,
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y); auto *y_data = dev_ctx.template Alloc<T>(y);
paddle::operators::math::InclusiveScan(x_data, phi::funcs::InclusiveScan(x_data,
y_data, y_data,
outer_dim, outer_dim,
mid_dim, mid_dim,
inner_dim, inner_dim,
static_cast<T>(1), static_cast<T>(1),
funcs::MultiplyFunctor<T>(), funcs::MultiplyFunctor<T>(),
/*reverse=*/false, /*reverse=*/false,
dev_ctx); dev_ctx);
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册