未验证 提交 abfc2fe9 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Refactor scale kernel that has selected_rows input (#39278)

* refactor scale kernel that its input is selected_rows

* complement upload file
上级 848ae7dc
......@@ -88,6 +88,7 @@ function(kernel_library TARGET)
set(cpu_srcs)
set(gpu_srcs)
set(xpu_srcs)
set(selected_rows_srcs)
# parse and save the deps kerenl targets
set(all_srcs)
set(kernel_deps)
......@@ -106,6 +107,9 @@ function(kernel_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc)
list(APPEND cpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/selected_rows/${TARGET}.cc)
list(APPEND selected_rows_srcs ${CMAKE_CURRENT_SOURCE_DIR}/selected_rows/${TARGET}.cc)
endif()
if (WITH_GPU OR WITH_ROCM)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu)
list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu)
......@@ -144,27 +148,30 @@ function(kernel_library TARGET)
list(LENGTH cpu_srcs cpu_srcs_len)
list(LENGTH gpu_srcs gpu_srcs_len)
list(LENGTH xpu_srcs xpu_srcs_len)
list(LENGTH selected_rows_srcs selected_rows_srcs_len)
# Build Target according different src organization
if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
${xpu_srcs_len} GREATER 0) AND ${common_srcs_len} GREATER 0)
# If the common_srcs depends on specific device srcs, build target using this rule.
${xpu_srcs_len} GREATER 0) AND (${common_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0))
# If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
cc_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
endif()
# If there are only specific device srcs, build target using this rule.
elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
......@@ -179,25 +186,42 @@ function(kernel_library TARGET)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
else()
if (${common_srcs_len} EQUAL 0)
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
# If the selected_rows_srcs depends on common_srcs, build target using this rule.
elseif (${common_srcs_len} GREATER 0 AND ${selected_rows_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
elseif (WITH_ROCM)
hip_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
else()
# If the kernel has a device independent public implementation,
# we will use this implementation and will not adopt the implementation
# under specific devices
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
cc_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
# If there are only common_srcs or selected_rows_srcs, build target using below rules.
elseif (${common_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
elseif (${selected_rows_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
else()
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
endif()
if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0)
# append target into PTEN_KERNELS property
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
set(pten_kernels ${pten_kernels} ${TARGET})
......@@ -219,6 +243,9 @@ function(kernel_library TARGET)
if (${xpu_srcs_len} GREATER 0)
kernel_declare(${xpu_srcs})
endif()
if (${selected_rows_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs})
endif()
endfunction()
function(register_kernels)
......
......@@ -43,34 +43,36 @@ class ScaleKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in_var = ctx.InputVar("X");
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto bias = ctx.Attr<float>("bias");
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto scale = ctx.Attr<float>("scale");
auto* out_var = ctx.OutputVar("Out");
if (ctx.HasInput("ScaleTensor")) {
auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor");
scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor));
}
auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<pten::SelectedRows>() && in_var != out_var) {
auto& in_slr = in_var->Get<pten::SelectedRows>();
auto* out_slr = out_var->GetMutable<pten::SelectedRows>();
out_slr->set_rows(in_slr.rows());
out_slr->set_height(in_slr.height());
}
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto* out =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
out->mutable_data<T>(in->place());
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
if (in_var->IsType<pten::SelectedRows>()) {
pten::ScaleSR<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
in_var->Get<pten::SelectedRows>(), scale, bias, bias_after_scale,
out_var->GetMutable<pten::SelectedRows>());
} else {
pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
}
}
};
......
......@@ -74,6 +74,9 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
......@@ -81,6 +84,9 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
std::type_index(typeid(std::vector<DenseTensor*>))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else {
// Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe
......
......@@ -20,6 +20,7 @@
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_def.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/core/sparse_coo_tensor.h"
#include "paddle/pten/core/sparse_csr_tensor.h"
......@@ -215,6 +216,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
......@@ -223,8 +225,6 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SparseCsrTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor);
/* Attribute Helpers */
......@@ -244,14 +244,13 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows);
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCsrTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRowsTensor);
/* End case */
template <typename T>
......
......@@ -57,7 +57,7 @@ PT_REGISTER_KERNEL(scale,
pten::ScaleKernel,
float,
double,
paddle::platform::bfloat16,
pten::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
......
......@@ -72,7 +72,7 @@ PT_REGISTER_KERNEL(scale,
pten::ScaleKernel,
float,
double,
paddle::platform::float16,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
......@@ -28,6 +29,14 @@ void ScaleKernel(const Context& dev_ctx,
bool bias_after_scale,
DenseTensor* out);
template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out);
template <typename T, typename Context>
DenseTensor Scale(const Context& dev_ctx,
const DenseTensor& x,
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/scale_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/bfloat16.h"
namespace pten {
template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out) {
if (x.value().data() != out->value().data()) {
out->set_rows(x.rows());
out->set_height(x.height());
}
pten::ScaleKernel<T>(
dev_ctx, x.value(), scale, bias, bias_after_scale, out->mutable_value());
}
} // namespace pten
PT_REGISTER_KERNEL(scale_sr,
CPU,
ALL_LAYOUT,
pten::ScaleSR,
float,
double,
pten::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_KERNEL(scale_sr,
GPU,
ALL_LAYOUT,
pten::ScaleSR,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
......@@ -101,7 +101,7 @@ static void ScaleCPU(DataType kernel_dtype,
break;
}
case pten::DataType::BFLOAT16: {
pten::ScaleKernel<paddle::platform::bfloat16>(
pten::ScaleKernel<pten::dtype::bfloat16>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册