未验证 提交 12346cdc 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Move header of selected_rows kernel to selected_rows dir (#40128)

* move selected_rows kernel head to selected_rows dir

* update license

* add sr namespace

* refacter selected_rows kernel funciton name

* fix bug
上级 5dc76637
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
...@@ -31,13 +30,6 @@ void FullKernel(const Context& dev_ctx, ...@@ -31,13 +30,6 @@ void FullKernel(const Context& dev_ctx,
DataType dtype, DataType dtype,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void FullSR(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype,
SelectedRows* out);
template <typename T, typename Context> template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx, void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
namespace phi { namespace phi {
...@@ -29,14 +28,6 @@ void ScaleKernel(const Context& dev_ctx, ...@@ -29,14 +28,6 @@ void ScaleKernel(const Context& dev_ctx,
bool bias_after_scale, bool bias_after_scale,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Scale(const Context& dev_ctx, DenseTensor Scale(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -12,34 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,34 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/selected_rows/full_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#endif #endif
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
namespace phi { namespace phi {
namespace sr {
template <typename T, typename Context> template <typename T, typename Context>
void FullSR(const Context& dev_ctx, void FullKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DataType dtype, DataType dtype,
SelectedRows* out) { SelectedRows* out) {
phi::FullKernel<T>(dev_ctx, shape, val, dtype, out->mutable_value()); phi::FullKernel<T>(dev_ctx, shape, val, dtype, out->mutable_value());
} }
} // namespace sr
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(full_sr, PD_REGISTER_KERNEL(full_sr,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FullSR, phi::sr::FullKernel,
float, float,
double, double,
uint8_t, uint8_t,
...@@ -56,7 +59,7 @@ PD_REGISTER_KERNEL(full_sr, ...@@ -56,7 +59,7 @@ PD_REGISTER_KERNEL(full_sr,
PD_REGISTER_KERNEL(full_sr, PD_REGISTER_KERNEL(full_sr,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FullSR, phi::sr::FullKernel,
float, float,
double, double,
uint8_t, uint8_t,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype,
SelectedRows* out);
} // namespace sr
} // namespace phi
...@@ -12,21 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/kernels/selected_rows/scale_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/scale_kernel.h"
namespace phi { namespace phi {
namespace sr {
template <typename T, typename Context> template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx, void ScaleKernel(const Context& dev_ctx,
const SelectedRows& x, const SelectedRows& x,
const Scalar& scale, const Scalar& scale,
float bias, float bias,
bool bias_after_scale, bool bias_after_scale,
SelectedRows* out) { SelectedRows* out) {
if (x.value().Holder() != out->value().Holder() || if (x.value().Holder() != out->value().Holder() ||
x.value().data() != out->value().data()) { x.value().data() != out->value().data()) {
out->set_rows(x.rows()); out->set_rows(x.rows());
...@@ -36,12 +38,13 @@ void ScaleSR(const Context& dev_ctx, ...@@ -36,12 +38,13 @@ void ScaleSR(const Context& dev_ctx,
dev_ctx, x.value(), scale, bias, bias_after_scale, out->mutable_value()); dev_ctx, x.value(), scale, bias, bias_after_scale, out->mutable_value());
} }
} // namespace sr
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(scale_sr, PD_REGISTER_KERNEL(scale_sr,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ScaleSR, phi::sr::ScaleKernel,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -55,7 +58,7 @@ PD_REGISTER_KERNEL(scale_sr, ...@@ -55,7 +58,7 @@ PD_REGISTER_KERNEL(scale_sr,
PD_REGISTER_KERNEL(scale_sr, PD_REGISTER_KERNEL(scale_sr,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ScaleSR, phi::sr::ScaleKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void ScaleKernel(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out);
} // namespace sr
} // namespace phi
...@@ -12,22 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,22 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/uniform_random_kernel.h" #include "paddle/phi/kernels/selected_rows/uniform_random_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/uniform_random_kernel.h"
namespace phi { namespace phi {
namespace sr {
template <typename T, typename Context> template <typename T, typename Context>
void UniformRandomRawSRKernel(const Context& dev_ctx, void UniformRandomRawKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype, DataType dtype,
float min, float min,
float max, float max,
int seed, int seed,
int diag_num, int diag_num,
int diag_step, int diag_step,
float diag_val, float diag_val,
SelectedRows* out) { SelectedRows* out) {
phi::UniformRandomRawKernel<T>(dev_ctx, phi::UniformRandomRawKernel<T>(dev_ctx,
shape, shape,
dtype, dtype,
...@@ -41,23 +46,24 @@ void UniformRandomRawSRKernel(const Context& dev_ctx, ...@@ -41,23 +46,24 @@ void UniformRandomRawSRKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void UniformRandomSRKernel(const Context& dev_ctx, void UniformRandomKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype, DataType dtype,
float min, float min,
float max, float max,
int seed, int seed,
SelectedRows* out) { SelectedRows* out) {
phi::UniformRandomKernel<T>( phi::UniformRandomKernel<T>(
dev_ctx, shape, dtype, min, max, seed, out->mutable_value()); dev_ctx, shape, dtype, min, max, seed, out->mutable_value());
} }
} // namespace sr
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(uniform_random_raw_sr, PD_REGISTER_KERNEL(uniform_random_raw_sr,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UniformRandomRawSRKernel, phi::sr::UniformRandomRawKernel,
float, float,
double, double,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
...@@ -65,7 +71,7 @@ PD_REGISTER_KERNEL(uniform_random_raw_sr, ...@@ -65,7 +71,7 @@ PD_REGISTER_KERNEL(uniform_random_raw_sr,
PD_REGISTER_KERNEL(uniform_random_sr, PD_REGISTER_KERNEL(uniform_random_sr,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UniformRandomSRKernel, phi::sr::UniformRandomKernel,
float, float,
double, double,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
...@@ -75,14 +81,14 @@ PD_REGISTER_KERNEL(uniform_random_sr, ...@@ -75,14 +81,14 @@ PD_REGISTER_KERNEL(uniform_random_sr,
PD_REGISTER_KERNEL(uniform_random_raw_sr, PD_REGISTER_KERNEL(uniform_random_raw_sr,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UniformRandomRawSRKernel, phi::sr::UniformRandomRawKernel,
float, float,
double) {} double) {}
PD_REGISTER_KERNEL(uniform_random_sr, PD_REGISTER_KERNEL(uniform_random_sr,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UniformRandomSRKernel, phi::sr::UniformRandomKernel,
float, float,
double) {} double) {}
#endif #endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void UniformRandomRawKernel(const Context& dev_ctx,
const ScalarArray& shape,
DataType dtype,
float min,
float max,
int seed,
int diag_num,
int diag_step,
float diag_val,
SelectedRows* out);
template <typename T, typename Context>
void UniformRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
DataType dtype,
float min,
float max,
int seed,
SelectedRows* out);
} // namespace sr
} // namespace phi
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi { namespace phi {
...@@ -42,25 +41,4 @@ void UniformRandomKernel(const Context& dev_ctx, ...@@ -42,25 +41,4 @@ void UniformRandomKernel(const Context& dev_ctx,
int seed, int seed,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void UniformRandomRawSRKernel(const Context& dev_ctx,
const ScalarArray& shape,
DataType dtype,
float min,
float max,
int seed,
int diag_num,
int diag_step,
float diag_val,
SelectedRows* out);
template <typename T, typename Context>
void UniformRandomSRKernel(const Context& dev_ctx,
const ScalarArray& shape,
DataType dtype,
float min,
float max,
int seed,
SelectedRows* out);
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册