未验证 提交 5e222dc2 编写于 作者: H huangjiyi 提交者: GitHub

[PHI Decoupling] move maxouting and matrix_bit_code from fluid to phi (#49131)

* move maxouting from fluid to phi

* move matrix_bit_code from fluid to phi

* replace mutable_data and fix include

* fix include

* move gather_scatter_kernel from fluid to phi

* Revert "move gather_scatter_kernel from fluid to phi"

This reverts commit 3d0b1eaf179656072e8c483dfca688cccccdda01.
上级 af599121
...@@ -27,7 +27,6 @@ math_library(sample_prob) ...@@ -27,7 +27,6 @@ math_library(sample_prob)
math_library(sampler DEPS generator) math_library(sampler DEPS generator)
# math_library(math_function DEPS blas dense_tensor tensor) # math_library(math_function DEPS blas dense_tensor tensor)
math_library(maxouting)
math_library(sequence_padding) math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_pooling DEPS math_function jit_kernel_helper)
...@@ -39,7 +38,6 @@ elseif(WITH_XPU) ...@@ -39,7 +38,6 @@ elseif(WITH_XPU)
else() else()
math_library(beam_search DEPS math_function) math_library(beam_search DEPS math_function)
endif() endif()
math_library(matrix_bit_code)
math_library(unpooling) math_library(unpooling)
math_library(prelu) math_library(prelu)
......
...@@ -14,17 +14,16 @@ ...@@ -14,17 +14,16 @@
#pragma once #pragma once
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_bit_code.h"
namespace phi { namespace phi {
namespace math = paddle::operators::math;
template <typename T, typename Context> template <typename T, typename Context>
void HSigmoidLossGradKernelImpl(const Context& ctx, void HSigmoidLossGradKernelImpl(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -55,12 +54,12 @@ void HSigmoidLossGradKernelImpl(const Context& ctx, ...@@ -55,12 +54,12 @@ void HSigmoidLossGradKernelImpl(const Context& ctx,
is_custom = true; is_custom = true;
} }
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code; std::unique_ptr<phi::funcs::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) { if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>( bit_code.reset(new phi::funcs::MatrixBitCodeFunctor<T>(
num_classes, label.template data<int64_t>())); num_classes, label.template data<int64_t>()));
} else { } else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>( bit_code.reset(new phi::funcs::MatrixBitCodeFunctor<T>(
*(path.get_ptr()), *(code.get_ptr()), label.template data<int64_t>())); *(path.get_ptr()), *(code.get_ptr()), label.template data<int64_t>()));
} }
......
...@@ -14,19 +14,17 @@ ...@@ -14,19 +14,17 @@
#include "paddle/phi/kernels/hsigmoid_loss_kernel.h" #include "paddle/phi/kernels/hsigmoid_loss_kernel.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h" #include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "paddle/phi/kernels/funcs/matrix_bit_code.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h" #include "paddle/phi/kernels/impl/clip_kernel_impl.h"
namespace phi { namespace phi {
namespace math = paddle::operators::math;
template <typename T, typename Context> template <typename T, typename Context>
void HSigmoidLossKernel(const Context& ctx, void HSigmoidLossKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -48,8 +46,9 @@ void HSigmoidLossKernel(const Context& ctx, ...@@ -48,8 +46,9 @@ void HSigmoidLossKernel(const Context& ctx,
if (path.get_ptr()) { if (path.get_ptr()) {
is_custom = true; is_custom = true;
} }
int64_t code_length = path.get_ptr() ? path.get_ptr()->dims()[1] int64_t code_length = path.get_ptr()
: math::FindLastSet(num_classes_st - 1); ? path.get_ptr()->dims()[1]
: phi::funcs::FindLastSet(num_classes_st - 1);
int64_t batch_size = x.dims()[0]; int64_t batch_size = x.dims()[0];
DenseTensor sum; DenseTensor sum;
pre_out->Resize(phi::make_ddim({batch_size, code_length})); pre_out->Resize(phi::make_ddim({batch_size, code_length}));
...@@ -63,12 +62,12 @@ void HSigmoidLossKernel(const Context& ctx, ...@@ -63,12 +62,12 @@ void HSigmoidLossKernel(const Context& ctx,
auto& place = *ctx.eigen_device(); auto& place = *ctx.eigen_device();
funcs::RowwiseSum<Context, T> row_sum; funcs::RowwiseSum<Context, T> row_sum;
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code; std::unique_ptr<phi::funcs::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) { if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>( bit_code.reset(new phi::funcs::MatrixBitCodeFunctor<T>(
num_classes_st, label.template data<int64_t>())); num_classes_st, label.template data<int64_t>()));
} else { } else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>( bit_code.reset(new phi::funcs::MatrixBitCodeFunctor<T>(
*(path.get_ptr()), *(code.get_ptr()), label.template data<int64_t>())); *(path.get_ptr()), *(code.get_ptr()), label.template data<int64_t>()));
} }
......
...@@ -20,6 +20,8 @@ math_library(cross_entropy) ...@@ -20,6 +20,8 @@ math_library(cross_entropy)
math_library(im2col) math_library(im2col)
math_library(vol2col) math_library(vol2col)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function)
math_library(maxouting)
math_library(matrix_bit_code)
cc_library( cc_library(
phi_data_layout_transform phi_data_layout_transform
......
...@@ -12,11 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/phi/kernels/funcs/matrix_bit_code.h"
namespace paddle { #include <map>
namespace operators { #include <unordered_map>
namespace math {
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
namespace funcs {
template <typename T> template <typename T>
struct MatrixBitCodeFunctorAdd { struct MatrixBitCodeFunctorAdd {
...@@ -354,6 +359,5 @@ void MatrixBitCodeFunctor<T>::Sub(phi::DenseTensor *tmat) { ...@@ -354,6 +359,5 @@ void MatrixBitCodeFunctor<T>::Sub(phi::DenseTensor *tmat) {
template class MatrixBitCodeFunctor<float>; template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>; template class MatrixBitCodeFunctor<double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -13,18 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,18 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/utils/variant.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#if defined(_WIN32) #if defined(_WIN32)
#include <intrin.h> #include <intrin.h>
...@@ -34,9 +28,8 @@ limitations under the License. */ ...@@ -34,9 +28,8 @@ limitations under the License. */
#include <windows.h> #include <windows.h>
#endif // _WIN32 #endif // _WIN32
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
/** /**
* SimpleCodeTable class should support 3 functions: * SimpleCodeTable class should support 3 functions:
* *
...@@ -273,6 +266,5 @@ class MatrixBitCodeFunctor { ...@@ -273,6 +266,5 @@ class MatrixBitCodeFunctor {
const int64_t* ids_; const int64_t* ids_;
CodeTable code_table_; CodeTable code_table_;
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/maxouting.h" #include "paddle/phi/kernels/funcs/maxouting.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
// All tensors are in NCHW or NHWC format, and the groups must be greater than 1 // All tensors are in NCHW or NHWC format, and the groups must be greater than 1
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -35,7 +34,7 @@ void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context, ...@@ -35,7 +34,7 @@ void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
// c_size means the output size of each sample // c_size means the output size of each sample
int c_size = fea_size * output_channels; int c_size = fea_size * output_channels;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = context.template Alloc<T>(output);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i; int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -80,8 +79,7 @@ void MaxOutGradFunctor<DeviceContext, T>::operator()( ...@@ -80,8 +79,7 @@ void MaxOutGradFunctor<DeviceContext, T>::operator()(
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = context.template Alloc<T>(input_grad);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i; int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -114,6 +112,5 @@ template class MaxOutGradFunctor<phi::CPUContext, double>; ...@@ -114,6 +112,5 @@ template class MaxOutGradFunctor<phi::CPUContext, double>;
template class MaxOutFunctor<phi::CPUContext, float>; template class MaxOutFunctor<phi::CPUContext, float>;
template class MaxOutFunctor<phi::CPUContext, double>; template class MaxOutFunctor<phi::CPUContext, double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/maxouting.h" #include "paddle/phi/kernels/funcs/maxouting.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T> template <typename T>
__global__ void KernelMaxOut(const int nthreads, __global__ void KernelMaxOut(const int nthreads,
...@@ -57,6 +56,7 @@ __global__ void KernelMaxOut(const int nthreads, ...@@ -57,6 +56,7 @@ __global__ void KernelMaxOut(const int nthreads,
output_data[i] = ele; output_data[i] = ele;
} }
} }
template <typename T> template <typename T>
__global__ void KernelMaxoutGrad(const int nthreads, __global__ void KernelMaxoutGrad(const int nthreads,
const T* input_data, const T* input_data,
...@@ -102,6 +102,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, ...@@ -102,6 +102,7 @@ __global__ void KernelMaxoutGrad(const int nthreads,
} }
} }
} }
/* /*
* All tensors are in NCHW or NHWC format. * All tensors are in NCHW or NHWC format.
*/ */
...@@ -118,7 +119,7 @@ void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context, ...@@ -118,7 +119,7 @@ void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const int output_channels = output->dims()[axis]; const int output_channels = output->dims()[axis];
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = context.template Alloc<T>(output);
int nthreads = output->numel(); int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
...@@ -155,7 +156,7 @@ void MaxOutGradFunctor<DeviceContext, T>::operator()( ...@@ -155,7 +156,7 @@ void MaxOutGradFunctor<DeviceContext, T>::operator()(
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = context.template Alloc<T>(input_grad);
int nthreads = output.numel(); int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
...@@ -179,6 +180,5 @@ template class MaxOutGradFunctor<phi::GPUContext, double>; ...@@ -179,6 +180,5 @@ template class MaxOutGradFunctor<phi::GPUContext, double>;
template class MaxOutFunctor<phi::GPUContext, float>; template class MaxOutFunctor<phi::GPUContext, float>;
template class MaxOutFunctor<phi::GPUContext, double>; template class MaxOutFunctor<phi::GPUContext, double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -13,14 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/macros.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/hostdevice.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MaxOutFunctor { class MaxOutFunctor {
...@@ -43,6 +40,5 @@ class MaxOutGradFunctor { ...@@ -43,6 +40,5 @@ class MaxOutGradFunctor {
const int groups, const int groups,
const int axis = 1); const int axis = 1);
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/maxouting.h"
#include "paddle/phi/kernels/maxout_grad_kernel.h" #include "paddle/phi/kernels/maxout_grad_kernel.h"
namespace phi { namespace phi {
...@@ -36,7 +36,7 @@ void MaxOutGradKernel(const Context& dev_ctx, ...@@ -36,7 +36,7 @@ void MaxOutGradKernel(const Context& dev_ctx,
if (x_grad) { if (x_grad) {
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
zero(dev_ctx, x_grad, static_cast<T>(0.0)); zero(dev_ctx, x_grad, static_cast<T>(0.0));
paddle::operators::math::MaxOutGradFunctor<Context, T> maxout_backward; phi::funcs::MaxOutGradFunctor<Context, T> maxout_backward;
maxout_backward(dev_ctx, x, x_grad, out, out_grad, groups, axis); maxout_backward(dev_ctx, x, x_grad, out, out_grad, groups, axis);
} }
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "paddle/fluid/operators/math/maxouting.h" #include "paddle/phi/kernels/funcs/maxouting.h"
#include "paddle/phi/kernels/maxout_kernel.h" #include "paddle/phi/kernels/maxout_kernel.h"
namespace phi { namespace phi {
...@@ -29,7 +29,7 @@ void MaxOutKernel(const Context& dev_ctx, ...@@ -29,7 +29,7 @@ void MaxOutKernel(const Context& dev_ctx,
axis += x.dims().size(); axis += x.dims().size();
} }
paddle::operators::math::MaxOutFunctor<Context, T> maxout_forward; phi::funcs::MaxOutFunctor<Context, T> maxout_forward;
maxout_forward(dev_ctx, x, out, groups, axis); maxout_forward(dev_ctx, x, out, groups, axis);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册