未验证 提交 c9f4cfad 编写于 作者: N Netpunk 提交者: GitHub

[PHI decoupling] replace dependency of inclusive_scan.h from phi (#48980)

* replace dependency of inclusive_scan.h from phi

* format code
上级 00f20313
......@@ -16,13 +16,13 @@
#include <thrust/transform.h>
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/phi/core/utils/data_type.h"
......@@ -194,16 +194,15 @@ void CumprodGradKernel(const Context &dev_ctx,
auto zero_mask = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(uint8_t));
auto *zero_mask_data = reinterpret_cast<uint8_t *>(zero_mask->ptr());
paddle::operators::math::InclusiveScan<uint8_t, cub::Max>(
zero_mask_without_cummax_data,
zero_mask_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<uint8_t>(0),
cub::Max(),
/*reverse=*/false,
dev_ctx);
phi::funcs::InclusiveScan<uint8_t, cub::Max>(zero_mask_without_cummax_data,
zero_mask_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<uint8_t>(0),
cub::Max(),
/*reverse=*/false,
dev_ctx);
zero_mask_without_cummax = nullptr;
// Step 2: calculate reversed cumsum(dy * y)
......@@ -222,16 +221,15 @@ void CumprodGradKernel(const Context &dev_ctx,
.Allocate(numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
paddle::operators::math::InclusiveScan<T, cub::Sum>(
dy_mul_y_data,
dy_mul_y_reversed_cumsum_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
phi::funcs::InclusiveScan<T, cub::Sum>(dy_mul_y_data,
dy_mul_y_reversed_cumsum_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
// Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1],
......@@ -262,7 +260,7 @@ void CumprodGradKernel(const Context &dev_ctx,
// Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
phi::funcs::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
x_filled_one_data,
x_filled_one_cumprod_data,
outer_dim,
......@@ -284,7 +282,7 @@ void CumprodGradKernel(const Context &dev_ctx,
funcs::MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, cub::Sum>(
phi::funcs::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum,
outer_dim,
......
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/cumprod_kernel.h"
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
namespace phi {
......@@ -35,15 +35,15 @@ void CumprodKernel(const Context &dev_ctx,
const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y);
paddle::operators::math::InclusiveScan(x_data,
y_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
phi::funcs::InclusiveScan(x_data,
y_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册